wandb 0.17.0__py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (832) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +252 -0
  3. wandb/__main__.py +3 -0
  4. wandb/_globals.py +19 -0
  5. wandb/agents/__init__.py +0 -0
  6. wandb/agents/pyagent.py +364 -0
  7. wandb/analytics/__init__.py +3 -0
  8. wandb/analytics/sentry.py +266 -0
  9. wandb/apis/__init__.py +48 -0
  10. wandb/apis/attrs.py +40 -0
  11. wandb/apis/importers/__init__.py +1 -0
  12. wandb/apis/importers/internals/internal.py +385 -0
  13. wandb/apis/importers/internals/protocols.py +99 -0
  14. wandb/apis/importers/internals/util.py +78 -0
  15. wandb/apis/importers/mlflow.py +254 -0
  16. wandb/apis/importers/validation.py +108 -0
  17. wandb/apis/importers/wandb.py +1598 -0
  18. wandb/apis/internal.py +225 -0
  19. wandb/apis/normalize.py +89 -0
  20. wandb/apis/paginator.py +81 -0
  21. wandb/apis/public/__init__.py +34 -0
  22. wandb/apis/public/api.py +1096 -0
  23. wandb/apis/public/artifacts.py +851 -0
  24. wandb/apis/public/const.py +4 -0
  25. wandb/apis/public/files.py +195 -0
  26. wandb/apis/public/history.py +149 -0
  27. wandb/apis/public/jobs.py +646 -0
  28. wandb/apis/public/projects.py +162 -0
  29. wandb/apis/public/query_generator.py +166 -0
  30. wandb/apis/public/reports.py +469 -0
  31. wandb/apis/public/runs.py +800 -0
  32. wandb/apis/public/sweeps.py +240 -0
  33. wandb/apis/public/teams.py +198 -0
  34. wandb/apis/public/users.py +136 -0
  35. wandb/apis/reports/__init__.py +7 -0
  36. wandb/apis/reports/v1/__init__.py +30 -0
  37. wandb/apis/reports/v1/_blocks.py +1406 -0
  38. wandb/apis/reports/v1/_helpers.py +70 -0
  39. wandb/apis/reports/v1/_panels.py +1282 -0
  40. wandb/apis/reports/v1/_templates.py +478 -0
  41. wandb/apis/reports/v1/blocks.py +27 -0
  42. wandb/apis/reports/v1/helpers.py +2 -0
  43. wandb/apis/reports/v1/mutations.py +66 -0
  44. wandb/apis/reports/v1/panels.py +17 -0
  45. wandb/apis/reports/v1/report.py +268 -0
  46. wandb/apis/reports/v1/runset.py +144 -0
  47. wandb/apis/reports/v1/templates.py +7 -0
  48. wandb/apis/reports/v1/util.py +406 -0
  49. wandb/apis/reports/v1/validators.py +131 -0
  50. wandb/apis/reports/v2/__init__.py +20 -0
  51. wandb/apis/reports/v2/blocks.py +25 -0
  52. wandb/apis/reports/v2/expr_parsing.py +257 -0
  53. wandb/apis/reports/v2/gql.py +68 -0
  54. wandb/apis/reports/v2/interface.py +1911 -0
  55. wandb/apis/reports/v2/internal.py +867 -0
  56. wandb/apis/reports/v2/metrics.py +6 -0
  57. wandb/apis/reports/v2/panels.py +15 -0
  58. wandb/beta/workflows.py +283 -0
  59. wandb/bin/wandb-core +0 -0
  60. wandb/cli/__init__.py +0 -0
  61. wandb/cli/cli.py +2960 -0
  62. wandb/data_types.py +2071 -0
  63. wandb/docker/__init__.py +342 -0
  64. wandb/docker/auth.py +436 -0
  65. wandb/docker/wandb-entrypoint.sh +33 -0
  66. wandb/docker/www_authenticate.py +94 -0
  67. wandb/env.py +496 -0
  68. wandb/errors/__init__.py +46 -0
  69. wandb/errors/term.py +103 -0
  70. wandb/errors/util.py +57 -0
  71. wandb/filesync/__init__.py +0 -0
  72. wandb/filesync/dir_watcher.py +403 -0
  73. wandb/filesync/stats.py +100 -0
  74. wandb/filesync/step_checksum.py +142 -0
  75. wandb/filesync/step_prepare.py +179 -0
  76. wandb/filesync/step_upload.py +290 -0
  77. wandb/filesync/upload_job.py +142 -0
  78. wandb/integration/__init__.py +0 -0
  79. wandb/integration/catboost/__init__.py +5 -0
  80. wandb/integration/catboost/catboost.py +178 -0
  81. wandb/integration/cohere/__init__.py +3 -0
  82. wandb/integration/cohere/cohere.py +21 -0
  83. wandb/integration/cohere/resolver.py +347 -0
  84. wandb/integration/diffusers/__init__.py +3 -0
  85. wandb/integration/diffusers/autologger.py +76 -0
  86. wandb/integration/diffusers/pipeline_resolver.py +50 -0
  87. wandb/integration/diffusers/resolvers/__init__.py +9 -0
  88. wandb/integration/diffusers/resolvers/multimodal.py +882 -0
  89. wandb/integration/diffusers/resolvers/utils.py +102 -0
  90. wandb/integration/fastai/__init__.py +249 -0
  91. wandb/integration/gym/__init__.py +85 -0
  92. wandb/integration/huggingface/__init__.py +3 -0
  93. wandb/integration/huggingface/huggingface.py +18 -0
  94. wandb/integration/huggingface/resolver.py +213 -0
  95. wandb/integration/keras/__init__.py +14 -0
  96. wandb/integration/keras/callbacks/__init__.py +5 -0
  97. wandb/integration/keras/callbacks/metrics_logger.py +130 -0
  98. wandb/integration/keras/callbacks/model_checkpoint.py +200 -0
  99. wandb/integration/keras/callbacks/tables_builder.py +226 -0
  100. wandb/integration/keras/keras.py +1080 -0
  101. wandb/integration/kfp/__init__.py +6 -0
  102. wandb/integration/kfp/helpers.py +28 -0
  103. wandb/integration/kfp/kfp_patch.py +324 -0
  104. wandb/integration/kfp/wandb_logging.py +182 -0
  105. wandb/integration/langchain/__init__.py +3 -0
  106. wandb/integration/langchain/wandb_tracer.py +48 -0
  107. wandb/integration/lightgbm/__init__.py +239 -0
  108. wandb/integration/lightning/__init__.py +0 -0
  109. wandb/integration/lightning/fabric/__init__.py +3 -0
  110. wandb/integration/lightning/fabric/logger.py +762 -0
  111. wandb/integration/magic.py +556 -0
  112. wandb/integration/metaflow/__init__.py +3 -0
  113. wandb/integration/metaflow/metaflow.py +383 -0
  114. wandb/integration/openai/__init__.py +3 -0
  115. wandb/integration/openai/fine_tuning.py +454 -0
  116. wandb/integration/openai/openai.py +22 -0
  117. wandb/integration/openai/resolver.py +240 -0
  118. wandb/integration/prodigy/__init__.py +3 -0
  119. wandb/integration/prodigy/prodigy.py +299 -0
  120. wandb/integration/sacred/__init__.py +117 -0
  121. wandb/integration/sagemaker/__init__.py +12 -0
  122. wandb/integration/sagemaker/auth.py +28 -0
  123. wandb/integration/sagemaker/config.py +49 -0
  124. wandb/integration/sagemaker/files.py +3 -0
  125. wandb/integration/sagemaker/resources.py +34 -0
  126. wandb/integration/sb3/__init__.py +3 -0
  127. wandb/integration/sb3/sb3.py +153 -0
  128. wandb/integration/tensorboard/__init__.py +10 -0
  129. wandb/integration/tensorboard/log.py +358 -0
  130. wandb/integration/tensorboard/monkeypatch.py +185 -0
  131. wandb/integration/tensorflow/__init__.py +5 -0
  132. wandb/integration/tensorflow/estimator_hook.py +54 -0
  133. wandb/integration/torch/__init__.py +0 -0
  134. wandb/integration/ultralytics/__init__.py +11 -0
  135. wandb/integration/ultralytics/bbox_utils.py +208 -0
  136. wandb/integration/ultralytics/callback.py +524 -0
  137. wandb/integration/ultralytics/classification_utils.py +83 -0
  138. wandb/integration/ultralytics/mask_utils.py +202 -0
  139. wandb/integration/ultralytics/pose_utils.py +104 -0
  140. wandb/integration/xgboost/__init__.py +11 -0
  141. wandb/integration/xgboost/xgboost.py +189 -0
  142. wandb/integration/yolov8/__init__.py +0 -0
  143. wandb/integration/yolov8/yolov8.py +284 -0
  144. wandb/jupyter.py +501 -0
  145. wandb/magic.py +3 -0
  146. wandb/mpmain/__init__.py +0 -0
  147. wandb/mpmain/__main__.py +1 -0
  148. wandb/old/__init__.py +0 -0
  149. wandb/old/core.py +131 -0
  150. wandb/old/settings.py +173 -0
  151. wandb/old/summary.py +435 -0
  152. wandb/plot/__init__.py +19 -0
  153. wandb/plot/bar.py +42 -0
  154. wandb/plot/confusion_matrix.py +99 -0
  155. wandb/plot/histogram.py +36 -0
  156. wandb/plot/line.py +40 -0
  157. wandb/plot/line_series.py +88 -0
  158. wandb/plot/pr_curve.py +136 -0
  159. wandb/plot/roc_curve.py +118 -0
  160. wandb/plot/scatter.py +32 -0
  161. wandb/plot/utils.py +183 -0
  162. wandb/proto/__init__.py +0 -0
  163. wandb/proto/v3/__init__.py +0 -0
  164. wandb/proto/v3/wandb_base_pb2.py +54 -0
  165. wandb/proto/v3/wandb_internal_pb2.py +1586 -0
  166. wandb/proto/v3/wandb_server_pb2.py +207 -0
  167. wandb/proto/v3/wandb_settings_pb2.py +111 -0
  168. wandb/proto/v3/wandb_telemetry_pb2.py +105 -0
  169. wandb/proto/v4/__init__.py +0 -0
  170. wandb/proto/v4/wandb_base_pb2.py +29 -0
  171. wandb/proto/v4/wandb_internal_pb2.py +354 -0
  172. wandb/proto/v4/wandb_server_pb2.py +62 -0
  173. wandb/proto/v4/wandb_settings_pb2.py +44 -0
  174. wandb/proto/v4/wandb_telemetry_pb2.py +40 -0
  175. wandb/proto/wandb_base_pb2.py +8 -0
  176. wandb/proto/wandb_deprecated.py +43 -0
  177. wandb/proto/wandb_internal_codegen.py +82 -0
  178. wandb/proto/wandb_internal_pb2.py +8 -0
  179. wandb/proto/wandb_server_pb2.py +8 -0
  180. wandb/proto/wandb_settings_pb2.py +8 -0
  181. wandb/proto/wandb_telemetry_pb2.py +8 -0
  182. wandb/py.typed +0 -0
  183. wandb/sdk/__init__.py +37 -0
  184. wandb/sdk/artifacts/__init__.py +0 -0
  185. wandb/sdk/artifacts/artifact.py +2320 -0
  186. wandb/sdk/artifacts/artifact_download_logger.py +43 -0
  187. wandb/sdk/artifacts/artifact_file_cache.py +229 -0
  188. wandb/sdk/artifacts/artifact_instance_cache.py +15 -0
  189. wandb/sdk/artifacts/artifact_manifest.py +72 -0
  190. wandb/sdk/artifacts/artifact_manifest_entry.py +208 -0
  191. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  192. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +90 -0
  193. wandb/sdk/artifacts/artifact_saver.py +261 -0
  194. wandb/sdk/artifacts/artifact_state.py +11 -0
  195. wandb/sdk/artifacts/artifact_ttl.py +7 -0
  196. wandb/sdk/artifacts/exceptions.py +56 -0
  197. wandb/sdk/artifacts/staging.py +25 -0
  198. wandb/sdk/artifacts/storage_handler.py +60 -0
  199. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  200. wandb/sdk/artifacts/storage_handlers/azure_handler.py +194 -0
  201. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +195 -0
  202. wandb/sdk/artifacts/storage_handlers/http_handler.py +113 -0
  203. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +135 -0
  204. wandb/sdk/artifacts/storage_handlers/multi_handler.py +54 -0
  205. wandb/sdk/artifacts/storage_handlers/s3_handler.py +300 -0
  206. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +68 -0
  207. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +133 -0
  208. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  209. wandb/sdk/artifacts/storage_layout.py +6 -0
  210. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  211. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  212. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +371 -0
  213. wandb/sdk/artifacts/storage_policy.py +72 -0
  214. wandb/sdk/backend/__init__.py +0 -0
  215. wandb/sdk/backend/backend.py +240 -0
  216. wandb/sdk/data_types/__init__.py +0 -0
  217. wandb/sdk/data_types/_dtypes.py +911 -0
  218. wandb/sdk/data_types/_private.py +10 -0
  219. wandb/sdk/data_types/base_types/__init__.py +0 -0
  220. wandb/sdk/data_types/base_types/json_metadata.py +55 -0
  221. wandb/sdk/data_types/base_types/media.py +313 -0
  222. wandb/sdk/data_types/base_types/wb_value.py +274 -0
  223. wandb/sdk/data_types/helper_types/__init__.py +0 -0
  224. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +293 -0
  225. wandb/sdk/data_types/helper_types/classes.py +159 -0
  226. wandb/sdk/data_types/helper_types/image_mask.py +233 -0
  227. wandb/sdk/data_types/histogram.py +96 -0
  228. wandb/sdk/data_types/html.py +115 -0
  229. wandb/sdk/data_types/image.py +687 -0
  230. wandb/sdk/data_types/molecule.py +241 -0
  231. wandb/sdk/data_types/object_3d.py +363 -0
  232. wandb/sdk/data_types/plotly.py +82 -0
  233. wandb/sdk/data_types/saved_model.py +444 -0
  234. wandb/sdk/data_types/trace_tree.py +438 -0
  235. wandb/sdk/data_types/utils.py +180 -0
  236. wandb/sdk/data_types/video.py +245 -0
  237. wandb/sdk/integration_utils/__init__.py +0 -0
  238. wandb/sdk/integration_utils/auto_logging.py +239 -0
  239. wandb/sdk/integration_utils/data_logging.py +475 -0
  240. wandb/sdk/interface/__init__.py +0 -0
  241. wandb/sdk/interface/constants.py +4 -0
  242. wandb/sdk/interface/interface.py +962 -0
  243. wandb/sdk/interface/interface_queue.py +59 -0
  244. wandb/sdk/interface/interface_relay.py +53 -0
  245. wandb/sdk/interface/interface_shared.py +550 -0
  246. wandb/sdk/interface/interface_sock.py +61 -0
  247. wandb/sdk/interface/message_future.py +27 -0
  248. wandb/sdk/interface/message_future_poll.py +50 -0
  249. wandb/sdk/interface/router.py +118 -0
  250. wandb/sdk/interface/router_queue.py +44 -0
  251. wandb/sdk/interface/router_relay.py +39 -0
  252. wandb/sdk/interface/router_sock.py +36 -0
  253. wandb/sdk/interface/summary_record.py +67 -0
  254. wandb/sdk/internal/__init__.py +0 -0
  255. wandb/sdk/internal/context.py +89 -0
  256. wandb/sdk/internal/datastore.py +297 -0
  257. wandb/sdk/internal/file_pusher.py +181 -0
  258. wandb/sdk/internal/file_stream.py +695 -0
  259. wandb/sdk/internal/flow_control.py +263 -0
  260. wandb/sdk/internal/handler.py +909 -0
  261. wandb/sdk/internal/internal.py +417 -0
  262. wandb/sdk/internal/internal_api.py +4074 -0
  263. wandb/sdk/internal/internal_util.py +100 -0
  264. wandb/sdk/internal/job_builder.py +667 -0
  265. wandb/sdk/internal/profiler.py +78 -0
  266. wandb/sdk/internal/progress.py +83 -0
  267. wandb/sdk/internal/run.py +25 -0
  268. wandb/sdk/internal/sample.py +70 -0
  269. wandb/sdk/internal/sender.py +1641 -0
  270. wandb/sdk/internal/sender_config.py +197 -0
  271. wandb/sdk/internal/settings_static.py +83 -0
  272. wandb/sdk/internal/system/__init__.py +0 -0
  273. wandb/sdk/internal/system/assets/__init__.py +27 -0
  274. wandb/sdk/internal/system/assets/aggregators.py +37 -0
  275. wandb/sdk/internal/system/assets/asset_registry.py +20 -0
  276. wandb/sdk/internal/system/assets/cpu.py +163 -0
  277. wandb/sdk/internal/system/assets/disk.py +210 -0
  278. wandb/sdk/internal/system/assets/gpu.py +414 -0
  279. wandb/sdk/internal/system/assets/gpu_amd.py +230 -0
  280. wandb/sdk/internal/system/assets/gpu_apple.py +177 -0
  281. wandb/sdk/internal/system/assets/interfaces.py +207 -0
  282. wandb/sdk/internal/system/assets/ipu.py +177 -0
  283. wandb/sdk/internal/system/assets/memory.py +166 -0
  284. wandb/sdk/internal/system/assets/network.py +125 -0
  285. wandb/sdk/internal/system/assets/open_metrics.py +299 -0
  286. wandb/sdk/internal/system/assets/tpu.py +154 -0
  287. wandb/sdk/internal/system/assets/trainium.py +398 -0
  288. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  289. wandb/sdk/internal/system/system_info.py +247 -0
  290. wandb/sdk/internal/system/system_monitor.py +229 -0
  291. wandb/sdk/internal/tb_watcher.py +518 -0
  292. wandb/sdk/internal/thread_local_settings.py +18 -0
  293. wandb/sdk/internal/update.py +113 -0
  294. wandb/sdk/internal/writer.py +206 -0
  295. wandb/sdk/launch/__init__.py +14 -0
  296. wandb/sdk/launch/_launch.py +328 -0
  297. wandb/sdk/launch/_launch_add.py +255 -0
  298. wandb/sdk/launch/_project_spec.py +538 -0
  299. wandb/sdk/launch/agent/__init__.py +5 -0
  300. wandb/sdk/launch/agent/agent.py +901 -0
  301. wandb/sdk/launch/agent/config.py +296 -0
  302. wandb/sdk/launch/agent/job_status_tracker.py +53 -0
  303. wandb/sdk/launch/agent/run_queue_item_file_saver.py +47 -0
  304. wandb/sdk/launch/builder/__init__.py +0 -0
  305. wandb/sdk/launch/builder/abstract.py +156 -0
  306. wandb/sdk/launch/builder/build.py +280 -0
  307. wandb/sdk/launch/builder/context_manager.py +235 -0
  308. wandb/sdk/launch/builder/docker_builder.py +177 -0
  309. wandb/sdk/launch/builder/kaniko_builder.py +566 -0
  310. wandb/sdk/launch/builder/noop.py +58 -0
  311. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +187 -0
  312. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  313. wandb/sdk/launch/create_job.py +512 -0
  314. wandb/sdk/launch/environment/abstract.py +29 -0
  315. wandb/sdk/launch/environment/aws_environment.py +297 -0
  316. wandb/sdk/launch/environment/azure_environment.py +105 -0
  317. wandb/sdk/launch/environment/gcp_environment.py +335 -0
  318. wandb/sdk/launch/environment/local_environment.py +66 -0
  319. wandb/sdk/launch/errors.py +19 -0
  320. wandb/sdk/launch/git_reference.py +109 -0
  321. wandb/sdk/launch/inputs/files.py +148 -0
  322. wandb/sdk/launch/inputs/internal.py +217 -0
  323. wandb/sdk/launch/inputs/manage.py +95 -0
  324. wandb/sdk/launch/loader.py +249 -0
  325. wandb/sdk/launch/registry/abstract.py +48 -0
  326. wandb/sdk/launch/registry/anon.py +29 -0
  327. wandb/sdk/launch/registry/azure_container_registry.py +124 -0
  328. wandb/sdk/launch/registry/elastic_container_registry.py +192 -0
  329. wandb/sdk/launch/registry/google_artifact_registry.py +219 -0
  330. wandb/sdk/launch/registry/local_registry.py +67 -0
  331. wandb/sdk/launch/runner/__init__.py +0 -0
  332. wandb/sdk/launch/runner/abstract.py +195 -0
  333. wandb/sdk/launch/runner/kubernetes_monitor.py +441 -0
  334. wandb/sdk/launch/runner/kubernetes_runner.py +891 -0
  335. wandb/sdk/launch/runner/local_container.py +298 -0
  336. wandb/sdk/launch/runner/local_process.py +78 -0
  337. wandb/sdk/launch/runner/sagemaker_runner.py +422 -0
  338. wandb/sdk/launch/runner/vertex_runner.py +230 -0
  339. wandb/sdk/launch/sweeps/__init__.py +39 -0
  340. wandb/sdk/launch/sweeps/scheduler.py +738 -0
  341. wandb/sdk/launch/sweeps/scheduler_sweep.py +91 -0
  342. wandb/sdk/launch/sweeps/utils.py +316 -0
  343. wandb/sdk/launch/utils.py +741 -0
  344. wandb/sdk/launch/wandb_reference.py +138 -0
  345. wandb/sdk/lib/__init__.py +5 -0
  346. wandb/sdk/lib/_settings_toposort_generate.py +159 -0
  347. wandb/sdk/lib/_settings_toposort_generated.py +239 -0
  348. wandb/sdk/lib/_wburls_generate.py +25 -0
  349. wandb/sdk/lib/_wburls_generated.py +22 -0
  350. wandb/sdk/lib/apikey.py +258 -0
  351. wandb/sdk/lib/capped_dict.py +26 -0
  352. wandb/sdk/lib/config_util.py +101 -0
  353. wandb/sdk/lib/console.py +39 -0
  354. wandb/sdk/lib/deprecate.py +42 -0
  355. wandb/sdk/lib/disabled.py +190 -0
  356. wandb/sdk/lib/exit_hooks.py +54 -0
  357. wandb/sdk/lib/file_stream_utils.py +118 -0
  358. wandb/sdk/lib/filenames.py +64 -0
  359. wandb/sdk/lib/filesystem.py +372 -0
  360. wandb/sdk/lib/fsm.py +174 -0
  361. wandb/sdk/lib/gitlib.py +239 -0
  362. wandb/sdk/lib/gql_request.py +65 -0
  363. wandb/sdk/lib/handler_util.py +21 -0
  364. wandb/sdk/lib/hashutil.py +62 -0
  365. wandb/sdk/lib/import_hooks.py +275 -0
  366. wandb/sdk/lib/ipython.py +146 -0
  367. wandb/sdk/lib/json_util.py +80 -0
  368. wandb/sdk/lib/lazyloader.py +63 -0
  369. wandb/sdk/lib/mailbox.py +460 -0
  370. wandb/sdk/lib/module.py +69 -0
  371. wandb/sdk/lib/paths.py +106 -0
  372. wandb/sdk/lib/preinit.py +42 -0
  373. wandb/sdk/lib/printer.py +313 -0
  374. wandb/sdk/lib/proto_util.py +90 -0
  375. wandb/sdk/lib/redirect.py +845 -0
  376. wandb/sdk/lib/reporting.py +99 -0
  377. wandb/sdk/lib/retry.py +289 -0
  378. wandb/sdk/lib/run_moment.py +78 -0
  379. wandb/sdk/lib/runid.py +12 -0
  380. wandb/sdk/lib/server.py +52 -0
  381. wandb/sdk/lib/sock_client.py +291 -0
  382. wandb/sdk/lib/sparkline.py +45 -0
  383. wandb/sdk/lib/telemetry.py +100 -0
  384. wandb/sdk/lib/timed_input.py +133 -0
  385. wandb/sdk/lib/timer.py +19 -0
  386. wandb/sdk/lib/tracelog.py +255 -0
  387. wandb/sdk/lib/wburls.py +46 -0
  388. wandb/sdk/service/__init__.py +0 -0
  389. wandb/sdk/service/_startup_debug.py +22 -0
  390. wandb/sdk/service/port_file.py +53 -0
  391. wandb/sdk/service/server.py +119 -0
  392. wandb/sdk/service/server_sock.py +276 -0
  393. wandb/sdk/service/service.py +267 -0
  394. wandb/sdk/service/service_base.py +50 -0
  395. wandb/sdk/service/service_sock.py +70 -0
  396. wandb/sdk/service/streams.py +426 -0
  397. wandb/sdk/verify/__init__.py +0 -0
  398. wandb/sdk/verify/verify.py +501 -0
  399. wandb/sdk/wandb_alerts.py +12 -0
  400. wandb/sdk/wandb_config.py +319 -0
  401. wandb/sdk/wandb_helper.py +54 -0
  402. wandb/sdk/wandb_init.py +1179 -0
  403. wandb/sdk/wandb_login.py +339 -0
  404. wandb/sdk/wandb_manager.py +222 -0
  405. wandb/sdk/wandb_metric.py +110 -0
  406. wandb/sdk/wandb_require.py +92 -0
  407. wandb/sdk/wandb_require_helpers.py +44 -0
  408. wandb/sdk/wandb_run.py +4247 -0
  409. wandb/sdk/wandb_settings.py +1926 -0
  410. wandb/sdk/wandb_setup.py +335 -0
  411. wandb/sdk/wandb_summary.py +150 -0
  412. wandb/sdk/wandb_sweep.py +114 -0
  413. wandb/sdk/wandb_sync.py +75 -0
  414. wandb/sdk/wandb_watch.py +128 -0
  415. wandb/sklearn/__init__.py +37 -0
  416. wandb/sklearn/calculate/__init__.py +32 -0
  417. wandb/sklearn/calculate/calibration_curves.py +125 -0
  418. wandb/sklearn/calculate/class_proportions.py +68 -0
  419. wandb/sklearn/calculate/confusion_matrix.py +92 -0
  420. wandb/sklearn/calculate/decision_boundaries.py +40 -0
  421. wandb/sklearn/calculate/elbow_curve.py +55 -0
  422. wandb/sklearn/calculate/feature_importances.py +67 -0
  423. wandb/sklearn/calculate/learning_curve.py +64 -0
  424. wandb/sklearn/calculate/outlier_candidates.py +69 -0
  425. wandb/sklearn/calculate/residuals.py +86 -0
  426. wandb/sklearn/calculate/silhouette.py +118 -0
  427. wandb/sklearn/calculate/summary_metrics.py +62 -0
  428. wandb/sklearn/plot/__init__.py +35 -0
  429. wandb/sklearn/plot/classifier.py +329 -0
  430. wandb/sklearn/plot/clusterer.py +142 -0
  431. wandb/sklearn/plot/regressor.py +121 -0
  432. wandb/sklearn/plot/shared.py +91 -0
  433. wandb/sklearn/utils.py +183 -0
  434. wandb/sync/__init__.py +3 -0
  435. wandb/sync/sync.py +443 -0
  436. wandb/testing/relay.py +859 -0
  437. wandb/trigger.py +29 -0
  438. wandb/util.py +1915 -0
  439. wandb/vendor/__init__.py +0 -0
  440. wandb/vendor/gql-0.2.0/setup.py +40 -0
  441. wandb/vendor/gql-0.2.0/tests/__init__.py +0 -0
  442. wandb/vendor/gql-0.2.0/tests/starwars/__init__.py +0 -0
  443. wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py +96 -0
  444. wandb/vendor/gql-0.2.0/tests/starwars/schema.py +146 -0
  445. wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py +293 -0
  446. wandb/vendor/gql-0.2.0/tests/starwars/test_query.py +355 -0
  447. wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py +171 -0
  448. wandb/vendor/gql-0.2.0/tests/test_client.py +31 -0
  449. wandb/vendor/gql-0.2.0/tests/test_transport.py +89 -0
  450. wandb/vendor/gql-0.2.0/wandb_gql/__init__.py +4 -0
  451. wandb/vendor/gql-0.2.0/wandb_gql/client.py +75 -0
  452. wandb/vendor/gql-0.2.0/wandb_gql/dsl.py +152 -0
  453. wandb/vendor/gql-0.2.0/wandb_gql/gql.py +10 -0
  454. wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py +0 -0
  455. wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py +6 -0
  456. wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py +15 -0
  457. wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py +46 -0
  458. wandb/vendor/gql-0.2.0/wandb_gql/utils.py +21 -0
  459. wandb/vendor/graphql-core-1.1/setup.py +86 -0
  460. wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py +287 -0
  461. wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py +6 -0
  462. wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py +42 -0
  463. wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py +11 -0
  464. wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py +29 -0
  465. wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py +36 -0
  466. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py +26 -0
  467. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py +311 -0
  468. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py +398 -0
  469. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py +0 -0
  470. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py +53 -0
  471. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py +22 -0
  472. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py +32 -0
  473. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py +7 -0
  474. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py +35 -0
  475. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py +6 -0
  476. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py +0 -0
  477. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py +66 -0
  478. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py +252 -0
  479. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py +151 -0
  480. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py +7 -0
  481. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py +57 -0
  482. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py +145 -0
  483. wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py +60 -0
  484. wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py +0 -0
  485. wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py +1349 -0
  486. wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py +19 -0
  487. wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py +435 -0
  488. wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py +30 -0
  489. wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py +779 -0
  490. wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py +193 -0
  491. wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py +18 -0
  492. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py +222 -0
  493. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py +82 -0
  494. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py +0 -0
  495. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py +17 -0
  496. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py +28 -0
  497. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py +40 -0
  498. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py +8 -0
  499. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py +43 -0
  500. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py +78 -0
  501. wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py +67 -0
  502. wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py +619 -0
  503. wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py +132 -0
  504. wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py +440 -0
  505. wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py +131 -0
  506. wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py +100 -0
  507. wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py +145 -0
  508. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py +0 -0
  509. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py +9 -0
  510. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py +65 -0
  511. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py +49 -0
  512. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py +24 -0
  513. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py +75 -0
  514. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py +291 -0
  515. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py +250 -0
  516. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py +9 -0
  517. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py +357 -0
  518. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py +27 -0
  519. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py +21 -0
  520. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py +90 -0
  521. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py +67 -0
  522. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py +66 -0
  523. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py +21 -0
  524. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py +168 -0
  525. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py +56 -0
  526. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py +69 -0
  527. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py +21 -0
  528. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py +149 -0
  529. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py +69 -0
  530. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py +4 -0
  531. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py +79 -0
  532. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py +24 -0
  533. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py +8 -0
  534. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py +44 -0
  535. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py +113 -0
  536. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py +33 -0
  537. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py +70 -0
  538. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py +97 -0
  539. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py +19 -0
  540. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py +43 -0
  541. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py +23 -0
  542. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py +59 -0
  543. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py +36 -0
  544. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py +38 -0
  545. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py +37 -0
  546. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py +529 -0
  547. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py +44 -0
  548. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py +46 -0
  549. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py +33 -0
  550. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py +32 -0
  551. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py +28 -0
  552. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py +33 -0
  553. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py +31 -0
  554. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py +27 -0
  555. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py +21 -0
  556. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py +53 -0
  557. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py +158 -0
  558. wandb/vendor/promise-2.3.0/conftest.py +30 -0
  559. wandb/vendor/promise-2.3.0/setup.py +64 -0
  560. wandb/vendor/promise-2.3.0/tests/__init__.py +0 -0
  561. wandb/vendor/promise-2.3.0/tests/conftest.py +8 -0
  562. wandb/vendor/promise-2.3.0/tests/test_awaitable.py +32 -0
  563. wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py +47 -0
  564. wandb/vendor/promise-2.3.0/tests/test_benchmark.py +116 -0
  565. wandb/vendor/promise-2.3.0/tests/test_complex_threads.py +23 -0
  566. wandb/vendor/promise-2.3.0/tests/test_dataloader.py +452 -0
  567. wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py +99 -0
  568. wandb/vendor/promise-2.3.0/tests/test_dataloader_extra.py +65 -0
  569. wandb/vendor/promise-2.3.0/tests/test_extra.py +670 -0
  570. wandb/vendor/promise-2.3.0/tests/test_issues.py +132 -0
  571. wandb/vendor/promise-2.3.0/tests/test_promise_list.py +70 -0
  572. wandb/vendor/promise-2.3.0/tests/test_spec.py +584 -0
  573. wandb/vendor/promise-2.3.0/tests/test_thread_safety.py +115 -0
  574. wandb/vendor/promise-2.3.0/tests/utils.py +3 -0
  575. wandb/vendor/promise-2.3.0/wandb_promise/__init__.py +38 -0
  576. wandb/vendor/promise-2.3.0/wandb_promise/async_.py +135 -0
  577. wandb/vendor/promise-2.3.0/wandb_promise/compat.py +32 -0
  578. wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py +326 -0
  579. wandb/vendor/promise-2.3.0/wandb_promise/iterate_promise.py +12 -0
  580. wandb/vendor/promise-2.3.0/wandb_promise/promise.py +848 -0
  581. wandb/vendor/promise-2.3.0/wandb_promise/promise_list.py +151 -0
  582. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/__init__.py +0 -0
  583. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/version.py +83 -0
  584. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/__init__.py +0 -0
  585. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/asyncio.py +22 -0
  586. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/gevent.py +21 -0
  587. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/immediate.py +27 -0
  588. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/thread.py +18 -0
  589. wandb/vendor/promise-2.3.0/wandb_promise/utils.py +56 -0
  590. wandb/vendor/pygments/__init__.py +90 -0
  591. wandb/vendor/pygments/cmdline.py +568 -0
  592. wandb/vendor/pygments/console.py +74 -0
  593. wandb/vendor/pygments/filter.py +74 -0
  594. wandb/vendor/pygments/filters/__init__.py +350 -0
  595. wandb/vendor/pygments/formatter.py +95 -0
  596. wandb/vendor/pygments/formatters/__init__.py +153 -0
  597. wandb/vendor/pygments/formatters/_mapping.py +85 -0
  598. wandb/vendor/pygments/formatters/bbcode.py +109 -0
  599. wandb/vendor/pygments/formatters/html.py +851 -0
  600. wandb/vendor/pygments/formatters/img.py +600 -0
  601. wandb/vendor/pygments/formatters/irc.py +182 -0
  602. wandb/vendor/pygments/formatters/latex.py +482 -0
  603. wandb/vendor/pygments/formatters/other.py +160 -0
  604. wandb/vendor/pygments/formatters/rtf.py +147 -0
  605. wandb/vendor/pygments/formatters/svg.py +153 -0
  606. wandb/vendor/pygments/formatters/terminal.py +136 -0
  607. wandb/vendor/pygments/formatters/terminal256.py +309 -0
  608. wandb/vendor/pygments/lexer.py +871 -0
  609. wandb/vendor/pygments/lexers/__init__.py +329 -0
  610. wandb/vendor/pygments/lexers/_asy_builtins.py +1645 -0
  611. wandb/vendor/pygments/lexers/_cl_builtins.py +232 -0
  612. wandb/vendor/pygments/lexers/_cocoa_builtins.py +72 -0
  613. wandb/vendor/pygments/lexers/_csound_builtins.py +1346 -0
  614. wandb/vendor/pygments/lexers/_lasso_builtins.py +5327 -0
  615. wandb/vendor/pygments/lexers/_lua_builtins.py +295 -0
  616. wandb/vendor/pygments/lexers/_mapping.py +500 -0
  617. wandb/vendor/pygments/lexers/_mql_builtins.py +1172 -0
  618. wandb/vendor/pygments/lexers/_openedge_builtins.py +2547 -0
  619. wandb/vendor/pygments/lexers/_php_builtins.py +4756 -0
  620. wandb/vendor/pygments/lexers/_postgres_builtins.py +621 -0
  621. wandb/vendor/pygments/lexers/_scilab_builtins.py +3094 -0
  622. wandb/vendor/pygments/lexers/_sourcemod_builtins.py +1163 -0
  623. wandb/vendor/pygments/lexers/_stan_builtins.py +532 -0
  624. wandb/vendor/pygments/lexers/_stata_builtins.py +419 -0
  625. wandb/vendor/pygments/lexers/_tsql_builtins.py +1004 -0
  626. wandb/vendor/pygments/lexers/_vim_builtins.py +1939 -0
  627. wandb/vendor/pygments/lexers/actionscript.py +240 -0
  628. wandb/vendor/pygments/lexers/agile.py +24 -0
  629. wandb/vendor/pygments/lexers/algebra.py +221 -0
  630. wandb/vendor/pygments/lexers/ambient.py +76 -0
  631. wandb/vendor/pygments/lexers/ampl.py +87 -0
  632. wandb/vendor/pygments/lexers/apl.py +101 -0
  633. wandb/vendor/pygments/lexers/archetype.py +318 -0
  634. wandb/vendor/pygments/lexers/asm.py +641 -0
  635. wandb/vendor/pygments/lexers/automation.py +374 -0
  636. wandb/vendor/pygments/lexers/basic.py +500 -0
  637. wandb/vendor/pygments/lexers/bibtex.py +160 -0
  638. wandb/vendor/pygments/lexers/business.py +612 -0
  639. wandb/vendor/pygments/lexers/c_cpp.py +252 -0
  640. wandb/vendor/pygments/lexers/c_like.py +541 -0
  641. wandb/vendor/pygments/lexers/capnproto.py +78 -0
  642. wandb/vendor/pygments/lexers/chapel.py +102 -0
  643. wandb/vendor/pygments/lexers/clean.py +288 -0
  644. wandb/vendor/pygments/lexers/compiled.py +34 -0
  645. wandb/vendor/pygments/lexers/configs.py +833 -0
  646. wandb/vendor/pygments/lexers/console.py +114 -0
  647. wandb/vendor/pygments/lexers/crystal.py +393 -0
  648. wandb/vendor/pygments/lexers/csound.py +366 -0
  649. wandb/vendor/pygments/lexers/css.py +689 -0
  650. wandb/vendor/pygments/lexers/d.py +251 -0
  651. wandb/vendor/pygments/lexers/dalvik.py +125 -0
  652. wandb/vendor/pygments/lexers/data.py +555 -0
  653. wandb/vendor/pygments/lexers/diff.py +165 -0
  654. wandb/vendor/pygments/lexers/dotnet.py +691 -0
  655. wandb/vendor/pygments/lexers/dsls.py +878 -0
  656. wandb/vendor/pygments/lexers/dylan.py +289 -0
  657. wandb/vendor/pygments/lexers/ecl.py +125 -0
  658. wandb/vendor/pygments/lexers/eiffel.py +65 -0
  659. wandb/vendor/pygments/lexers/elm.py +121 -0
  660. wandb/vendor/pygments/lexers/erlang.py +533 -0
  661. wandb/vendor/pygments/lexers/esoteric.py +277 -0
  662. wandb/vendor/pygments/lexers/ezhil.py +69 -0
  663. wandb/vendor/pygments/lexers/factor.py +344 -0
  664. wandb/vendor/pygments/lexers/fantom.py +250 -0
  665. wandb/vendor/pygments/lexers/felix.py +273 -0
  666. wandb/vendor/pygments/lexers/forth.py +177 -0
  667. wandb/vendor/pygments/lexers/fortran.py +205 -0
  668. wandb/vendor/pygments/lexers/foxpro.py +428 -0
  669. wandb/vendor/pygments/lexers/functional.py +21 -0
  670. wandb/vendor/pygments/lexers/go.py +101 -0
  671. wandb/vendor/pygments/lexers/grammar_notation.py +213 -0
  672. wandb/vendor/pygments/lexers/graph.py +80 -0
  673. wandb/vendor/pygments/lexers/graphics.py +553 -0
  674. wandb/vendor/pygments/lexers/haskell.py +843 -0
  675. wandb/vendor/pygments/lexers/haxe.py +936 -0
  676. wandb/vendor/pygments/lexers/hdl.py +382 -0
  677. wandb/vendor/pygments/lexers/hexdump.py +103 -0
  678. wandb/vendor/pygments/lexers/html.py +602 -0
  679. wandb/vendor/pygments/lexers/idl.py +270 -0
  680. wandb/vendor/pygments/lexers/igor.py +288 -0
  681. wandb/vendor/pygments/lexers/inferno.py +96 -0
  682. wandb/vendor/pygments/lexers/installers.py +322 -0
  683. wandb/vendor/pygments/lexers/int_fiction.py +1343 -0
  684. wandb/vendor/pygments/lexers/iolang.py +63 -0
  685. wandb/vendor/pygments/lexers/j.py +146 -0
  686. wandb/vendor/pygments/lexers/javascript.py +1525 -0
  687. wandb/vendor/pygments/lexers/julia.py +333 -0
  688. wandb/vendor/pygments/lexers/jvm.py +1573 -0
  689. wandb/vendor/pygments/lexers/lisp.py +2621 -0
  690. wandb/vendor/pygments/lexers/make.py +202 -0
  691. wandb/vendor/pygments/lexers/markup.py +595 -0
  692. wandb/vendor/pygments/lexers/math.py +21 -0
  693. wandb/vendor/pygments/lexers/matlab.py +663 -0
  694. wandb/vendor/pygments/lexers/ml.py +769 -0
  695. wandb/vendor/pygments/lexers/modeling.py +358 -0
  696. wandb/vendor/pygments/lexers/modula2.py +1561 -0
  697. wandb/vendor/pygments/lexers/monte.py +204 -0
  698. wandb/vendor/pygments/lexers/ncl.py +894 -0
  699. wandb/vendor/pygments/lexers/nimrod.py +159 -0
  700. wandb/vendor/pygments/lexers/nit.py +64 -0
  701. wandb/vendor/pygments/lexers/nix.py +136 -0
  702. wandb/vendor/pygments/lexers/oberon.py +105 -0
  703. wandb/vendor/pygments/lexers/objective.py +504 -0
  704. wandb/vendor/pygments/lexers/ooc.py +85 -0
  705. wandb/vendor/pygments/lexers/other.py +41 -0
  706. wandb/vendor/pygments/lexers/parasail.py +79 -0
  707. wandb/vendor/pygments/lexers/parsers.py +835 -0
  708. wandb/vendor/pygments/lexers/pascal.py +644 -0
  709. wandb/vendor/pygments/lexers/pawn.py +199 -0
  710. wandb/vendor/pygments/lexers/perl.py +620 -0
  711. wandb/vendor/pygments/lexers/php.py +267 -0
  712. wandb/vendor/pygments/lexers/praat.py +294 -0
  713. wandb/vendor/pygments/lexers/prolog.py +306 -0
  714. wandb/vendor/pygments/lexers/python.py +939 -0
  715. wandb/vendor/pygments/lexers/qvt.py +152 -0
  716. wandb/vendor/pygments/lexers/r.py +453 -0
  717. wandb/vendor/pygments/lexers/rdf.py +270 -0
  718. wandb/vendor/pygments/lexers/rebol.py +431 -0
  719. wandb/vendor/pygments/lexers/resource.py +85 -0
  720. wandb/vendor/pygments/lexers/rnc.py +67 -0
  721. wandb/vendor/pygments/lexers/roboconf.py +82 -0
  722. wandb/vendor/pygments/lexers/robotframework.py +560 -0
  723. wandb/vendor/pygments/lexers/ruby.py +519 -0
  724. wandb/vendor/pygments/lexers/rust.py +220 -0
  725. wandb/vendor/pygments/lexers/sas.py +228 -0
  726. wandb/vendor/pygments/lexers/scripting.py +1222 -0
  727. wandb/vendor/pygments/lexers/shell.py +794 -0
  728. wandb/vendor/pygments/lexers/smalltalk.py +195 -0
  729. wandb/vendor/pygments/lexers/smv.py +79 -0
  730. wandb/vendor/pygments/lexers/snobol.py +83 -0
  731. wandb/vendor/pygments/lexers/special.py +103 -0
  732. wandb/vendor/pygments/lexers/sql.py +681 -0
  733. wandb/vendor/pygments/lexers/stata.py +108 -0
  734. wandb/vendor/pygments/lexers/supercollider.py +90 -0
  735. wandb/vendor/pygments/lexers/tcl.py +145 -0
  736. wandb/vendor/pygments/lexers/templates.py +2283 -0
  737. wandb/vendor/pygments/lexers/testing.py +207 -0
  738. wandb/vendor/pygments/lexers/text.py +25 -0
  739. wandb/vendor/pygments/lexers/textedit.py +169 -0
  740. wandb/vendor/pygments/lexers/textfmts.py +297 -0
  741. wandb/vendor/pygments/lexers/theorem.py +458 -0
  742. wandb/vendor/pygments/lexers/trafficscript.py +54 -0
  743. wandb/vendor/pygments/lexers/typoscript.py +226 -0
  744. wandb/vendor/pygments/lexers/urbi.py +133 -0
  745. wandb/vendor/pygments/lexers/varnish.py +190 -0
  746. wandb/vendor/pygments/lexers/verification.py +111 -0
  747. wandb/vendor/pygments/lexers/web.py +24 -0
  748. wandb/vendor/pygments/lexers/webmisc.py +988 -0
  749. wandb/vendor/pygments/lexers/whiley.py +116 -0
  750. wandb/vendor/pygments/lexers/x10.py +69 -0
  751. wandb/vendor/pygments/modeline.py +44 -0
  752. wandb/vendor/pygments/plugin.py +68 -0
  753. wandb/vendor/pygments/regexopt.py +92 -0
  754. wandb/vendor/pygments/scanner.py +105 -0
  755. wandb/vendor/pygments/sphinxext.py +158 -0
  756. wandb/vendor/pygments/style.py +155 -0
  757. wandb/vendor/pygments/styles/__init__.py +80 -0
  758. wandb/vendor/pygments/styles/abap.py +29 -0
  759. wandb/vendor/pygments/styles/algol.py +63 -0
  760. wandb/vendor/pygments/styles/algol_nu.py +63 -0
  761. wandb/vendor/pygments/styles/arduino.py +98 -0
  762. wandb/vendor/pygments/styles/autumn.py +65 -0
  763. wandb/vendor/pygments/styles/borland.py +51 -0
  764. wandb/vendor/pygments/styles/bw.py +49 -0
  765. wandb/vendor/pygments/styles/colorful.py +81 -0
  766. wandb/vendor/pygments/styles/default.py +73 -0
  767. wandb/vendor/pygments/styles/emacs.py +72 -0
  768. wandb/vendor/pygments/styles/friendly.py +72 -0
  769. wandb/vendor/pygments/styles/fruity.py +42 -0
  770. wandb/vendor/pygments/styles/igor.py +29 -0
  771. wandb/vendor/pygments/styles/lovelace.py +97 -0
  772. wandb/vendor/pygments/styles/manni.py +75 -0
  773. wandb/vendor/pygments/styles/monokai.py +106 -0
  774. wandb/vendor/pygments/styles/murphy.py +80 -0
  775. wandb/vendor/pygments/styles/native.py +65 -0
  776. wandb/vendor/pygments/styles/paraiso_dark.py +125 -0
  777. wandb/vendor/pygments/styles/paraiso_light.py +125 -0
  778. wandb/vendor/pygments/styles/pastie.py +75 -0
  779. wandb/vendor/pygments/styles/perldoc.py +69 -0
  780. wandb/vendor/pygments/styles/rainbow_dash.py +89 -0
  781. wandb/vendor/pygments/styles/rrt.py +33 -0
  782. wandb/vendor/pygments/styles/sas.py +44 -0
  783. wandb/vendor/pygments/styles/stata.py +40 -0
  784. wandb/vendor/pygments/styles/tango.py +141 -0
  785. wandb/vendor/pygments/styles/trac.py +63 -0
  786. wandb/vendor/pygments/styles/vim.py +63 -0
  787. wandb/vendor/pygments/styles/vs.py +38 -0
  788. wandb/vendor/pygments/styles/xcode.py +51 -0
  789. wandb/vendor/pygments/token.py +213 -0
  790. wandb/vendor/pygments/unistring.py +217 -0
  791. wandb/vendor/pygments/util.py +388 -0
  792. wandb/vendor/pynvml/__init__.py +0 -0
  793. wandb/vendor/pynvml/pynvml.py +4779 -0
  794. wandb/vendor/watchdog_0_9_0/wandb_watchdog/__init__.py +17 -0
  795. wandb/vendor/watchdog_0_9_0/wandb_watchdog/events.py +615 -0
  796. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/__init__.py +98 -0
  797. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/api.py +369 -0
  798. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents.py +172 -0
  799. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents2.py +239 -0
  800. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify.py +218 -0
  801. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_buffer.py +81 -0
  802. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_c.py +575 -0
  803. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/kqueue.py +730 -0
  804. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/polling.py +145 -0
  805. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/read_directory_changes.py +133 -0
  806. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/winapi.py +348 -0
  807. wandb/vendor/watchdog_0_9_0/wandb_watchdog/patterns.py +265 -0
  808. wandb/vendor/watchdog_0_9_0/wandb_watchdog/tricks/__init__.py +174 -0
  809. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/__init__.py +151 -0
  810. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/bricks.py +249 -0
  811. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/compat.py +29 -0
  812. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/decorators.py +198 -0
  813. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/delayed_queue.py +88 -0
  814. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/dirsnapshot.py +293 -0
  815. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/echo.py +157 -0
  816. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/event_backport.py +41 -0
  817. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/importlib2.py +40 -0
  818. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/platform.py +57 -0
  819. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/unicode_paths.py +64 -0
  820. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/win32stat.py +123 -0
  821. wandb/vendor/watchdog_0_9_0/wandb_watchdog/version.py +28 -0
  822. wandb/vendor/watchdog_0_9_0/wandb_watchdog/watchmedo.py +577 -0
  823. wandb/viz.py +123 -0
  824. wandb/wandb_agent.py +586 -0
  825. wandb/wandb_controller.py +720 -0
  826. wandb/wandb_run.py +9 -0
  827. wandb/wandb_torch.py +550 -0
  828. wandb-0.17.0.dist-info/METADATA +217 -0
  829. wandb-0.17.0.dist-info/RECORD +832 -0
  830. wandb-0.17.0.dist-info/WHEEL +6 -0
  831. wandb-0.17.0.dist-info/entry_points.txt +3 -0
  832. wandb-0.17.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,4074 @@
1
+ import ast
2
+ import base64
3
+ import datetime
4
+ import functools
5
+ import http.client
6
+ import json
7
+ import logging
8
+ import os
9
+ import re
10
+ import socket
11
+ import sys
12
+ import threading
13
+ from copy import deepcopy
14
+ from typing import (
15
+ IO,
16
+ TYPE_CHECKING,
17
+ Any,
18
+ Callable,
19
+ Dict,
20
+ Iterable,
21
+ List,
22
+ Mapping,
23
+ MutableMapping,
24
+ Optional,
25
+ Sequence,
26
+ TextIO,
27
+ Tuple,
28
+ Union,
29
+ )
30
+
31
+ import click
32
+ import requests
33
+ import yaml
34
+ from wandb_gql import Client, gql
35
+ from wandb_gql.client import RetryError
36
+
37
+ import wandb
38
+ from wandb import env, util
39
+ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
40
+ from wandb.errors import CommError, UnsupportedError, UsageError
41
+ from wandb.integration.sagemaker import parse_sm_secrets
42
+ from wandb.old.settings import Settings
43
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
44
+ from wandb.sdk.lib.gql_request import GraphQLSession
45
+ from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
46
+
47
+ from ..lib import retry
48
+ from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
49
+ from ..lib.gitlib import GitRepo
50
+ from . import context
51
+ from .progress import Progress
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ if TYPE_CHECKING:
56
+ if sys.version_info >= (3, 8):
57
+ from typing import Literal, TypedDict
58
+ else:
59
+ from typing_extensions import Literal, TypedDict
60
+
61
+ from .progress import ProgressFn
62
+
63
+ class CreateArtifactFileSpecInput(TypedDict, total=False):
64
+ """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
65
+
66
+ artifactID: str # noqa: N815
67
+ name: str
68
+ md5: str
69
+ mimetype: Optional[str]
70
+ artifactManifestID: Optional[str] # noqa: N815
71
+ uploadPartsInput: Optional[List[Dict[str, object]]] # noqa: N815
72
+
73
+ class CreateArtifactFilesResponseFile(TypedDict):
74
+ id: str
75
+ name: str
76
+ displayName: str # noqa: N815
77
+ uploadUrl: Optional[str] # noqa: N815
78
+ uploadHeaders: Sequence[str] # noqa: N815
79
+ uploadMultipartUrls: "UploadPartsResponse" # noqa: N815
80
+ storagePath: str # noqa: N815
81
+ artifact: "CreateArtifactFilesResponseFileNode"
82
+
83
+ class CreateArtifactFilesResponseFileNode(TypedDict):
84
+ id: str
85
+
86
+ class UploadPartsResponse(TypedDict):
87
+ uploadUrlParts: List["UploadUrlParts"] # noqa: N815
88
+ uploadID: str # noqa: N815
89
+
90
+ class UploadUrlParts(TypedDict):
91
+ partNumber: int # noqa: N815
92
+ uploadUrl: str # noqa: N815
93
+
94
+ class CompleteMultipartUploadArtifactInput(TypedDict):
95
+ """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
96
+
97
+ completeMultipartAction: str # noqa: N815
98
+ completedParts: Dict[int, str] # noqa: N815
99
+ artifactID: str # noqa: N815
100
+ storagePath: str # noqa: N815
101
+ uploadID: str # noqa: N815
102
+ md5: str
103
+
104
+ class CompleteMultipartUploadArtifactResponse(TypedDict):
105
+ digest: str
106
+
107
+ class DefaultSettings(TypedDict):
108
+ section: str
109
+ git_remote: str
110
+ ignore_globs: Optional[List[str]]
111
+ base_url: Optional[str]
112
+ root_dir: Optional[str]
113
+ api_key: Optional[str]
114
+ entity: Optional[str]
115
+ project: Optional[str]
116
+ _extra_http_headers: Optional[Mapping[str, str]]
117
+ _proxies: Optional[Mapping[str, str]]
118
+
119
+ _Response = MutableMapping
120
+ SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"]
121
+ Number = Union[int, float]
122
+
123
+ # class _MappingSupportsCopy(Protocol):
124
+ # def copy(self) -> "_MappingSupportsCopy": ...
125
+ # def keys(self) -> Iterable: ...
126
+ # def __getitem__(self, name: str) -> Any: ...
127
+
128
+ httpclient_logger = logging.getLogger("http.client")
129
+ if os.environ.get("WANDB_DEBUG"):
130
+ httpclient_logger.setLevel(logging.DEBUG)
131
+
132
+
133
+ def check_httpclient_logger_handler() -> None:
134
+ # Only enable http.client logging if WANDB_DEBUG is set
135
+ if not os.environ.get("WANDB_DEBUG"):
136
+ return
137
+ if httpclient_logger.handlers:
138
+ return
139
+
140
+ # Enable HTTPConnection debug logging to the logging framework
141
+ level = logging.DEBUG
142
+
143
+ def httpclient_log(*args: Any) -> None:
144
+ httpclient_logger.log(level, " ".join(args))
145
+
146
+ # mask the print() built-in in the http.client module to use logging instead
147
+ http.client.print = httpclient_log # type: ignore[attr-defined]
148
+ # enable debugging
149
+ http.client.HTTPConnection.debuglevel = 1
150
+
151
+ root_logger = logging.getLogger("wandb")
152
+ if root_logger.handlers:
153
+ httpclient_logger.addHandler(root_logger.handlers[0])
154
+
155
+
156
+ class _ThreadLocalData(threading.local):
157
+ context: Optional[context.Context]
158
+
159
+ def __init__(self) -> None:
160
+ self.context = None
161
+
162
+
163
+ class Api:
164
+ """W&B Internal Api wrapper.
165
+
166
+ Note:
167
+ Settings are automatically overridden by looking for
168
+ a `wandb/settings` file in the current working directory or its parent
169
+ directory. If none can be found, we look in the current user's home
170
+ directory.
171
+
172
+ Arguments:
173
+ default_settings(dict, optional): If you aren't using a settings
174
+ file, or you wish to override the section to use in the settings file
175
+ Override the settings here.
176
+ """
177
+
178
+ HTTP_TIMEOUT = env.get_http_timeout(20)
179
+ FILE_PUSHER_TIMEOUT = env.get_file_pusher_timeout()
180
+ _global_context: context.Context
181
+ _local_data: _ThreadLocalData
182
+
183
+ def __init__(
184
+ self,
185
+ default_settings: Optional[
186
+ Union[
187
+ "wandb.sdk.wandb_settings.Settings",
188
+ "wandb.sdk.internal.settings_static.SettingsStatic",
189
+ Settings,
190
+ dict,
191
+ ]
192
+ ] = None,
193
+ load_settings: bool = True,
194
+ retry_timedelta: datetime.timedelta = datetime.timedelta( # noqa: B008 # okay because it's immutable
195
+ days=7
196
+ ),
197
+ environ: MutableMapping = os.environ,
198
+ retry_callback: Optional[Callable[[int, str], Any]] = None,
199
+ ) -> None:
200
+ self._environ = environ
201
+ self._global_context = context.Context()
202
+ self._local_data = _ThreadLocalData()
203
+ self.default_settings: DefaultSettings = {
204
+ "section": "default",
205
+ "git_remote": "origin",
206
+ "ignore_globs": [],
207
+ "base_url": "https://api.wandb.ai",
208
+ "root_dir": None,
209
+ "api_key": None,
210
+ "entity": None,
211
+ "project": None,
212
+ "_extra_http_headers": None,
213
+ "_proxies": None,
214
+ }
215
+ self.retry_timedelta = retry_timedelta
216
+ # todo: Old Settings do not follow the SupportsKeysAndGetItem Protocol
217
+ default_settings = default_settings or {}
218
+ self.default_settings.update(default_settings) # type: ignore
219
+ self.retry_uploads = 10
220
+ self._settings = Settings(
221
+ load_settings=load_settings,
222
+ root_dir=self.default_settings.get("root_dir"),
223
+ )
224
+ self.git = GitRepo(remote=self.settings("git_remote"))
225
+ # Mutable settings set by the _file_stream_api
226
+ self.dynamic_settings = {
227
+ "system_sample_seconds": 2,
228
+ "system_samples": 15,
229
+ "heartbeat_seconds": 30,
230
+ }
231
+
232
+ # todo: remove these hacky hacks after settings refactor is complete
233
+ # keeping this code here to limit scope and so that it is easy to remove later
234
+ extra_http_headers = self.settings("_extra_http_headers") or json.loads(
235
+ self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
236
+ )
237
+ proxies = self.settings("_proxies") or json.loads(
238
+ self._environ.get("WANDB__PROXIES", "{}")
239
+ )
240
+
241
+ auth = None
242
+ if _thread_local_api_settings.cookies is None:
243
+ auth = ("api", self.api_key or "")
244
+ extra_http_headers.update(_thread_local_api_settings.headers or {})
245
+ self.client = Client(
246
+ transport=GraphQLSession(
247
+ headers={
248
+ "User-Agent": self.user_agent,
249
+ "X-WANDB-USERNAME": env.get_username(env=self._environ),
250
+ "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
251
+ **extra_http_headers,
252
+ },
253
+ use_json=True,
254
+ # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
255
+ # https://bugs.python.org/issue22889
256
+ timeout=self.HTTP_TIMEOUT,
257
+ auth=auth,
258
+ url=f"{self.settings('base_url')}/graphql",
259
+ cookies=_thread_local_api_settings.cookies,
260
+ proxies=proxies,
261
+ )
262
+ )
263
+
264
+ self.retry_callback = retry_callback
265
+ self._retry_gql = retry.Retry(
266
+ self.execute,
267
+ retry_timedelta=retry_timedelta,
268
+ check_retry_fn=util.no_retry_auth,
269
+ retryable_exceptions=(RetryError, requests.RequestException),
270
+ retry_callback=retry_callback,
271
+ )
272
+ self._current_run_id: Optional[str] = None
273
+ self._file_stream_api = None
274
+ self._upload_file_session = requests.Session()
275
+ if self.FILE_PUSHER_TIMEOUT:
276
+ self._upload_file_session.put = functools.partial( # type: ignore
277
+ self._upload_file_session.put,
278
+ timeout=self.FILE_PUSHER_TIMEOUT,
279
+ )
280
+ if proxies:
281
+ self._upload_file_session.proxies.update(proxies)
282
+ # This Retry class is initialized once for each Api instance, so this
283
+ # defaults to retrying 1 million times per process or 7 days
284
+ self.upload_file_retry = normalize_exceptions(
285
+ retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
286
+ )
287
+ self.upload_multipart_file_chunk_retry = normalize_exceptions(
288
+ retry.retriable(retry_timedelta=retry_timedelta)(
289
+ self.upload_multipart_file_chunk
290
+ )
291
+ )
292
+ self._client_id_mapping: Dict[str, str] = {}
293
+ # Large file uploads to azure can optionally use their SDK
294
+ self._azure_blob_module = util.get_module("azure.storage.blob")
295
+
296
+ self.query_types: Optional[List[str]] = None
297
+ self.mutation_types: Optional[List[str]] = None
298
+ self.server_info_types: Optional[List[str]] = None
299
+ self.server_use_artifact_input_info: Optional[List[str]] = None
300
+ self.server_create_artifact_input_info: Optional[List[str]] = None
301
+ self.server_artifact_fields_info: Optional[List[str]] = None
302
+ self._max_cli_version: Optional[str] = None
303
+ self._server_settings_type: Optional[List[str]] = None
304
+ self.fail_run_queue_item_input_info: Optional[List[str]] = None
305
+ self.create_launch_agent_input_info: Optional[List[str]] = None
306
+ self.server_create_run_queue_supports_drc: Optional[bool] = None
307
+ self.server_create_run_queue_supports_priority: Optional[bool] = None
308
+ self.server_supports_template_variables: Optional[bool] = None
309
+ self.server_push_to_run_queue_supports_priority: Optional[bool] = None
310
+
311
+ def gql(self, *args: Any, **kwargs: Any) -> Any:
312
+ ret = self._retry_gql(
313
+ *args,
314
+ retry_cancel_event=self.context.cancel_event,
315
+ **kwargs,
316
+ )
317
+ return ret
318
+
319
+ def set_local_context(self, api_context: Optional[context.Context]) -> None:
320
+ self._local_data.context = api_context
321
+
322
+ def clear_local_context(self) -> None:
323
+ self._local_data.context = None
324
+
325
+ @property
326
+ def context(self) -> context.Context:
327
+ return self._local_data.context or self._global_context
328
+
329
+ def reauth(self) -> None:
330
+ """Ensure the current api key is set in the transport."""
331
+ self.client.transport.session.auth = ("api", self.api_key or "")
332
+
333
+ def relocate(self) -> None:
334
+ """Ensure the current api points to the right server."""
335
+ self.client.transport.url = "%s/graphql" % self.settings("base_url")
336
+
337
+ def execute(self, *args: Any, **kwargs: Any) -> "_Response":
338
+ """Wrapper around execute that logs in cases of failure."""
339
+ try:
340
+ return self.client.execute(*args, **kwargs) # type: ignore
341
+ except requests.exceptions.HTTPError as err:
342
+ response = err.response
343
+ assert response is not None
344
+ logger.error(f"{response.status_code} response executing GraphQL.")
345
+ logger.error(response.text)
346
+ for error in parse_backend_error_messages(response):
347
+ wandb.termerror(f"Error while calling W&B API: {error} ({response})")
348
+ raise
349
+
350
+ def disabled(self) -> Union[str, bool]:
351
+ return self._settings.get(Settings.DEFAULT_SECTION, "disabled", fallback=False) # type: ignore
352
+
353
+ def set_current_run_id(self, run_id: str) -> None:
354
+ self._current_run_id = run_id
355
+
356
+ @property
357
+ def current_run_id(self) -> Optional[str]:
358
+ return self._current_run_id
359
+
360
+ @property
361
+ def user_agent(self) -> str:
362
+ return f"W&B Internal Client {wandb.__version__}"
363
+
364
+ @property
365
+ def api_key(self) -> Optional[str]:
366
+ if _thread_local_api_settings.api_key:
367
+ return _thread_local_api_settings.api_key
368
+ auth = requests.utils.get_netrc_auth(self.api_url)
369
+ key = None
370
+ if auth:
371
+ key = auth[-1]
372
+
373
+ # Environment should take precedence
374
+ env_key: Optional[str] = self._environ.get(env.API_KEY)
375
+ sagemaker_key: Optional[str] = parse_sm_secrets().get(env.API_KEY)
376
+ default_key: Optional[str] = self.default_settings.get("api_key")
377
+ return env_key or key or sagemaker_key or default_key
378
+
379
+ @property
380
+ def api_url(self) -> str:
381
+ return self.settings("base_url") # type: ignore
382
+
383
+ @property
384
+ def app_url(self) -> str:
385
+ return wandb.util.app_url(self.api_url)
386
+
387
+ @property
388
+ def default_entity(self) -> str:
389
+ return self.viewer().get("entity") # type: ignore
390
+
391
+ def settings(self, key: Optional[str] = None, section: Optional[str] = None) -> Any:
392
+ """The settings overridden from the wandb/settings file.
393
+
394
+ Arguments:
395
+ key (str, optional): If provided only this setting is returned
396
+ section (str, optional): If provided this section of the setting file is
397
+ used, defaults to "default"
398
+
399
+ Returns:
400
+ A dict with the current settings
401
+
402
+ {
403
+ "entity": "models",
404
+ "base_url": "https://api.wandb.ai",
405
+ "project": None
406
+ }
407
+ """
408
+ result = self.default_settings.copy()
409
+ result.update(self._settings.items(section=section)) # type: ignore
410
+ result.update(
411
+ {
412
+ "entity": env.get_entity(
413
+ self._settings.get(
414
+ Settings.DEFAULT_SECTION,
415
+ "entity",
416
+ fallback=result.get("entity"),
417
+ ),
418
+ env=self._environ,
419
+ ),
420
+ "project": env.get_project(
421
+ self._settings.get(
422
+ Settings.DEFAULT_SECTION,
423
+ "project",
424
+ fallback=result.get("project"),
425
+ ),
426
+ env=self._environ,
427
+ ),
428
+ "base_url": env.get_base_url(
429
+ self._settings.get(
430
+ Settings.DEFAULT_SECTION,
431
+ "base_url",
432
+ fallback=result.get("base_url"),
433
+ ),
434
+ env=self._environ,
435
+ ),
436
+ "ignore_globs": env.get_ignore(
437
+ self._settings.get(
438
+ Settings.DEFAULT_SECTION,
439
+ "ignore_globs",
440
+ fallback=result.get("ignore_globs"),
441
+ ),
442
+ env=self._environ,
443
+ ),
444
+ }
445
+ )
446
+
447
+ return result if key is None else result[key] # type: ignore
448
+
449
+ def clear_setting(
450
+ self, key: str, globally: bool = False, persist: bool = False
451
+ ) -> None:
452
+ self._settings.clear(
453
+ Settings.DEFAULT_SECTION, key, globally=globally, persist=persist
454
+ )
455
+
456
+ def set_setting(
457
+ self, key: str, value: Any, globally: bool = False, persist: bool = False
458
+ ) -> None:
459
+ self._settings.set(
460
+ Settings.DEFAULT_SECTION, key, value, globally=globally, persist=persist
461
+ )
462
+ if key == "entity":
463
+ env.set_entity(value, env=self._environ)
464
+ elif key == "project":
465
+ env.set_project(value, env=self._environ)
466
+ elif key == "base_url":
467
+ self.relocate()
468
+
469
+ def parse_slug(
470
+ self, slug: str, project: Optional[str] = None, run: Optional[str] = None
471
+ ) -> Tuple[str, str]:
472
+ """Parse a slug into a project and run.
473
+
474
+ Arguments:
475
+ slug (str): The slug to parse
476
+ project (str, optional): The project to use, if not provided it will be
477
+ inferred from the slug
478
+ run (str, optional): The run to use, if not provided it will be inferred
479
+ from the slug
480
+
481
+ Returns:
482
+ A dict with the project and run
483
+ """
484
+ if slug and "/" in slug:
485
+ parts = slug.split("/")
486
+ project = parts[0]
487
+ run = parts[1]
488
+ else:
489
+ project = project or self.settings().get("project")
490
+ if project is None:
491
+ raise CommError("No default project configured.")
492
+ run = run or slug or self.current_run_id or env.get_run(env=self._environ)
493
+ assert run, "run must be specified"
494
+ return project, run
495
+
496
+ @normalize_exceptions
497
+ def server_info_introspection(self) -> Tuple[List[str], List[str], List[str]]:
498
+ query_string = """
499
+ query ProbeServerCapabilities {
500
+ QueryType: __type(name: "Query") {
501
+ ...fieldData
502
+ }
503
+ MutationType: __type(name: "Mutation") {
504
+ ...fieldData
505
+ }
506
+ ServerInfoType: __type(name: "ServerInfo") {
507
+ ...fieldData
508
+ }
509
+ }
510
+
511
+ fragment fieldData on __Type {
512
+ fields {
513
+ name
514
+ }
515
+ }
516
+ """
517
+ if (
518
+ self.query_types is None
519
+ or self.mutation_types is None
520
+ or self.server_info_types is None
521
+ ):
522
+ query = gql(query_string)
523
+ res = self.gql(query)
524
+
525
+ self.query_types = [
526
+ field.get("name", "")
527
+ for field in res.get("QueryType", {}).get("fields", [{}])
528
+ ]
529
+ self.mutation_types = [
530
+ field.get("name", "")
531
+ for field in res.get("MutationType", {}).get("fields", [{}])
532
+ ]
533
+ self.server_info_types = [
534
+ field.get("name", "")
535
+ for field in res.get("ServerInfoType", {}).get("fields", [{}])
536
+ ]
537
+ return self.query_types, self.server_info_types, self.mutation_types
538
+
539
+ @normalize_exceptions
540
+ def server_settings_introspection(self) -> None:
541
+ query_string = """
542
+ query ProbeServerSettings {
543
+ ServerSettingsType: __type(name: "ServerSettings") {
544
+ ...fieldData
545
+ }
546
+ }
547
+
548
+ fragment fieldData on __Type {
549
+ fields {
550
+ name
551
+ }
552
+ }
553
+ """
554
+ if self._server_settings_type is None:
555
+ query = gql(query_string)
556
+ res = self.gql(query)
557
+ self._server_settings_type = (
558
+ [
559
+ field.get("name", "")
560
+ for field in res.get("ServerSettingsType", {}).get("fields", [{}])
561
+ ]
562
+ if res
563
+ else []
564
+ )
565
+
566
+ def server_use_artifact_input_introspection(self) -> List:
567
+ query_string = """
568
+ query ProbeServerUseArtifactInput {
569
+ UseArtifactInputInfoType: __type(name: "UseArtifactInput") {
570
+ name
571
+ inputFields {
572
+ name
573
+ }
574
+ }
575
+ }
576
+ """
577
+
578
+ if self.server_use_artifact_input_info is None:
579
+ query = gql(query_string)
580
+ res = self.gql(query)
581
+ self.server_use_artifact_input_info = [
582
+ field.get("name", "")
583
+ for field in res.get("UseArtifactInputInfoType", {}).get(
584
+ "inputFields", [{}]
585
+ )
586
+ ]
587
+ return self.server_use_artifact_input_info
588
+
589
+ @normalize_exceptions
590
+ def launch_agent_introspection(self) -> Optional[str]:
591
+ query = gql(
592
+ """
593
+ query LaunchAgentIntrospection {
594
+ LaunchAgentType: __type(name: "LaunchAgent") {
595
+ name
596
+ }
597
+ }
598
+ """
599
+ )
600
+
601
+ res = self.gql(query)
602
+ return res.get("LaunchAgentType") or None
603
+
604
+ @normalize_exceptions
605
+ def create_run_queue_introspection(self) -> Tuple[bool, bool, bool]:
606
+ _, _, mutations = self.server_info_introspection()
607
+ query_string = """
608
+ query ProbeCreateRunQueueInput {
609
+ CreateRunQueueInputType: __type(name: "CreateRunQueueInput") {
610
+ name
611
+ inputFields {
612
+ name
613
+ }
614
+ }
615
+ }
616
+ """
617
+ if (
618
+ self.server_create_run_queue_supports_drc is None
619
+ or self.server_create_run_queue_supports_priority is None
620
+ ):
621
+ query = gql(query_string)
622
+ res = self.gql(query)
623
+ if res is None:
624
+ raise CommError("Could not get CreateRunQueue input from GQL.")
625
+ self.server_create_run_queue_supports_drc = "defaultResourceConfigID" in [
626
+ x["name"]
627
+ for x in (
628
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
629
+ )
630
+ ]
631
+ self.server_create_run_queue_supports_priority = "prioritizationMode" in [
632
+ x["name"]
633
+ for x in (
634
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
635
+ )
636
+ ]
637
+ return (
638
+ "createRunQueue" in mutations,
639
+ self.server_create_run_queue_supports_drc,
640
+ self.server_create_run_queue_supports_priority,
641
+ )
642
+
643
+ @normalize_exceptions
644
+ def push_to_run_queue_introspection(self) -> Tuple[bool, bool]:
645
+ query_string = """
646
+ query ProbePushToRunQueueInput {
647
+ PushToRunQueueInputType: __type(name: "PushToRunQueueInput") {
648
+ name
649
+ inputFields {
650
+ name
651
+ }
652
+ }
653
+ }
654
+ """
655
+
656
+ if (
657
+ self.server_supports_template_variables is None
658
+ or self.server_push_to_run_queue_supports_priority is None
659
+ ):
660
+ query = gql(query_string)
661
+ res = self.gql(query)
662
+ self.server_supports_template_variables = "templateVariableValues" in [
663
+ x["name"]
664
+ for x in (
665
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
666
+ )
667
+ ]
668
+ self.server_push_to_run_queue_supports_priority = "priority" in [
669
+ x["name"]
670
+ for x in (
671
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
672
+ )
673
+ ]
674
+
675
+ return (
676
+ self.server_supports_template_variables,
677
+ self.server_push_to_run_queue_supports_priority,
678
+ )
679
+
680
+ @normalize_exceptions
681
+ def create_default_resource_config_introspection(self) -> bool:
682
+ _, _, mutations = self.server_info_introspection()
683
+ return "createDefaultResourceConfig" in mutations
684
+
685
+ @normalize_exceptions
686
+ def fail_run_queue_item_introspection(self) -> bool:
687
+ _, _, mutations = self.server_info_introspection()
688
+ return "failRunQueueItem" in mutations
689
+
690
+ @normalize_exceptions
691
+ def fail_run_queue_item_fields_introspection(self) -> List:
692
+ if self.fail_run_queue_item_input_info:
693
+ return self.fail_run_queue_item_input_info
694
+ query_string = """
695
+ query ProbeServerFailRunQueueItemInput {
696
+ FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") {
697
+ inputFields{
698
+ name
699
+ }
700
+ }
701
+ }
702
+ """
703
+
704
+ query = gql(query_string)
705
+ res = self.gql(query)
706
+
707
+ self.fail_run_queue_item_input_info = [
708
+ field.get("name", "")
709
+ for field in res.get("FailRunQueueItemInputInfoType", {}).get(
710
+ "inputFields", [{}]
711
+ )
712
+ ]
713
+ return self.fail_run_queue_item_input_info
714
+
715
+ @normalize_exceptions
716
+ def fail_run_queue_item(
717
+ self,
718
+ run_queue_item_id: str,
719
+ message: str,
720
+ stage: str,
721
+ file_paths: Optional[List[str]] = None,
722
+ ) -> bool:
723
+ if not self.fail_run_queue_item_introspection():
724
+ return False
725
+ variable_values: Dict[str, Union[str, Optional[List[str]]]] = {
726
+ "runQueueItemId": run_queue_item_id,
727
+ }
728
+ if "message" in self.fail_run_queue_item_fields_introspection():
729
+ variable_values.update({"message": message, "stage": stage})
730
+ if file_paths is not None:
731
+ variable_values["filePaths"] = file_paths
732
+ mutation_string = """
733
+ mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
734
+ failRunQueueItem(
735
+ input: {
736
+ runQueueItemId: $runQueueItemId
737
+ message: $message
738
+ stage: $stage
739
+ filePaths: $filePaths
740
+ }
741
+ ) {
742
+ success
743
+ }
744
+ }
745
+ """
746
+ else:
747
+ mutation_string = """
748
+ mutation failRunQueueItem($runQueueItemId: ID!) {
749
+ failRunQueueItem(
750
+ input: {
751
+ runQueueItemId: $runQueueItemId
752
+ }
753
+ ) {
754
+ success
755
+ }
756
+ }
757
+ """
758
+
759
+ mutation = gql(mutation_string)
760
+ response = self.gql(mutation, variable_values=variable_values)
761
+ result: bool = response["failRunQueueItem"]["success"]
762
+ return result
763
+
764
+ @normalize_exceptions
765
+ def update_run_queue_item_warning_introspection(self) -> bool:
766
+ _, _, mutations = self.server_info_introspection()
767
+ return "updateRunQueueItemWarning" in mutations
768
+
769
+ @normalize_exceptions
770
+ def update_run_queue_item_warning(
771
+ self,
772
+ run_queue_item_id: str,
773
+ message: str,
774
+ stage: str,
775
+ file_paths: Optional[List[str]] = None,
776
+ ) -> bool:
777
+ if not self.update_run_queue_item_warning_introspection():
778
+ return False
779
+ mutation = gql(
780
+ """
781
+ mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
782
+ updateRunQueueItemWarning(
783
+ input: {
784
+ runQueueItemId: $runQueueItemId
785
+ message: $message
786
+ stage: $stage
787
+ filePaths: $filePaths
788
+ }
789
+ ) {
790
+ success
791
+ }
792
+ }
793
+ """
794
+ )
795
+ response = self.gql(
796
+ mutation,
797
+ variable_values={
798
+ "runQueueItemId": run_queue_item_id,
799
+ "message": message,
800
+ "stage": stage,
801
+ "filePaths": file_paths,
802
+ },
803
+ )
804
+ result: bool = response["updateRunQueueItemWarning"]["success"]
805
+ return result
806
+
807
+ @normalize_exceptions
808
+ def viewer(self) -> Dict[str, Any]:
809
+ query = gql(
810
+ """
811
+ query Viewer{
812
+ viewer {
813
+ id
814
+ entity
815
+ username
816
+ flags
817
+ teams {
818
+ edges {
819
+ node {
820
+ name
821
+ }
822
+ }
823
+ }
824
+ }
825
+ }
826
+ """
827
+ )
828
+ res = self.gql(query)
829
+ return res.get("viewer") or {}
830
+
831
+ @normalize_exceptions
832
+ def max_cli_version(self) -> Optional[str]:
833
+ if self._max_cli_version is not None:
834
+ return self._max_cli_version
835
+
836
+ query_types, server_info_types, _ = self.server_info_introspection()
837
+ cli_version_exists = (
838
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
839
+ )
840
+ if not cli_version_exists:
841
+ return None
842
+
843
+ _, server_info = self.viewer_server_info()
844
+ self._max_cli_version = server_info.get("cliVersionInfo", {}).get(
845
+ "max_cli_version"
846
+ )
847
+ return self._max_cli_version
848
+
849
+ @normalize_exceptions
850
+ def viewer_server_info(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
851
+ local_query = """
852
+ latestLocalVersionInfo {
853
+ outOfDate
854
+ latestVersionString
855
+ versionOnThisInstanceString
856
+ }
857
+ """
858
+ cli_query = """
859
+ serverInfo {
860
+ cliVersionInfo
861
+ _LOCAL_QUERY_
862
+ }
863
+ """
864
+ query_template = """
865
+ query Viewer{
866
+ viewer {
867
+ id
868
+ entity
869
+ username
870
+ email
871
+ flags
872
+ teams {
873
+ edges {
874
+ node {
875
+ name
876
+ }
877
+ }
878
+ }
879
+ }
880
+ _CLI_QUERY_
881
+ }
882
+ """
883
+ query_types, server_info_types, _ = self.server_info_introspection()
884
+
885
+ cli_version_exists = (
886
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
887
+ )
888
+
889
+ local_version_exists = (
890
+ "serverInfo" in query_types
891
+ and "latestLocalVersionInfo" in server_info_types
892
+ )
893
+
894
+ cli_query_string = "" if not cli_version_exists else cli_query
895
+ local_query_string = "" if not local_version_exists else local_query
896
+
897
+ query_string = query_template.replace("_CLI_QUERY_", cli_query_string).replace(
898
+ "_LOCAL_QUERY_", local_query_string
899
+ )
900
+ query = gql(query_string)
901
+ res = self.gql(query)
902
+ return res.get("viewer") or {}, res.get("serverInfo") or {}
903
+
904
+ @normalize_exceptions
905
+ def list_projects(self, entity: Optional[str] = None) -> List[Dict[str, str]]:
906
+ """List projects in W&B scoped by entity.
907
+
908
+ Arguments:
909
+ entity (str, optional): The entity to scope this project to.
910
+
911
+ Returns:
912
+ [{"id","name","description"}]
913
+ """
914
+ query = gql(
915
+ """
916
+ query EntityProjects($entity: String) {
917
+ models(first: 10, entityName: $entity) {
918
+ edges {
919
+ node {
920
+ id
921
+ name
922
+ description
923
+ }
924
+ }
925
+ }
926
+ }
927
+ """
928
+ )
929
+ project_list: List[Dict[str, str]] = self._flatten_edges(
930
+ self.gql(
931
+ query, variable_values={"entity": entity or self.settings("entity")}
932
+ )["models"]
933
+ )
934
+ return project_list
935
+
936
+ @normalize_exceptions
937
+ def project(self, project: str, entity: Optional[str] = None) -> "_Response":
938
+ """Retrieve project.
939
+
940
+ Arguments:
941
+ project (str): The project to get details for
942
+ entity (str, optional): The entity to scope this project to.
943
+
944
+ Returns:
945
+ [{"id","name","repo","dockerImage","description"}]
946
+ """
947
+ query = gql(
948
+ """
949
+ query ProjectDetails($entity: String, $project: String) {
950
+ model(name: $project, entityName: $entity) {
951
+ id
952
+ name
953
+ repo
954
+ dockerImage
955
+ description
956
+ }
957
+ }
958
+ """
959
+ )
960
+ response: _Response = self.gql(
961
+ query, variable_values={"entity": entity, "project": project}
962
+ )["model"]
963
+ return response
964
+
965
+ @normalize_exceptions
966
+ def sweep(
967
+ self,
968
+ sweep: str,
969
+ specs: str,
970
+ project: Optional[str] = None,
971
+ entity: Optional[str] = None,
972
+ ) -> Dict[str, Any]:
973
+ """Retrieve sweep.
974
+
975
+ Arguments:
976
+ sweep (str): The sweep to get details for
977
+ specs (str): history specs
978
+ project (str, optional): The project to scope this sweep to.
979
+ entity (str, optional): The entity to scope this sweep to.
980
+
981
+ Returns:
982
+ [{"id","name","repo","dockerImage","description"}]
983
+ """
984
+ query = gql(
985
+ """
986
+ query SweepWithRuns($entity: String, $project: String, $sweep: String!, $specs: [JSONString!]!) {
987
+ project(name: $project, entityName: $entity) {
988
+ sweep(sweepName: $sweep) {
989
+ id
990
+ name
991
+ method
992
+ state
993
+ description
994
+ config
995
+ createdAt
996
+ heartbeatAt
997
+ updatedAt
998
+ earlyStopJobRunning
999
+ bestLoss
1000
+ controller
1001
+ scheduler
1002
+ runs {
1003
+ edges {
1004
+ node {
1005
+ name
1006
+ state
1007
+ config
1008
+ exitcode
1009
+ heartbeatAt
1010
+ shouldStop
1011
+ failed
1012
+ stopped
1013
+ running
1014
+ summaryMetrics
1015
+ sampledHistory(specs: $specs)
1016
+ }
1017
+ }
1018
+ }
1019
+ }
1020
+ }
1021
+ }
1022
+ """
1023
+ )
1024
+ entity = entity or self.settings("entity")
1025
+ project = project or self.settings("project")
1026
+ response = self.gql(
1027
+ query,
1028
+ variable_values={
1029
+ "entity": entity,
1030
+ "project": project,
1031
+ "sweep": sweep,
1032
+ "specs": specs,
1033
+ },
1034
+ )
1035
+ if response["project"] is None or response["project"]["sweep"] is None:
1036
+ raise ValueError(f"Sweep {entity}/{project}/{sweep} not found")
1037
+ data: Dict[str, Any] = response["project"]["sweep"]
1038
+ if data:
1039
+ data["runs"] = self._flatten_edges(data["runs"])
1040
+ return data
1041
+
1042
+ @normalize_exceptions
1043
+ def list_runs(
1044
+ self, project: str, entity: Optional[str] = None
1045
+ ) -> List[Dict[str, str]]:
1046
+ """List runs in W&B scoped by project.
1047
+
1048
+ Arguments:
1049
+ project (str): The project to scope the runs to
1050
+ entity (str, optional): The entity to scope this project to. Defaults to public models
1051
+
1052
+ Returns:
1053
+ [{"id","name","description"}]
1054
+ """
1055
+ query = gql(
1056
+ """
1057
+ query ProjectRuns($model: String!, $entity: String) {
1058
+ model(name: $model, entityName: $entity) {
1059
+ buckets(first: 10) {
1060
+ edges {
1061
+ node {
1062
+ id
1063
+ name
1064
+ displayName
1065
+ description
1066
+ }
1067
+ }
1068
+ }
1069
+ }
1070
+ }
1071
+ """
1072
+ )
1073
+ return self._flatten_edges(
1074
+ self.gql(
1075
+ query,
1076
+ variable_values={
1077
+ "entity": entity or self.settings("entity"),
1078
+ "model": project or self.settings("project"),
1079
+ },
1080
+ )["model"]["buckets"]
1081
+ )
1082
+
1083
+ @normalize_exceptions
1084
+ def run_config(
1085
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
1086
+ ) -> Tuple[str, Dict[str, Any], Optional[str], Dict[str, Any]]:
1087
+ """Get the relevant configs for a run.
1088
+
1089
+ Arguments:
1090
+ project (str): The project to download, (can include bucket)
1091
+ run (str, optional): The run to download
1092
+ entity (str, optional): The entity to scope this project to.
1093
+ """
1094
+ check_httpclient_logger_handler()
1095
+
1096
+ query = gql(
1097
+ """
1098
+ query RunConfigs(
1099
+ $name: String!,
1100
+ $entity: String,
1101
+ $run: String!,
1102
+ $pattern: String!,
1103
+ $includeConfig: Boolean!,
1104
+ ) {
1105
+ model(name: $name, entityName: $entity) {
1106
+ bucket(name: $run) {
1107
+ config @include(if: $includeConfig)
1108
+ commit @include(if: $includeConfig)
1109
+ files(pattern: $pattern) {
1110
+ pageInfo {
1111
+ hasNextPage
1112
+ endCursor
1113
+ }
1114
+ edges {
1115
+ node {
1116
+ name
1117
+ directUrl
1118
+ }
1119
+ }
1120
+ }
1121
+ }
1122
+ }
1123
+ }
1124
+ """
1125
+ )
1126
+
1127
+ variable_values = {
1128
+ "name": project,
1129
+ "run": run,
1130
+ "entity": entity,
1131
+ "includeConfig": True,
1132
+ }
1133
+
1134
+ commit: str = ""
1135
+ config: Dict[str, Any] = {}
1136
+ patch: Optional[str] = None
1137
+ metadata: Dict[str, Any] = {}
1138
+
1139
+ # If we use the `names` parameter on the `files` node, then the server
1140
+ # will helpfully give us and 'open' file handle to the files that don't
1141
+ # exist. This is so that we can upload data to it. However, in this
1142
+ # case, we just want to download that file and not upload to it, so
1143
+ # let's instead query for the files that do exist using `pattern`
1144
+ # (with no wildcards).
1145
+ #
1146
+ # Unfortunately we're unable to construct a single pattern that matches
1147
+ # our 2 files, we would need something like regex for that.
1148
+ for filename in [DIFF_FNAME, METADATA_FNAME]:
1149
+ variable_values["pattern"] = filename
1150
+ response = self.gql(query, variable_values=variable_values)
1151
+ if response["model"] is None:
1152
+ raise CommError(f"Run {entity}/{project}/{run} not found")
1153
+ run_obj: Dict = response["model"]["bucket"]
1154
+ # we only need to fetch this config once
1155
+ if variable_values["includeConfig"]:
1156
+ commit = run_obj["commit"]
1157
+ config = json.loads(run_obj["config"] or "{}")
1158
+ variable_values["includeConfig"] = False
1159
+ if run_obj["files"] is not None:
1160
+ for file_edge in run_obj["files"]["edges"]:
1161
+ name = file_edge["node"]["name"]
1162
+ url = file_edge["node"]["directUrl"]
1163
+ res = requests.get(url)
1164
+ res.raise_for_status()
1165
+ if name == METADATA_FNAME:
1166
+ metadata = res.json()
1167
+ elif name == DIFF_FNAME:
1168
+ patch = res.text
1169
+
1170
+ return commit, config, patch, metadata
1171
+
1172
+ @normalize_exceptions
1173
+ def run_resume_status(
1174
+ self, entity: str, project_name: str, name: str
1175
+ ) -> Optional[Dict[str, Any]]:
1176
+ """Check if a run exists and get resume information.
1177
+
1178
+ Arguments:
1179
+ entity (str): The entity to scope this project to.
1180
+ project_name (str): The project to download, (can include bucket)
1181
+ name (str): The run to download
1182
+ """
1183
+ query = gql(
1184
+ """
1185
+ query RunResumeStatus($project: String, $entity: String, $name: String!) {
1186
+ model(name: $project, entityName: $entity) {
1187
+ id
1188
+ name
1189
+ entity {
1190
+ id
1191
+ name
1192
+ }
1193
+
1194
+ bucket(name: $name, missingOk: true) {
1195
+ id
1196
+ name
1197
+ summaryMetrics
1198
+ displayName
1199
+ logLineCount
1200
+ historyLineCount
1201
+ eventsLineCount
1202
+ historyTail
1203
+ eventsTail
1204
+ config
1205
+ tags
1206
+ }
1207
+ }
1208
+ }
1209
+ """
1210
+ )
1211
+
1212
+ response = self.gql(
1213
+ query,
1214
+ variable_values={
1215
+ "entity": entity,
1216
+ "project": project_name,
1217
+ "name": name,
1218
+ },
1219
+ )
1220
+
1221
+ if "model" not in response or "bucket" not in (response["model"] or {}):
1222
+ return None
1223
+
1224
+ project = response["model"]
1225
+ self.set_setting("project", project_name)
1226
+ if "entity" in project:
1227
+ self.set_setting("entity", project["entity"]["name"])
1228
+
1229
+ result: Dict[str, Any] = project["bucket"]
1230
+
1231
+ return result
1232
+
1233
+ @normalize_exceptions
1234
+ def check_stop_requested(
1235
+ self, project_name: str, entity_name: str, run_id: str
1236
+ ) -> bool:
1237
+ query = gql(
1238
+ """
1239
+ query RunStoppedStatus($projectName: String, $entityName: String, $runId: String!) {
1240
+ project(name:$projectName, entityName:$entityName) {
1241
+ run(name:$runId) {
1242
+ stopped
1243
+ }
1244
+ }
1245
+ }
1246
+ """
1247
+ )
1248
+
1249
+ response = self.gql(
1250
+ query,
1251
+ variable_values={
1252
+ "projectName": project_name,
1253
+ "entityName": entity_name,
1254
+ "runId": run_id,
1255
+ },
1256
+ )
1257
+
1258
+ project = response.get("project", None)
1259
+ if not project:
1260
+ return False
1261
+ run = project.get("run", None)
1262
+ if not run:
1263
+ return False
1264
+
1265
+ status: bool = run["stopped"]
1266
+ return status
1267
+
1268
+ def format_project(self, project: str) -> str:
1269
+ return re.sub(r"\W+", "-", project.lower()).strip("-_")
1270
+
1271
+ @normalize_exceptions
1272
+ def upsert_project(
1273
+ self,
1274
+ project: str,
1275
+ id: Optional[str] = None,
1276
+ description: Optional[str] = None,
1277
+ entity: Optional[str] = None,
1278
+ ) -> Dict[str, Any]:
1279
+ """Create a new project.
1280
+
1281
+ Arguments:
1282
+ project (str): The project to create
1283
+ description (str, optional): A description of this project
1284
+ entity (str, optional): The entity to scope this project to.
1285
+ """
1286
+ mutation = gql(
1287
+ """
1288
+ mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) {
1289
+ upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
1290
+ model {
1291
+ name
1292
+ description
1293
+ }
1294
+ }
1295
+ }
1296
+ """
1297
+ )
1298
+ response = self.gql(
1299
+ mutation,
1300
+ variable_values={
1301
+ "name": self.format_project(project),
1302
+ "entity": entity or self.settings("entity"),
1303
+ "description": description,
1304
+ "id": id,
1305
+ },
1306
+ )
1307
+ # TODO(jhr): Commenting out 'repo' field for cling, add back
1308
+ # 'description': description, 'repo': self.git.remote_url, 'id': id})
1309
+ result: Dict[str, Any] = response["upsertModel"]["model"]
1310
+ return result
1311
+
1312
+ @normalize_exceptions
1313
+ def entity_is_team(self, entity: str) -> bool:
1314
+ query = gql(
1315
+ """
1316
+ query EntityIsTeam($entity: String!) {
1317
+ entity(name: $entity) {
1318
+ id
1319
+ isTeam
1320
+ }
1321
+ }
1322
+ """
1323
+ )
1324
+ variable_values = {
1325
+ "entity": entity,
1326
+ }
1327
+
1328
+ res = self.gql(query, variable_values)
1329
+ if res.get("entity") is None:
1330
+ raise Exception(
1331
+ f"Error fetching entity {entity} "
1332
+ "check that you have access to this entity"
1333
+ )
1334
+
1335
+ is_team: bool = res["entity"]["isTeam"]
1336
+ return is_team
1337
+
1338
+ @normalize_exceptions
1339
+ def get_project_run_queues(self, entity: str, project: str) -> List[Dict[str, str]]:
1340
+ query = gql(
1341
+ """
1342
+ query ProjectRunQueues($entity: String!, $projectName: String!){
1343
+ project(entityName: $entity, name: $projectName) {
1344
+ runQueues {
1345
+ id
1346
+ name
1347
+ createdBy
1348
+ access
1349
+ }
1350
+ }
1351
+ }
1352
+ """
1353
+ )
1354
+ variable_values = {
1355
+ "projectName": project,
1356
+ "entity": entity,
1357
+ }
1358
+
1359
+ res = self.gql(query, variable_values)
1360
+ if res.get("project") is None:
1361
+ # circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry)
1362
+ if project == "model-registry":
1363
+ msg = (
1364
+ f"Error fetching run queues for {entity} "
1365
+ "check that you have access to this entity and project"
1366
+ )
1367
+ else:
1368
+ msg = (
1369
+ f"Error fetching run queues for {entity}/{project} "
1370
+ "check that you have access to this entity and project"
1371
+ )
1372
+
1373
+ raise Exception(msg)
1374
+
1375
+ project_run_queues: List[Dict[str, str]] = res["project"]["runQueues"]
1376
+ return project_run_queues
1377
+
1378
+ @normalize_exceptions
1379
+ def create_default_resource_config(
1380
+ self,
1381
+ entity: str,
1382
+ resource: str,
1383
+ config: str,
1384
+ template_variables: Optional[Dict[str, Union[float, int, str]]],
1385
+ ) -> Optional[Dict[str, Any]]:
1386
+ if not self.create_default_resource_config_introspection():
1387
+ raise Exception()
1388
+ supports_template_vars, _ = self.push_to_run_queue_introspection()
1389
+
1390
+ mutation_params = """
1391
+ $entityName: String!,
1392
+ $resource: String!,
1393
+ $config: JSONString!
1394
+ """
1395
+ mutation_inputs = """
1396
+ entityName: $entityName,
1397
+ resource: $resource,
1398
+ config: $config
1399
+ """
1400
+
1401
+ if supports_template_vars:
1402
+ mutation_params += ", $templateVariables: JSONString"
1403
+ mutation_inputs += ", templateVariables: $templateVariables"
1404
+ else:
1405
+ if template_variables is not None:
1406
+ raise UnsupportedError(
1407
+ "server does not support template variables, please update server instance to >=0.46"
1408
+ )
1409
+
1410
+ variable_values = {
1411
+ "entityName": entity,
1412
+ "resource": resource,
1413
+ "config": config,
1414
+ }
1415
+ if supports_template_vars:
1416
+ if template_variables is not None:
1417
+ variable_values["templateVariables"] = json.dumps(template_variables)
1418
+ else:
1419
+ variable_values["templateVariables"] = "{}"
1420
+
1421
+ query = gql(
1422
+ f"""
1423
+ mutation createDefaultResourceConfig(
1424
+ {mutation_params}
1425
+ ) {{
1426
+ createDefaultResourceConfig(
1427
+ input: {{
1428
+ {mutation_inputs}
1429
+ }}
1430
+ ) {{
1431
+ defaultResourceConfigID
1432
+ success
1433
+ }}
1434
+ }}
1435
+ """
1436
+ )
1437
+
1438
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1439
+ "createDefaultResourceConfig"
1440
+ ]
1441
+ return result
1442
+
1443
+ @normalize_exceptions
1444
+ def create_run_queue(
1445
+ self,
1446
+ entity: str,
1447
+ project: str,
1448
+ queue_name: str,
1449
+ access: str,
1450
+ prioritization_mode: Optional[str] = None,
1451
+ config_id: Optional[str] = None,
1452
+ ) -> Optional[Dict[str, Any]]:
1453
+ (
1454
+ create_run_queue,
1455
+ supports_drc,
1456
+ supports_prioritization,
1457
+ ) = self.create_run_queue_introspection()
1458
+ if not create_run_queue:
1459
+ raise UnsupportedError(
1460
+ "run queue creation is not supported by this version of "
1461
+ "wandb server. Consider updating to the latest version."
1462
+ )
1463
+ if not supports_drc and config_id is not None:
1464
+ raise UnsupportedError(
1465
+ "default resource configurations are not supported by this version "
1466
+ "of wandb server. Consider updating to the latest version."
1467
+ )
1468
+ if not supports_prioritization and prioritization_mode is not None:
1469
+ raise UnsupportedError(
1470
+ "launch prioritization is not supported by this version of "
1471
+ "wandb server. Consider updating to the latest version."
1472
+ )
1473
+
1474
+ if supports_prioritization:
1475
+ query = gql(
1476
+ """
1477
+ mutation createRunQueue(
1478
+ $entity: String!,
1479
+ $project: String!,
1480
+ $queueName: String!,
1481
+ $access: RunQueueAccessType!,
1482
+ $prioritizationMode: RunQueuePrioritizationMode,
1483
+ $defaultResourceConfigID: ID,
1484
+ ) {
1485
+ createRunQueue(
1486
+ input: {
1487
+ entityName: $entity,
1488
+ projectName: $project,
1489
+ queueName: $queueName,
1490
+ access: $access,
1491
+ prioritizationMode: $prioritizationMode
1492
+ defaultResourceConfigID: $defaultResourceConfigID
1493
+ }
1494
+ ) {
1495
+ success
1496
+ queueID
1497
+ }
1498
+ }
1499
+ """
1500
+ )
1501
+ variable_values = {
1502
+ "entity": entity,
1503
+ "project": project,
1504
+ "queueName": queue_name,
1505
+ "access": access,
1506
+ "prioritizationMode": prioritization_mode,
1507
+ "defaultResourceConfigID": config_id,
1508
+ }
1509
+ else:
1510
+ query = gql(
1511
+ """
1512
+ mutation createRunQueue(
1513
+ $entity: String!,
1514
+ $project: String!,
1515
+ $queueName: String!,
1516
+ $access: RunQueueAccessType!,
1517
+ $defaultResourceConfigID: ID,
1518
+ ) {
1519
+ createRunQueue(
1520
+ input: {
1521
+ entityName: $entity,
1522
+ projectName: $project,
1523
+ queueName: $queueName,
1524
+ access: $access,
1525
+ defaultResourceConfigID: $defaultResourceConfigID
1526
+ }
1527
+ ) {
1528
+ success
1529
+ queueID
1530
+ }
1531
+ }
1532
+ """
1533
+ )
1534
+ variable_values = {
1535
+ "entity": entity,
1536
+ "project": project,
1537
+ "queueName": queue_name,
1538
+ "access": access,
1539
+ "defaultResourceConfigID": config_id,
1540
+ }
1541
+
1542
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1543
+ "createRunQueue"
1544
+ ]
1545
+ return result
1546
+
1547
+ @normalize_exceptions
1548
+ def push_to_run_queue_by_name(
1549
+ self,
1550
+ entity: str,
1551
+ project: str,
1552
+ queue_name: str,
1553
+ run_spec: str,
1554
+ template_variables: Optional[Dict[str, Union[int, float, str]]],
1555
+ priority: Optional[int] = None,
1556
+ ) -> Optional[Dict[str, Any]]:
1557
+ self.push_to_run_queue_introspection()
1558
+ """Queryless mutation, should be used before legacy fallback method."""
1559
+
1560
+ mutation_params = """
1561
+ $entityName: String!,
1562
+ $projectName: String!,
1563
+ $queueName: String!,
1564
+ $runSpec: JSONString!
1565
+ """
1566
+
1567
+ mutation_input = """
1568
+ entityName: $entityName,
1569
+ projectName: $projectName,
1570
+ queueName: $queueName,
1571
+ runSpec: $runSpec
1572
+ """
1573
+
1574
+ variables: Dict[str, Any] = {
1575
+ "entityName": entity,
1576
+ "projectName": project,
1577
+ "queueName": queue_name,
1578
+ "runSpec": run_spec,
1579
+ }
1580
+ if self.server_push_to_run_queue_supports_priority:
1581
+ if priority is not None:
1582
+ variables["priority"] = priority
1583
+ mutation_params += ", $priority: Int"
1584
+ mutation_input += ", priority: $priority"
1585
+ else:
1586
+ if priority is not None:
1587
+ raise UnsupportedError(
1588
+ "server does not support priority, please update server instance to >=0.46"
1589
+ )
1590
+
1591
+ if self.server_supports_template_variables:
1592
+ if template_variables is not None:
1593
+ variables.update(
1594
+ {"templateVariableValues": json.dumps(template_variables)}
1595
+ )
1596
+ mutation_params += ", $templateVariableValues: JSONString"
1597
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1598
+ else:
1599
+ if template_variables is not None:
1600
+ raise UnsupportedError(
1601
+ "server does not support template variables, please update server instance to >=0.46"
1602
+ )
1603
+
1604
+ mutation = gql(
1605
+ f"""
1606
+ mutation pushToRunQueueByName(
1607
+ {mutation_params}
1608
+ ) {{
1609
+ pushToRunQueueByName(
1610
+ input: {{
1611
+ {mutation_input}
1612
+ }}
1613
+ ) {{
1614
+ runQueueItemId
1615
+ runSpec
1616
+ }}
1617
+ }}
1618
+ """
1619
+ )
1620
+
1621
+ try:
1622
+ result: Optional[Dict[str, Any]] = self.gql(
1623
+ mutation, variables, check_retry_fn=util.no_retry_4xx
1624
+ ).get("pushToRunQueueByName")
1625
+ if not result:
1626
+ return None
1627
+
1628
+ if result.get("runSpec"):
1629
+ run_spec = json.loads(str(result["runSpec"]))
1630
+ result["runSpec"] = run_spec
1631
+
1632
+ return result
1633
+ except Exception as e:
1634
+ if (
1635
+ 'Cannot query field "runSpec" on type "PushToRunQueueByNamePayload"'
1636
+ not in str(e)
1637
+ ):
1638
+ return None
1639
+
1640
+ mutation_no_runspec = gql(
1641
+ """
1642
+ mutation pushToRunQueueByName(
1643
+ $entityName: String!,
1644
+ $projectName: String!,
1645
+ $queueName: String!,
1646
+ $runSpec: JSONString!,
1647
+ ) {
1648
+ pushToRunQueueByName(
1649
+ input: {
1650
+ entityName: $entityName,
1651
+ projectName: $projectName,
1652
+ queueName: $queueName,
1653
+ runSpec: $runSpec
1654
+ }
1655
+ ) {
1656
+ runQueueItemId
1657
+ }
1658
+ }
1659
+ """
1660
+ )
1661
+
1662
+ try:
1663
+ result = self.gql(
1664
+ mutation_no_runspec, variables, check_retry_fn=util.no_retry_4xx
1665
+ ).get("pushToRunQueueByName")
1666
+ except Exception:
1667
+ result = None
1668
+
1669
+ return result
1670
+
1671
+ @normalize_exceptions
1672
+ def push_to_run_queue(
1673
+ self,
1674
+ queue_name: str,
1675
+ launch_spec: Dict[str, str],
1676
+ template_variables: Optional[dict],
1677
+ project_queue: str,
1678
+ priority: Optional[int] = None,
1679
+ ) -> Optional[Dict[str, Any]]:
1680
+ self.push_to_run_queue_introspection()
1681
+ entity = launch_spec.get("queue_entity") or launch_spec["entity"]
1682
+ run_spec = json.dumps(launch_spec)
1683
+
1684
+ push_result = self.push_to_run_queue_by_name(
1685
+ entity, project_queue, queue_name, run_spec, template_variables, priority
1686
+ )
1687
+
1688
+ if push_result:
1689
+ return push_result
1690
+
1691
+ if priority is not None:
1692
+ # Cannot proceed with legacy method if priority is set
1693
+ return None
1694
+
1695
+ """ Legacy Method """
1696
+ queues_found = self.get_project_run_queues(entity, project_queue)
1697
+ matching_queues = [
1698
+ q
1699
+ for q in queues_found
1700
+ if q["name"] == queue_name
1701
+ # ensure user has access to queue
1702
+ and (
1703
+ # TODO: User created queues in the UI have USER access
1704
+ q["access"] in ["PROJECT", "USER"]
1705
+ or q["createdBy"] == self.default_entity
1706
+ )
1707
+ ]
1708
+ if not matching_queues:
1709
+ # in the case of a missing default queue. create it
1710
+ if queue_name == "default":
1711
+ wandb.termlog(
1712
+ f"No default queue existing for entity: {entity} in project: {project_queue}, creating one."
1713
+ )
1714
+ res = self.create_run_queue(
1715
+ launch_spec["entity"],
1716
+ project_queue,
1717
+ queue_name,
1718
+ access="PROJECT",
1719
+ )
1720
+
1721
+ if res is None or res.get("queueID") is None:
1722
+ wandb.termerror(
1723
+ f"Unable to create default queue for entity: {entity} on project: {project_queue}. Run could not be added to a queue"
1724
+ )
1725
+ return None
1726
+ queue_id = res["queueID"]
1727
+
1728
+ else:
1729
+ if project_queue == "model-registry":
1730
+ _msg = f"Unable to push to run queue {queue_name}. Queue not found."
1731
+ else:
1732
+ _msg = f"Unable to push to run queue {project_queue}/{queue_name}. Queue not found."
1733
+ wandb.termwarn(_msg)
1734
+ return None
1735
+ elif len(matching_queues) > 1:
1736
+ wandb.termerror(
1737
+ f"Unable to push to run queue {queue_name}. More than one queue found with this name."
1738
+ )
1739
+ return None
1740
+ else:
1741
+ queue_id = matching_queues[0]["id"]
1742
+ spec_json = json.dumps(launch_spec)
1743
+ variables = {"queueID": queue_id, "runSpec": spec_json}
1744
+
1745
+ mutation_params = """
1746
+ $queueID: ID!,
1747
+ $runSpec: JSONString!
1748
+ """
1749
+ mutation_input = """
1750
+ queueID: $queueID,
1751
+ runSpec: $runSpec
1752
+ """
1753
+ if self.server_supports_template_variables:
1754
+ if template_variables is not None:
1755
+ mutation_params += ", $templateVariableValues: JSONString"
1756
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1757
+ variables.update(
1758
+ {"templateVariableValues": json.dumps(template_variables)}
1759
+ )
1760
+ else:
1761
+ if template_variables is not None:
1762
+ raise UnsupportedError(
1763
+ "server does not support template variables, please update server instance to >=0.46"
1764
+ )
1765
+
1766
+ mutation = gql(
1767
+ f"""
1768
+ mutation pushToRunQueue(
1769
+ {mutation_params}
1770
+ ) {{
1771
+ pushToRunQueue(
1772
+ input: {{{mutation_input}}}
1773
+ ) {{
1774
+ runQueueItemId
1775
+ }}
1776
+ }}
1777
+ """
1778
+ )
1779
+
1780
+ response = self.gql(mutation, variable_values=variables)
1781
+ if not response.get("pushToRunQueue"):
1782
+ raise CommError(f"Error pushing run queue item to queue {queue_name}.")
1783
+
1784
+ result: Optional[Dict[str, Any]] = response["pushToRunQueue"]
1785
+ return result
1786
+
1787
+ @normalize_exceptions
1788
+ def pop_from_run_queue(
1789
+ self,
1790
+ queue_name: str,
1791
+ entity: Optional[str] = None,
1792
+ project: Optional[str] = None,
1793
+ agent_id: Optional[str] = None,
1794
+ ) -> Optional[Dict[str, Any]]:
1795
+ mutation = gql(
1796
+ """
1797
+ mutation popFromRunQueue($entity: String!, $project: String!, $queueName: String!, $launchAgentId: ID) {
1798
+ popFromRunQueue(input: {
1799
+ entityName: $entity,
1800
+ projectName: $project,
1801
+ queueName: $queueName,
1802
+ launchAgentId: $launchAgentId
1803
+ }) {
1804
+ runQueueItemId
1805
+ runSpec
1806
+ }
1807
+ }
1808
+ """
1809
+ )
1810
+ response = self.gql(
1811
+ mutation,
1812
+ variable_values={
1813
+ "entity": entity,
1814
+ "project": project,
1815
+ "queueName": queue_name,
1816
+ "launchAgentId": agent_id,
1817
+ },
1818
+ )
1819
+ result: Optional[Dict[str, Any]] = response["popFromRunQueue"]
1820
+ return result
1821
+
1822
+ @normalize_exceptions
1823
+ def ack_run_queue_item(self, item_id: str, run_id: Optional[str] = None) -> bool:
1824
+ mutation = gql(
1825
+ """
1826
+ mutation ackRunQueueItem($itemId: ID!, $runId: String!) {
1827
+ ackRunQueueItem(input: { runQueueItemId: $itemId, runName: $runId }) {
1828
+ success
1829
+ }
1830
+ }
1831
+ """
1832
+ )
1833
+ response = self.gql(
1834
+ mutation, variable_values={"itemId": item_id, "runId": str(run_id)}
1835
+ )
1836
+ if not response["ackRunQueueItem"]["success"]:
1837
+ raise CommError(
1838
+ "Error acking run queue item. Item may have already been acknowledged by another process"
1839
+ )
1840
+ result: bool = response["ackRunQueueItem"]["success"]
1841
+ return result
1842
+
1843
+ @normalize_exceptions
1844
+ def create_launch_agent_fields_introspection(self) -> List:
1845
+ if self.create_launch_agent_input_info:
1846
+ return self.create_launch_agent_input_info
1847
+ query_string = """
1848
+ query ProbeServerCreateLaunchAgentInput {
1849
+ CreateLaunchAgentInputInfoType: __type(name:"CreateLaunchAgentInput") {
1850
+ inputFields{
1851
+ name
1852
+ }
1853
+ }
1854
+ }
1855
+ """
1856
+
1857
+ query = gql(query_string)
1858
+ res = self.gql(query)
1859
+
1860
+ self.create_launch_agent_input_info = [
1861
+ field.get("name", "")
1862
+ for field in res.get("CreateLaunchAgentInputInfoType", {}).get(
1863
+ "inputFields", [{}]
1864
+ )
1865
+ ]
1866
+ return self.create_launch_agent_input_info
1867
+
1868
+ @normalize_exceptions
1869
+ def create_launch_agent(
1870
+ self,
1871
+ entity: str,
1872
+ project: str,
1873
+ queues: List[str],
1874
+ agent_config: Dict[str, Any],
1875
+ version: str,
1876
+ gorilla_agent_support: bool,
1877
+ ) -> dict:
1878
+ project_queues = self.get_project_run_queues(entity, project)
1879
+ if not project_queues:
1880
+ # create default queue if it doesn't already exist
1881
+ default = self.create_run_queue(
1882
+ entity, project, "default", access="PROJECT"
1883
+ )
1884
+ if default is None or default.get("queueID") is None:
1885
+ raise CommError(
1886
+ "Unable to create default queue for {}/{}. No queues for agent to poll".format(
1887
+ entity, project
1888
+ )
1889
+ )
1890
+ project_queues = [{"id": default["queueID"], "name": "default"}]
1891
+ polling_queue_ids = [
1892
+ q["id"] for q in project_queues if q["name"] in queues
1893
+ ] # filter to poll specified queues
1894
+ if len(polling_queue_ids) != len(queues):
1895
+ raise CommError(
1896
+ f"Could not start launch agent: Not all of requested queues ({', '.join(queues)}) found. "
1897
+ f"Available queues for this project: {','.join([q['name'] for q in project_queues])}"
1898
+ )
1899
+
1900
+ if not gorilla_agent_support:
1901
+ # if gorilla doesn't support launch agents, return a client-generated id
1902
+ return {
1903
+ "success": True,
1904
+ "launchAgentId": None,
1905
+ }
1906
+
1907
+ hostname = socket.gethostname()
1908
+
1909
+ variable_values = {
1910
+ "entity": entity,
1911
+ "project": project,
1912
+ "queues": polling_queue_ids,
1913
+ "hostname": hostname,
1914
+ }
1915
+
1916
+ mutation_params = """
1917
+ $entity: String!,
1918
+ $project: String!,
1919
+ $queues: [ID!]!,
1920
+ $hostname: String!
1921
+ """
1922
+
1923
+ mutation_input = """
1924
+ entityName: $entity,
1925
+ projectName: $project,
1926
+ runQueues: $queues,
1927
+ hostname: $hostname
1928
+ """
1929
+
1930
+ if "agentConfig" in self.create_launch_agent_fields_introspection():
1931
+ variable_values["agentConfig"] = json.dumps(agent_config)
1932
+ mutation_params += ", $agentConfig: JSONString"
1933
+ mutation_input += ", agentConfig: $agentConfig"
1934
+ if "version" in self.create_launch_agent_fields_introspection():
1935
+ variable_values["version"] = version
1936
+ mutation_params += ", $version: String"
1937
+ mutation_input += ", version: $version"
1938
+
1939
+ mutation = gql(
1940
+ f"""
1941
+ mutation createLaunchAgent(
1942
+ {mutation_params}
1943
+ ) {{
1944
+ createLaunchAgent(
1945
+ input: {{
1946
+ {mutation_input}
1947
+ }}
1948
+ ) {{
1949
+ launchAgentId
1950
+ }}
1951
+ }}
1952
+ """
1953
+ )
1954
+ result: dict = self.gql(mutation, variable_values)["createLaunchAgent"]
1955
+ return result
1956
+
1957
+ @normalize_exceptions
1958
+ def update_launch_agent_status(
1959
+ self,
1960
+ agent_id: str,
1961
+ status: str,
1962
+ gorilla_agent_support: bool,
1963
+ ) -> dict:
1964
+ if not gorilla_agent_support:
1965
+ # if gorilla doesn't support launch agents, this is a no-op
1966
+ return {
1967
+ "success": True,
1968
+ }
1969
+
1970
+ mutation = gql(
1971
+ """
1972
+ mutation updateLaunchAgent($agentId: ID!, $agentStatus: String){
1973
+ updateLaunchAgent(
1974
+ input: {
1975
+ launchAgentId: $agentId
1976
+ agentStatus: $agentStatus
1977
+ }
1978
+ ) {
1979
+ success
1980
+ }
1981
+ }
1982
+ """
1983
+ )
1984
+ variable_values = {
1985
+ "agentId": agent_id,
1986
+ "agentStatus": status,
1987
+ }
1988
+ result: dict = self.gql(mutation, variable_values)["updateLaunchAgent"]
1989
+ return result
1990
+
1991
+ @normalize_exceptions
1992
+ def get_launch_agent(self, agent_id: str, gorilla_agent_support: bool) -> dict:
1993
+ if not gorilla_agent_support:
1994
+ return {
1995
+ "id": None,
1996
+ "name": "",
1997
+ "stopPolling": False,
1998
+ }
1999
+ query = gql(
2000
+ """
2001
+ query LaunchAgent($agentId: ID!) {
2002
+ launchAgent(id: $agentId) {
2003
+ id
2004
+ name
2005
+ runQueues
2006
+ hostname
2007
+ agentStatus
2008
+ stopPolling
2009
+ heartbeatAt
2010
+ }
2011
+ }
2012
+ """
2013
+ )
2014
+ variable_values = {
2015
+ "agentId": agent_id,
2016
+ }
2017
+ result: dict = self.gql(query, variable_values)["launchAgent"]
2018
+ return result
2019
+
2020
+ @normalize_exceptions
2021
+ def upsert_run(
2022
+ self,
2023
+ id: Optional[str] = None,
2024
+ name: Optional[str] = None,
2025
+ project: Optional[str] = None,
2026
+ host: Optional[str] = None,
2027
+ group: Optional[str] = None,
2028
+ tags: Optional[List[str]] = None,
2029
+ config: Optional[dict] = None,
2030
+ description: Optional[str] = None,
2031
+ entity: Optional[str] = None,
2032
+ state: Optional[str] = None,
2033
+ display_name: Optional[str] = None,
2034
+ notes: Optional[str] = None,
2035
+ repo: Optional[str] = None,
2036
+ job_type: Optional[str] = None,
2037
+ program_path: Optional[str] = None,
2038
+ commit: Optional[str] = None,
2039
+ sweep_name: Optional[str] = None,
2040
+ summary_metrics: Optional[str] = None,
2041
+ num_retries: Optional[int] = None,
2042
+ ) -> Tuple[dict, bool, Optional[List]]:
2043
+ """Update a run.
2044
+
2045
+ Arguments:
2046
+ id (str, optional): The existing run to update
2047
+ name (str, optional): The name of the run to create
2048
+ group (str, optional): Name of the group this run is a part of
2049
+ project (str, optional): The name of the project
2050
+ host (str, optional): The name of the host
2051
+ tags (list, optional): A list of tags to apply to the run
2052
+ config (dict, optional): The latest config params
2053
+ description (str, optional): A description of this project
2054
+ entity (str, optional): The entity to scope this project to.
2055
+ display_name (str, optional): The display name of this project
2056
+ notes (str, optional): Notes about this run
2057
+ repo (str, optional): Url of the program's repository.
2058
+ state (str, optional): State of the program.
2059
+ job_type (str, optional): Type of job, e.g 'train'.
2060
+ program_path (str, optional): Path to the program.
2061
+ commit (str, optional): The Git SHA to associate the run with
2062
+ sweep_name (str, optional): The name of the sweep this run is a part of
2063
+ summary_metrics (str, optional): The JSON summary metrics
2064
+ num_retries (int, optional): Number of retries
2065
+ """
2066
+ query_string = """
2067
+ mutation UpsertBucket(
2068
+ $id: String,
2069
+ $name: String,
2070
+ $project: String,
2071
+ $entity: String,
2072
+ $groupName: String,
2073
+ $description: String,
2074
+ $displayName: String,
2075
+ $notes: String,
2076
+ $commit: String,
2077
+ $config: JSONString,
2078
+ $host: String,
2079
+ $debug: Boolean,
2080
+ $program: String,
2081
+ $repo: String,
2082
+ $jobType: String,
2083
+ $state: String,
2084
+ $sweep: String,
2085
+ $tags: [String!],
2086
+ $summaryMetrics: JSONString,
2087
+ ) {
2088
+ upsertBucket(input: {
2089
+ id: $id,
2090
+ name: $name,
2091
+ groupName: $groupName,
2092
+ modelName: $project,
2093
+ entityName: $entity,
2094
+ description: $description,
2095
+ displayName: $displayName,
2096
+ notes: $notes,
2097
+ config: $config,
2098
+ commit: $commit,
2099
+ host: $host,
2100
+ debug: $debug,
2101
+ jobProgram: $program,
2102
+ jobRepo: $repo,
2103
+ jobType: $jobType,
2104
+ state: $state,
2105
+ sweep: $sweep,
2106
+ tags: $tags,
2107
+ summaryMetrics: $summaryMetrics,
2108
+ }) {
2109
+ bucket {
2110
+ id
2111
+ name
2112
+ displayName
2113
+ description
2114
+ config
2115
+ sweepName
2116
+ project {
2117
+ id
2118
+ name
2119
+ entity {
2120
+ id
2121
+ name
2122
+ }
2123
+ }
2124
+ historyLineCount
2125
+ }
2126
+ inserted
2127
+ _Server_Settings_
2128
+ }
2129
+ }
2130
+ """
2131
+ self.server_settings_introspection()
2132
+
2133
+ server_settings_string = (
2134
+ """
2135
+ serverSettings {
2136
+ serverMessages{
2137
+ utfText
2138
+ plainText
2139
+ htmlText
2140
+ messageType
2141
+ messageLevel
2142
+ }
2143
+ }
2144
+ """
2145
+ if self._server_settings_type
2146
+ else ""
2147
+ )
2148
+
2149
+ query_string = query_string.replace("_Server_Settings_", server_settings_string)
2150
+ mutation = gql(query_string)
2151
+ config_str = json.dumps(config) if config else None
2152
+ if not description or description.isspace():
2153
+ description = None
2154
+
2155
+ kwargs = {}
2156
+ if num_retries is not None:
2157
+ kwargs["num_retries"] = num_retries
2158
+
2159
+ variable_values = {
2160
+ "id": id,
2161
+ "entity": entity or self.settings("entity"),
2162
+ "name": name,
2163
+ "project": project or util.auto_project_name(program_path),
2164
+ "groupName": group,
2165
+ "tags": tags,
2166
+ "description": description,
2167
+ "config": config_str,
2168
+ "commit": commit,
2169
+ "displayName": display_name,
2170
+ "notes": notes,
2171
+ "host": None if self.settings().get("anonymous") == "true" else host,
2172
+ "debug": env.is_debug(env=self._environ),
2173
+ "repo": repo,
2174
+ "program": program_path,
2175
+ "jobType": job_type,
2176
+ "state": state,
2177
+ "sweep": sweep_name,
2178
+ "summaryMetrics": summary_metrics,
2179
+ }
2180
+
2181
+ # retry conflict errors for 2 minutes, default to no_auth_retry
2182
+ check_retry_fn = util.make_check_retry_fn(
2183
+ check_fn=util.check_retry_conflict_or_gone,
2184
+ check_timedelta=datetime.timedelta(minutes=2),
2185
+ fallback_retry_fn=util.no_retry_auth,
2186
+ )
2187
+
2188
+ response = self.gql(
2189
+ mutation,
2190
+ variable_values=variable_values,
2191
+ check_retry_fn=check_retry_fn,
2192
+ **kwargs,
2193
+ )
2194
+
2195
+ run_obj: Dict[str, Dict[str, Dict[str, str]]] = response["upsertBucket"][
2196
+ "bucket"
2197
+ ]
2198
+ project_obj: Dict[str, Dict[str, str]] = run_obj.get("project", {})
2199
+ if project_obj:
2200
+ self.set_setting("project", project_obj["name"])
2201
+ entity_obj = project_obj.get("entity", {})
2202
+ if entity_obj:
2203
+ self.set_setting("entity", entity_obj["name"])
2204
+
2205
+ server_messages = None
2206
+ if self._server_settings_type:
2207
+ server_messages = (
2208
+ response["upsertBucket"]
2209
+ .get("serverSettings", {})
2210
+ .get("serverMessages", [])
2211
+ )
2212
+
2213
+ return (
2214
+ response["upsertBucket"]["bucket"],
2215
+ response["upsertBucket"]["inserted"],
2216
+ server_messages,
2217
+ )
2218
+
2219
+ @normalize_exceptions
2220
+ def get_run_info(
2221
+ self,
2222
+ entity: str,
2223
+ project: str,
2224
+ name: str,
2225
+ ) -> dict:
2226
+ query = gql(
2227
+ """
2228
+ query RunInfo($project: String!, $entity: String!, $name: String!) {
2229
+ project(name: $project, entityName: $entity) {
2230
+ run(name: $name) {
2231
+ runInfo {
2232
+ program
2233
+ args
2234
+ os
2235
+ python
2236
+ colab
2237
+ executable
2238
+ codeSaved
2239
+ cpuCount
2240
+ gpuCount
2241
+ gpu
2242
+ git {
2243
+ remote
2244
+ commit
2245
+ }
2246
+ }
2247
+ }
2248
+ }
2249
+ }
2250
+ """
2251
+ )
2252
+ variable_values = {"project": project, "entity": entity, "name": name}
2253
+ res = self.gql(query, variable_values)
2254
+ if res.get("project") is None:
2255
+ raise CommError(
2256
+ "Error fetching run info for {}/{}/{}. Check that this project exists and you have access to this entity and project".format(
2257
+ entity, project, name
2258
+ )
2259
+ )
2260
+ elif res["project"].get("run") is None:
2261
+ raise CommError(
2262
+ "Error fetching run info for {}/{}/{}. Check that this run id exists".format(
2263
+ entity, project, name
2264
+ )
2265
+ )
2266
+ run_info: dict = res["project"]["run"]["runInfo"]
2267
+ return run_info
2268
+
2269
+ @normalize_exceptions
2270
+ def get_run_state(self, entity: str, project: str, name: str) -> str:
2271
+ query = gql(
2272
+ """
2273
+ query RunState(
2274
+ $project: String!,
2275
+ $entity: String!,
2276
+ $name: String!) {
2277
+ project(name: $project, entityName: $entity) {
2278
+ run(name: $name) {
2279
+ state
2280
+ }
2281
+ }
2282
+ }
2283
+ """
2284
+ )
2285
+ variable_values = {
2286
+ "project": project,
2287
+ "entity": entity,
2288
+ "name": name,
2289
+ }
2290
+ res = self.gql(query, variable_values)
2291
+ if res.get("project") is None or res["project"].get("run") is None:
2292
+ raise CommError(f"Error fetching run state for {entity}/{project}/{name}.")
2293
+ run_state: str = res["project"]["run"]["state"]
2294
+ return run_state
2295
+
2296
+ @normalize_exceptions
2297
+ def create_run_files_introspection(self) -> bool:
2298
+ _, _, mutations = self.server_info_introspection()
2299
+ return "createRunFiles" in mutations
2300
+
2301
+ @normalize_exceptions
2302
+ def upload_urls(
2303
+ self,
2304
+ project: str,
2305
+ files: Union[List[str], Dict[str, IO]],
2306
+ run: Optional[str] = None,
2307
+ entity: Optional[str] = None,
2308
+ description: Optional[str] = None,
2309
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2310
+ """Generate temporary resumable upload urls.
2311
+
2312
+ Arguments:
2313
+ project (str): The project to download
2314
+ files (list or dict): The filenames to upload
2315
+ run (str, optional): The run to upload to
2316
+ entity (str, optional): The entity to scope this project to.
2317
+ description (str, optional): description
2318
+
2319
+ Returns:
2320
+ (run_id, upload_headers, file_info)
2321
+ run_id: id of run we uploaded files to
2322
+ upload_headers: A list of headers to use when uploading files.
2323
+ file_info: A dict of filenames and urls.
2324
+ {
2325
+ "run_id": "run_id",
2326
+ "upload_headers": [""],
2327
+ "file_info": [
2328
+ { "weights.h5": { "uploadUrl": "https://weights.url" } },
2329
+ { "model.json": { "uploadUrl": "https://model.json" } }
2330
+ ]
2331
+ }
2332
+ """
2333
+ run_name = run or self.current_run_id
2334
+ assert run_name, "run must be specified"
2335
+ entity = entity or self.settings("entity")
2336
+ assert entity, "entity must be specified"
2337
+
2338
+ has_create_run_files_mutation = self.create_run_files_introspection()
2339
+ if not has_create_run_files_mutation:
2340
+ return self.legacy_upload_urls(project, files, run, entity, description)
2341
+
2342
+ query = gql(
2343
+ """
2344
+ mutation CreateRunFiles($entity: String!, $project: String!, $run: String!, $files: [String!]!) {
2345
+ createRunFiles(input: {entityName: $entity, projectName: $project, runName: $run, files: $files}) {
2346
+ runID
2347
+ uploadHeaders
2348
+ files {
2349
+ name
2350
+ uploadUrl
2351
+ }
2352
+ }
2353
+ }
2354
+ """
2355
+ )
2356
+
2357
+ query_result = self.gql(
2358
+ query,
2359
+ variable_values={
2360
+ "project": project,
2361
+ "run": run_name,
2362
+ "entity": entity,
2363
+ "files": [file for file in files],
2364
+ },
2365
+ )
2366
+
2367
+ result = query_result["createRunFiles"]
2368
+ run_id = result["runID"]
2369
+ if not run_id:
2370
+ raise CommError(
2371
+ f"Error uploading files to {entity}/{project}/{run_name}. Check that this project exists and you have access to this entity and project"
2372
+ )
2373
+ file_name_urls = {file["name"]: file for file in result["files"]}
2374
+ return run_id, result["uploadHeaders"], file_name_urls
2375
+
2376
+ def legacy_upload_urls(
2377
+ self,
2378
+ project: str,
2379
+ files: Union[List[str], Dict[str, IO]],
2380
+ run: Optional[str] = None,
2381
+ entity: Optional[str] = None,
2382
+ description: Optional[str] = None,
2383
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2384
+ """Generate temporary resumable upload urls.
2385
+
2386
+ A new mutation createRunFiles was introduced after 0.15.4.
2387
+ This function is used to support older versions.
2388
+ """
2389
+ query = gql(
2390
+ """
2391
+ query RunUploadUrls($name: String!, $files: [String]!, $entity: String, $run: String!, $description: String) {
2392
+ model(name: $name, entityName: $entity) {
2393
+ bucket(name: $run, desc: $description) {
2394
+ id
2395
+ files(names: $files) {
2396
+ uploadHeaders
2397
+ edges {
2398
+ node {
2399
+ name
2400
+ url(upload: true)
2401
+ updatedAt
2402
+ }
2403
+ }
2404
+ }
2405
+ }
2406
+ }
2407
+ }
2408
+ """
2409
+ )
2410
+ run_id = run or self.current_run_id
2411
+ assert run_id, "run must be specified"
2412
+ entity = entity or self.settings("entity")
2413
+ query_result = self.gql(
2414
+ query,
2415
+ variable_values={
2416
+ "name": project,
2417
+ "run": run_id,
2418
+ "entity": entity,
2419
+ "files": [file for file in files],
2420
+ "description": description,
2421
+ },
2422
+ )
2423
+
2424
+ run_obj = query_result["model"]["bucket"]
2425
+ if run_obj:
2426
+ for file_node in run_obj["files"]["edges"]:
2427
+ file = file_node["node"]
2428
+ # we previously used "url" field but now use "uploadUrl"
2429
+ # replace the "url" field with "uploadUrl for downstream compatibility
2430
+ if "url" in file and "uploadUrl" not in file:
2431
+ file["uploadUrl"] = file.pop("url")
2432
+
2433
+ result = {
2434
+ file["name"]: file for file in self._flatten_edges(run_obj["files"])
2435
+ }
2436
+ return run_obj["id"], run_obj["files"]["uploadHeaders"], result
2437
+ else:
2438
+ raise CommError(f"Run does not exist {entity}/{project}/{run_id}.")
2439
+
2440
+ @normalize_exceptions
2441
+ def download_urls(
2442
+ self,
2443
+ project: str,
2444
+ run: Optional[str] = None,
2445
+ entity: Optional[str] = None,
2446
+ ) -> Dict[str, Dict[str, str]]:
2447
+ """Generate download urls.
2448
+
2449
+ Arguments:
2450
+ project (str): The project to download
2451
+ run (str): The run to upload to
2452
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2453
+
2454
+ Returns:
2455
+ A dict of extensions and urls
2456
+
2457
+ {
2458
+ 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
2459
+ 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2460
+ }
2461
+ """
2462
+ query = gql(
2463
+ """
2464
+ query RunDownloadUrls($name: String!, $entity: String, $run: String!) {
2465
+ model(name: $name, entityName: $entity) {
2466
+ bucket(name: $run) {
2467
+ files {
2468
+ edges {
2469
+ node {
2470
+ name
2471
+ url
2472
+ md5
2473
+ updatedAt
2474
+ }
2475
+ }
2476
+ }
2477
+ }
2478
+ }
2479
+ }
2480
+ """
2481
+ )
2482
+ run = run or self.current_run_id
2483
+ assert run, "run must be specified"
2484
+ entity = entity or self.settings("entity")
2485
+ query_result = self.gql(
2486
+ query,
2487
+ variable_values={
2488
+ "name": project,
2489
+ "run": run,
2490
+ "entity": entity,
2491
+ },
2492
+ )
2493
+ if query_result["model"] is None:
2494
+ raise CommError(f"Run does not exist {entity}/{project}/{run}.")
2495
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2496
+ return {file["name"]: file for file in files if file}
2497
+
2498
+ @normalize_exceptions
2499
+ def download_url(
2500
+ self,
2501
+ project: str,
2502
+ file_name: str,
2503
+ run: Optional[str] = None,
2504
+ entity: Optional[str] = None,
2505
+ ) -> Optional[Dict[str, str]]:
2506
+ """Generate download urls.
2507
+
2508
+ Arguments:
2509
+ project (str): The project to download
2510
+ file_name (str): The name of the file to download
2511
+ run (str): The run to upload to
2512
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2513
+
2514
+ Returns:
2515
+ A dict of extensions and urls
2516
+
2517
+ { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2518
+
2519
+ """
2520
+ query = gql(
2521
+ """
2522
+ query RunDownloadUrl($name: String!, $fileName: String!, $entity: String, $run: String!) {
2523
+ model(name: $name, entityName: $entity) {
2524
+ bucket(name: $run) {
2525
+ files(names: [$fileName]) {
2526
+ edges {
2527
+ node {
2528
+ name
2529
+ url
2530
+ md5
2531
+ updatedAt
2532
+ }
2533
+ }
2534
+ }
2535
+ }
2536
+ }
2537
+ }
2538
+ """
2539
+ )
2540
+ run = run or self.current_run_id
2541
+ assert run, "run must be specified"
2542
+ query_result = self.gql(
2543
+ query,
2544
+ variable_values={
2545
+ "name": project,
2546
+ "run": run,
2547
+ "fileName": file_name,
2548
+ "entity": entity or self.settings("entity"),
2549
+ },
2550
+ )
2551
+ if query_result["model"]:
2552
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2553
+ return files[0] if len(files) > 0 and files[0].get("updatedAt") else None
2554
+ else:
2555
+ return None
2556
+
2557
+ @normalize_exceptions
2558
+ def download_file(self, url: str) -> Tuple[int, requests.Response]:
2559
+ """Initiate a streaming download.
2560
+
2561
+ Arguments:
2562
+ url (str): The url to download
2563
+
2564
+ Returns:
2565
+ A tuple of the content length and the streaming response
2566
+ """
2567
+ check_httpclient_logger_handler()
2568
+ auth = None
2569
+ if _thread_local_api_settings.cookies is None:
2570
+ auth = ("user", self.api_key or "")
2571
+ response = requests.get(
2572
+ url,
2573
+ auth=auth,
2574
+ cookies=_thread_local_api_settings.cookies or {},
2575
+ headers=_thread_local_api_settings.headers or {},
2576
+ stream=True,
2577
+ )
2578
+ response.raise_for_status()
2579
+ return int(response.headers.get("content-length", 0)), response
2580
+
2581
+ @normalize_exceptions
2582
+ def download_write_file(
2583
+ self,
2584
+ metadata: Dict[str, str],
2585
+ out_dir: Optional[str] = None,
2586
+ ) -> Tuple[str, Optional[requests.Response]]:
2587
+ """Download a file from a run and write it to wandb/.
2588
+
2589
+ Arguments:
2590
+ metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
2591
+ out_dir (str, optional): The directory to write the file to. Defaults to wandb/
2592
+
2593
+ Returns:
2594
+ A tuple of the file's local path and the streaming response. The streaming response is None if the file
2595
+ already existed and was up-to-date.
2596
+ """
2597
+ filename = metadata["name"]
2598
+ path = os.path.join(out_dir or self.settings("wandb_dir"), filename)
2599
+ if self.file_current(filename, B64MD5(metadata["md5"])):
2600
+ return path, None
2601
+
2602
+ size, response = self.download_file(metadata["url"])
2603
+
2604
+ with util.fsync_open(path, "wb") as file:
2605
+ for data in response.iter_content(chunk_size=1024):
2606
+ file.write(data)
2607
+
2608
+ return path, response
2609
+
2610
+ def upload_file_azure(
2611
+ self, url: str, file: Any, extra_headers: Dict[str, str]
2612
+ ) -> None:
2613
+ """Upload a file to azure."""
2614
+ from azure.core.exceptions import AzureError # type: ignore
2615
+
2616
+ # Configure the client without retries so our existing logic can handle them
2617
+ client = self._azure_blob_module.BlobClient.from_blob_url(
2618
+ url, retry_policy=self._azure_blob_module.LinearRetry(retry_total=0)
2619
+ )
2620
+ try:
2621
+ if extra_headers.get("Content-MD5") is not None:
2622
+ md5: Optional[bytes] = base64.b64decode(extra_headers["Content-MD5"])
2623
+ else:
2624
+ md5 = None
2625
+ content_settings = self._azure_blob_module.ContentSettings(
2626
+ content_md5=md5,
2627
+ content_type=extra_headers.get("Content-Type"),
2628
+ )
2629
+ client.upload_blob(
2630
+ file,
2631
+ max_concurrency=4,
2632
+ length=len(file),
2633
+ overwrite=True,
2634
+ content_settings=content_settings,
2635
+ )
2636
+ except AzureError as e:
2637
+ if hasattr(e, "response"):
2638
+ response = requests.models.Response()
2639
+ response.status_code = e.response.status_code
2640
+ response.headers = e.response.headers
2641
+ raise requests.exceptions.RequestException(e.message, response=response)
2642
+ else:
2643
+ raise requests.exceptions.ConnectionError(e.message)
2644
+
2645
+ def upload_multipart_file_chunk(
2646
+ self,
2647
+ url: str,
2648
+ upload_chunk: bytes,
2649
+ extra_headers: Optional[Dict[str, str]] = None,
2650
+ ) -> Optional[requests.Response]:
2651
+ """Upload a file chunk to S3 with failure resumption.
2652
+
2653
+ Arguments:
2654
+ url: The url to download
2655
+ upload_chunk: The path to the file you want to upload
2656
+ extra_headers: A dictionary of extra headers to send with the request
2657
+
2658
+ Returns:
2659
+ The `requests` library response object
2660
+ """
2661
+ check_httpclient_logger_handler()
2662
+ try:
2663
+ if env.is_debug(env=self._environ):
2664
+ logger.debug("upload_file: %s", url)
2665
+ response = self._upload_file_session.put(
2666
+ url, data=upload_chunk, headers=extra_headers
2667
+ )
2668
+ if env.is_debug(env=self._environ):
2669
+ logger.debug("upload_file: %s complete", url)
2670
+ response.raise_for_status()
2671
+ except requests.exceptions.RequestException as e:
2672
+ logger.error(f"upload_file exception {url}: {e}")
2673
+ request_headers = e.request.headers if e.request is not None else ""
2674
+ logger.error(f"upload_file request headers: {request_headers}")
2675
+ response_content = e.response.content if e.response is not None else ""
2676
+ logger.error(f"upload_file response body: {response_content}")
2677
+ status_code = e.response.status_code if e.response is not None else 0
2678
+ # S3 reports retryable request timeouts out-of-band
2679
+ is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
2680
+ response_content
2681
+ )
2682
+ # Retry errors from cloud storage or local network issues
2683
+ if (
2684
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2685
+ or isinstance(
2686
+ e,
2687
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2688
+ )
2689
+ or is_aws_retryable
2690
+ ):
2691
+ _e = retry.TransientError(exc=e)
2692
+ raise _e.with_traceback(sys.exc_info()[2])
2693
+ else:
2694
+ wandb._sentry.reraise(e)
2695
+ return response
2696
+
2697
+ def upload_file(
2698
+ self,
2699
+ url: str,
2700
+ file: IO[bytes],
2701
+ callback: Optional["ProgressFn"] = None,
2702
+ extra_headers: Optional[Dict[str, str]] = None,
2703
+ ) -> Optional[requests.Response]:
2704
+ """Upload a file to W&B with failure resumption.
2705
+
2706
+ Arguments:
2707
+ url: The url to download
2708
+ file: The path to the file you want to upload
2709
+ callback: A callback which is passed the number of
2710
+ bytes uploaded since the last time it was called, used to report progress
2711
+ extra_headers: A dictionary of extra headers to send with the request
2712
+
2713
+ Returns:
2714
+ The `requests` library response object
2715
+ """
2716
+ check_httpclient_logger_handler()
2717
+ extra_headers = extra_headers.copy() if extra_headers else {}
2718
+ response: Optional[requests.Response] = None
2719
+ progress = Progress(file, callback=callback)
2720
+ try:
2721
+ if "x-ms-blob-type" in extra_headers and self._azure_blob_module:
2722
+ self.upload_file_azure(url, progress, extra_headers)
2723
+ else:
2724
+ if "x-ms-blob-type" in extra_headers:
2725
+ wandb.termwarn(
2726
+ "Azure uploads over 256MB require the azure SDK, install with pip install wandb[azure]",
2727
+ repeat=False,
2728
+ )
2729
+ if env.is_debug(env=self._environ):
2730
+ logger.debug("upload_file: %s", url)
2731
+ response = self._upload_file_session.put(
2732
+ url, data=progress, headers=extra_headers
2733
+ )
2734
+ if env.is_debug(env=self._environ):
2735
+ logger.debug("upload_file: %s complete", url)
2736
+ response.raise_for_status()
2737
+ except requests.exceptions.RequestException as e:
2738
+ logger.error(f"upload_file exception {url}: {e}")
2739
+ request_headers = e.request.headers if e.request is not None else ""
2740
+ logger.error(f"upload_file request headers: {request_headers}")
2741
+ response_content = e.response.content if e.response is not None else ""
2742
+ logger.error(f"upload_file response body: {response_content}")
2743
+ status_code = e.response.status_code if e.response is not None else 0
2744
+ # S3 reports retryable request timeouts out-of-band
2745
+ is_aws_retryable = (
2746
+ "x-amz-meta-md5" in extra_headers
2747
+ and status_code == 400
2748
+ and "RequestTimeout" in str(response_content)
2749
+ )
2750
+ # We need to rewind the file for the next retry (the file passed in is `seek`'ed to 0)
2751
+ progress.rewind()
2752
+ # Retry errors from cloud storage or local network issues
2753
+ if (
2754
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2755
+ or isinstance(
2756
+ e,
2757
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2758
+ )
2759
+ or is_aws_retryable
2760
+ ):
2761
+ _e = retry.TransientError(exc=e)
2762
+ raise _e.with_traceback(sys.exc_info()[2])
2763
+ else:
2764
+ wandb._sentry.reraise(e)
2765
+
2766
+ return response
2767
+
2768
+ @normalize_exceptions
2769
+ def register_agent(
2770
+ self,
2771
+ host: str,
2772
+ sweep_id: Optional[str] = None,
2773
+ project_name: Optional[str] = None,
2774
+ entity: Optional[str] = None,
2775
+ ) -> dict:
2776
+ """Register a new agent.
2777
+
2778
+ Arguments:
2779
+ host (str): hostname
2780
+ sweep_id (str): sweep id
2781
+ project_name: (str): model that contains sweep
2782
+ entity: (str): entity that contains sweep
2783
+ """
2784
+ mutation = gql(
2785
+ """
2786
+ mutation CreateAgent(
2787
+ $host: String!
2788
+ $projectName: String,
2789
+ $entityName: String,
2790
+ $sweep: String!
2791
+ ) {
2792
+ createAgent(input: {
2793
+ host: $host,
2794
+ projectName: $projectName,
2795
+ entityName: $entityName,
2796
+ sweep: $sweep,
2797
+ }) {
2798
+ agent {
2799
+ id
2800
+ }
2801
+ }
2802
+ }
2803
+ """
2804
+ )
2805
+ if entity is None:
2806
+ entity = self.settings("entity")
2807
+ if project_name is None:
2808
+ project_name = self.settings("project")
2809
+
2810
+ response = self.gql(
2811
+ mutation,
2812
+ variable_values={
2813
+ "host": host,
2814
+ "entityName": entity,
2815
+ "projectName": project_name,
2816
+ "sweep": sweep_id,
2817
+ },
2818
+ check_retry_fn=util.no_retry_4xx,
2819
+ )
2820
+ result: dict = response["createAgent"]["agent"]
2821
+ return result
2822
+
2823
+ def agent_heartbeat(
2824
+ self, agent_id: str, metrics: dict, run_states: dict
2825
+ ) -> List[Dict[str, Any]]:
2826
+ """Notify server about agent state, receive commands.
2827
+
2828
+ Arguments:
2829
+ agent_id (str): agent_id
2830
+ metrics (dict): system metrics
2831
+ run_states (dict): run_id: state mapping
2832
+ Returns:
2833
+ List of commands to execute.
2834
+ """
2835
+ mutation = gql(
2836
+ """
2837
+ mutation Heartbeat(
2838
+ $id: ID!,
2839
+ $metrics: JSONString,
2840
+ $runState: JSONString
2841
+ ) {
2842
+ agentHeartbeat(input: {
2843
+ id: $id,
2844
+ metrics: $metrics,
2845
+ runState: $runState
2846
+ }) {
2847
+ agent {
2848
+ id
2849
+ }
2850
+ commands
2851
+ }
2852
+ }
2853
+ """
2854
+ )
2855
+
2856
+ if agent_id is None:
2857
+ raise ValueError("Cannot call heartbeat with an unregistered agent.")
2858
+
2859
+ try:
2860
+ response = self.gql(
2861
+ mutation,
2862
+ variable_values={
2863
+ "id": agent_id,
2864
+ "metrics": json.dumps(metrics),
2865
+ "runState": json.dumps(run_states),
2866
+ },
2867
+ timeout=60,
2868
+ )
2869
+ except Exception as e:
2870
+ # GQL raises exceptions with stringified python dictionaries :/
2871
+ message = ast.literal_eval(e.args[0])["message"]
2872
+ logger.error("Error communicating with W&B: %s", message)
2873
+ return []
2874
+ else:
2875
+ result: List[Dict[str, Any]] = json.loads(
2876
+ response["agentHeartbeat"]["commands"]
2877
+ )
2878
+ return result
2879
+
2880
+ @staticmethod
2881
+ def _validate_config_and_fill_distribution(config: dict) -> dict:
2882
+ # verify that parameters are well specified.
2883
+ # TODO(dag): deprecate this in favor of jsonschema validation once
2884
+ # apiVersion 2 is released and local controller is integrated with
2885
+ # wandb/client.
2886
+
2887
+ # avoid modifying the original config dict in
2888
+ # case it is reused outside the calling func
2889
+ config = deepcopy(config)
2890
+
2891
+ # explicitly cast to dict in case config was passed as a sweepconfig
2892
+ # sweepconfig does not serialize cleanly to yaml and breaks graphql,
2893
+ # but it is a subclass of dict, so this conversion is clean
2894
+ config = dict(config)
2895
+
2896
+ if "parameters" not in config:
2897
+ # still shows an anaconda warning, but doesn't error
2898
+ return config
2899
+
2900
+ for parameter_name in config["parameters"]:
2901
+ parameter = config["parameters"][parameter_name]
2902
+ if "min" in parameter and "max" in parameter:
2903
+ if "distribution" not in parameter:
2904
+ if isinstance(parameter["min"], int) and isinstance(
2905
+ parameter["max"], int
2906
+ ):
2907
+ parameter["distribution"] = "int_uniform"
2908
+ elif isinstance(parameter["min"], float) and isinstance(
2909
+ parameter["max"], float
2910
+ ):
2911
+ parameter["distribution"] = "uniform"
2912
+ else:
2913
+ raise ValueError(
2914
+ "Parameter %s is ambiguous, please specify bounds as both floats (for a float_"
2915
+ "uniform distribution) or ints (for an int_uniform distribution)."
2916
+ % parameter_name
2917
+ )
2918
+ return config
2919
+
2920
+ @normalize_exceptions
2921
+ def upsert_sweep(
2922
+ self,
2923
+ config: dict,
2924
+ controller: Optional[str] = None,
2925
+ launch_scheduler: Optional[str] = None,
2926
+ scheduler: Optional[str] = None,
2927
+ obj_id: Optional[str] = None,
2928
+ project: Optional[str] = None,
2929
+ entity: Optional[str] = None,
2930
+ state: Optional[str] = None,
2931
+ ) -> Tuple[str, List[str]]:
2932
+ """Upsert a sweep object.
2933
+
2934
+ Arguments:
2935
+ config (dict): sweep config (will be converted to yaml)
2936
+ controller (str): controller to use
2937
+ launch_scheduler (str): launch scheduler to use
2938
+ scheduler (str): scheduler to use
2939
+ obj_id (str): object id
2940
+ project (str): project to use
2941
+ entity (str): entity to use
2942
+ state (str): state
2943
+ """
2944
+ project_query = """
2945
+ project {
2946
+ id
2947
+ name
2948
+ entity {
2949
+ id
2950
+ name
2951
+ }
2952
+ }
2953
+ """
2954
+ mutation_str = """
2955
+ mutation UpsertSweep(
2956
+ $id: ID,
2957
+ $config: String,
2958
+ $description: String,
2959
+ $entityName: String,
2960
+ $projectName: String,
2961
+ $controller: JSONString,
2962
+ $scheduler: JSONString,
2963
+ $state: String
2964
+ ) {
2965
+ upsertSweep(input: {
2966
+ id: $id,
2967
+ config: $config,
2968
+ description: $description,
2969
+ entityName: $entityName,
2970
+ projectName: $projectName,
2971
+ controller: $controller,
2972
+ scheduler: $scheduler,
2973
+ state: $state
2974
+ }) {
2975
+ sweep {
2976
+ name
2977
+ _PROJECT_QUERY_
2978
+ }
2979
+ configValidationWarnings
2980
+ }
2981
+ }
2982
+ """
2983
+ # TODO(jhr): we need protocol versioning to know schema is not supported
2984
+ # for now we will just try both new and old query
2985
+
2986
+ # launchScheduler was introduced in core v0.14.0
2987
+ mutation_4 = gql(
2988
+ mutation_str.replace(
2989
+ "$controller: JSONString,",
2990
+ "$controller: JSONString,$launchScheduler: JSONString,",
2991
+ )
2992
+ .replace(
2993
+ "controller: $controller,",
2994
+ "controller: $controller,launchScheduler: $launchScheduler,",
2995
+ )
2996
+ .replace("_PROJECT_QUERY_", project_query)
2997
+ )
2998
+
2999
+ # mutation 3 maps to backend that can support CLI version of at least 0.10.31
3000
+ mutation_3 = gql(mutation_str.replace("_PROJECT_QUERY_", project_query))
3001
+ mutation_2 = gql(
3002
+ mutation_str.replace("_PROJECT_QUERY_", project_query).replace(
3003
+ "configValidationWarnings", ""
3004
+ )
3005
+ )
3006
+ mutation_1 = gql(
3007
+ mutation_str.replace("_PROJECT_QUERY_", "").replace(
3008
+ "configValidationWarnings", ""
3009
+ )
3010
+ )
3011
+
3012
+ # TODO(dag): replace this with a query for protocol versioning
3013
+ mutations = [mutation_4, mutation_3, mutation_2, mutation_1]
3014
+
3015
+ config = self._validate_config_and_fill_distribution(config)
3016
+
3017
+ # Silly, but attr-dicts like EasyDicts don't serialize correctly to yaml.
3018
+ # This sanitizes them with a round trip pass through json to get a regular dict.
3019
+ config_str = yaml.dump(json.loads(json.dumps(config)))
3020
+
3021
+ err: Optional[Exception] = None
3022
+ for mutation in mutations:
3023
+ try:
3024
+ variables = {
3025
+ "id": obj_id,
3026
+ "config": config_str,
3027
+ "description": config.get("description"),
3028
+ "entityName": entity or self.settings("entity"),
3029
+ "projectName": project or self.settings("project"),
3030
+ "controller": controller,
3031
+ "launchScheduler": launch_scheduler,
3032
+ "scheduler": scheduler,
3033
+ }
3034
+ if state:
3035
+ variables["state"] = state
3036
+
3037
+ response = self.gql(
3038
+ mutation,
3039
+ variable_values=variables,
3040
+ check_retry_fn=util.no_retry_4xx,
3041
+ )
3042
+ except UsageError as e:
3043
+ raise e
3044
+ except Exception as e:
3045
+ # graphql schema exception is generic
3046
+ err = e
3047
+ continue
3048
+ err = None
3049
+ break
3050
+ if err:
3051
+ raise err
3052
+
3053
+ sweep: Dict[str, Dict[str, Dict]] = response["upsertSweep"]["sweep"]
3054
+ project_obj: Dict[str, Dict] = sweep.get("project", {})
3055
+ if project_obj:
3056
+ self.set_setting("project", project_obj["name"])
3057
+ entity_obj: dict = project_obj.get("entity", {})
3058
+ if entity_obj:
3059
+ self.set_setting("entity", entity_obj["name"])
3060
+
3061
+ warnings = response["upsertSweep"].get("configValidationWarnings", [])
3062
+ return response["upsertSweep"]["sweep"]["name"], warnings
3063
+
3064
+ @normalize_exceptions
3065
+ def create_anonymous_api_key(self) -> str:
3066
+ """Create a new API key belonging to a new anonymous user."""
3067
+ mutation = gql(
3068
+ """
3069
+ mutation CreateAnonymousApiKey {
3070
+ createAnonymousEntity(input: {}) {
3071
+ apiKey {
3072
+ name
3073
+ }
3074
+ }
3075
+ }
3076
+ """
3077
+ )
3078
+
3079
+ response = self.gql(mutation, variable_values={})
3080
+ key: str = str(response["createAnonymousEntity"]["apiKey"]["name"])
3081
+ return key
3082
+
3083
+ @staticmethod
3084
+ def file_current(fname: str, md5: B64MD5) -> bool:
3085
+ """Checksum a file and compare the md5 with the known md5."""
3086
+ return os.path.isfile(fname) and md5_file_b64(fname) == md5
3087
+
3088
+ @normalize_exceptions
3089
+ def pull(
3090
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
3091
+ ) -> "List[requests.Response]":
3092
+ """Download files from W&B.
3093
+
3094
+ Arguments:
3095
+ project (str): The project to download
3096
+ run (str, optional): The run to upload to
3097
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3098
+
3099
+ Returns:
3100
+ The `requests` library response object
3101
+ """
3102
+ project, run = self.parse_slug(project, run=run)
3103
+ urls = self.download_urls(project, run, entity)
3104
+ responses = []
3105
+ for filename in urls:
3106
+ _, response = self.download_write_file(urls[filename])
3107
+ if response:
3108
+ responses.append(response)
3109
+
3110
+ return responses
3111
+
3112
+ def get_project(self) -> str:
3113
+ project: str = self.default_settings.get("project") or self.settings("project")
3114
+ return project
3115
+
3116
+ @normalize_exceptions
3117
+ def push(
3118
+ self,
3119
+ files: Union[List[str], Dict[str, IO]],
3120
+ run: Optional[str] = None,
3121
+ entity: Optional[str] = None,
3122
+ project: Optional[str] = None,
3123
+ description: Optional[str] = None,
3124
+ force: bool = True,
3125
+ progress: Union[TextIO, bool] = False,
3126
+ ) -> "List[Optional[requests.Response]]":
3127
+ """Uploads multiple files to W&B.
3128
+
3129
+ Arguments:
3130
+ files (list or dict): The filenames to upload, when dict the values are open files
3131
+ run (str, optional): The run to upload to
3132
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3133
+ project (str, optional): The name of the project to upload to. Defaults to the one in settings.
3134
+ description (str, optional): The description of the changes
3135
+ force (bool, optional): Whether to prevent push if git has uncommitted changes
3136
+ progress (callable, or stream): If callable, will be called with (chunk_bytes,
3137
+ total_bytes) as argument else if True, renders a progress bar to stream.
3138
+
3139
+ Returns:
3140
+ A list of `requests.Response` objects
3141
+ """
3142
+ if project is None:
3143
+ project = self.get_project()
3144
+ if project is None:
3145
+ raise CommError("No project configured.")
3146
+ if run is None:
3147
+ run = self.current_run_id
3148
+
3149
+ # TODO(adrian): we use a retriable version of self.upload_file() so
3150
+ # will never retry self.upload_urls() here. Instead, maybe we should
3151
+ # make push itself retriable.
3152
+ _, upload_headers, result = self.upload_urls(
3153
+ project,
3154
+ files,
3155
+ run,
3156
+ entity,
3157
+ )
3158
+ extra_headers = {}
3159
+ for upload_header in upload_headers:
3160
+ key, val = upload_header.split(":", 1)
3161
+ extra_headers[key] = val
3162
+ responses = []
3163
+ for file_name, file_info in result.items():
3164
+ file_url = file_info["uploadUrl"]
3165
+
3166
+ # If the upload URL is relative, fill it in with the base URL,
3167
+ # since it's a proxied file store like the on-prem VM.
3168
+ if file_url.startswith("/"):
3169
+ file_url = f"{self.api_url}{file_url}"
3170
+
3171
+ try:
3172
+ # To handle Windows paths
3173
+ # TODO: this doesn't handle absolute paths...
3174
+ normal_name = os.path.join(*file_name.split("/"))
3175
+ open_file = (
3176
+ files[file_name]
3177
+ if isinstance(files, dict)
3178
+ else open(normal_name, "rb")
3179
+ )
3180
+ except OSError:
3181
+ print(f"{file_name} does not exist")
3182
+ continue
3183
+ if progress is False:
3184
+ responses.append(
3185
+ self.upload_file_retry(
3186
+ file_info["uploadUrl"], open_file, extra_headers=extra_headers
3187
+ )
3188
+ )
3189
+ else:
3190
+ if callable(progress):
3191
+ responses.append( # type: ignore
3192
+ self.upload_file_retry(
3193
+ file_url, open_file, progress, extra_headers=extra_headers
3194
+ )
3195
+ )
3196
+ else:
3197
+ length = os.fstat(open_file.fileno()).st_size
3198
+ with click.progressbar(
3199
+ file=progress, # type: ignore
3200
+ length=length,
3201
+ label=f"Uploading file: {file_name}",
3202
+ fill_char=click.style("&", fg="green"),
3203
+ ) as bar:
3204
+ responses.append(
3205
+ self.upload_file_retry(
3206
+ file_url,
3207
+ open_file,
3208
+ lambda bites, _: bar.update(bites),
3209
+ extra_headers=extra_headers,
3210
+ )
3211
+ )
3212
+ open_file.close()
3213
+ return responses
3214
+
3215
+ def link_artifact(
3216
+ self,
3217
+ client_id: str,
3218
+ server_id: str,
3219
+ portfolio_name: str,
3220
+ entity: str,
3221
+ project: str,
3222
+ aliases: Sequence[str],
3223
+ ) -> Dict[str, Any]:
3224
+ template = """
3225
+ mutation LinkArtifact(
3226
+ $artifactPortfolioName: String!,
3227
+ $entityName: String!,
3228
+ $projectName: String!,
3229
+ $aliases: [ArtifactAliasInput!],
3230
+ ID_TYPE
3231
+ ) {
3232
+ linkArtifact(input: {
3233
+ artifactPortfolioName: $artifactPortfolioName,
3234
+ entityName: $entityName,
3235
+ projectName: $projectName,
3236
+ aliases: $aliases,
3237
+ ID_VALUE
3238
+ }) {
3239
+ versionIndex
3240
+ }
3241
+ }
3242
+ """
3243
+
3244
+ def replace(a: str, b: str) -> None:
3245
+ nonlocal template
3246
+ template = template.replace(a, b)
3247
+
3248
+ if server_id:
3249
+ replace("ID_TYPE", "$artifactID: ID")
3250
+ replace("ID_VALUE", "artifactID: $artifactID")
3251
+ elif client_id:
3252
+ replace("ID_TYPE", "$clientID: ID")
3253
+ replace("ID_VALUE", "clientID: $clientID")
3254
+
3255
+ variable_values = {
3256
+ "clientID": client_id,
3257
+ "artifactID": server_id,
3258
+ "artifactPortfolioName": portfolio_name,
3259
+ "entityName": entity,
3260
+ "projectName": project,
3261
+ "aliases": [
3262
+ {"alias": alias, "artifactCollectionName": portfolio_name}
3263
+ for alias in aliases
3264
+ ],
3265
+ }
3266
+
3267
+ mutation = gql(template)
3268
+ response = self.gql(mutation, variable_values=variable_values)
3269
+ link_artifact: Dict[str, Any] = response["linkArtifact"]
3270
+ return link_artifact
3271
+
3272
+ def use_artifact(
3273
+ self,
3274
+ artifact_id: str,
3275
+ entity_name: Optional[str] = None,
3276
+ project_name: Optional[str] = None,
3277
+ run_name: Optional[str] = None,
3278
+ use_as: Optional[str] = None,
3279
+ ) -> Optional[Dict[str, Any]]:
3280
+ query_template = """
3281
+ mutation UseArtifact(
3282
+ $entityName: String!,
3283
+ $projectName: String!,
3284
+ $runName: String!,
3285
+ $artifactID: ID!,
3286
+ _USED_AS_TYPE_
3287
+ ) {
3288
+ useArtifact(input: {
3289
+ entityName: $entityName,
3290
+ projectName: $projectName,
3291
+ runName: $runName,
3292
+ artifactID: $artifactID,
3293
+ _USED_AS_VALUE_
3294
+ }) {
3295
+ artifact {
3296
+ id
3297
+ digest
3298
+ description
3299
+ state
3300
+ createdAt
3301
+ metadata
3302
+ }
3303
+ }
3304
+ }
3305
+ """
3306
+
3307
+ artifact_types = self.server_use_artifact_input_introspection()
3308
+ if "usedAs" in artifact_types:
3309
+ query_template = query_template.replace(
3310
+ "_USED_AS_TYPE_", "$usedAs: String"
3311
+ ).replace("_USED_AS_VALUE_", "usedAs: $usedAs")
3312
+ else:
3313
+ query_template = query_template.replace("_USED_AS_TYPE_", "").replace(
3314
+ "_USED_AS_VALUE_", ""
3315
+ )
3316
+
3317
+ query = gql(query_template)
3318
+
3319
+ entity_name = entity_name or self.settings("entity")
3320
+ project_name = project_name or self.settings("project")
3321
+ run_name = run_name or self.current_run_id
3322
+
3323
+ response = self.gql(
3324
+ query,
3325
+ variable_values={
3326
+ "entityName": entity_name,
3327
+ "projectName": project_name,
3328
+ "runName": run_name,
3329
+ "artifactID": artifact_id,
3330
+ "usedAs": use_as,
3331
+ },
3332
+ )
3333
+
3334
+ if response["useArtifact"]["artifact"]:
3335
+ artifact: Dict[str, Any] = response["useArtifact"]["artifact"]
3336
+ return artifact
3337
+ return None
3338
+
3339
+ def create_artifact_type(
3340
+ self,
3341
+ artifact_type_name: str,
3342
+ entity_name: Optional[str] = None,
3343
+ project_name: Optional[str] = None,
3344
+ description: Optional[str] = None,
3345
+ ) -> Optional[str]:
3346
+ mutation = gql(
3347
+ """
3348
+ mutation CreateArtifactType(
3349
+ $entityName: String!,
3350
+ $projectName: String!,
3351
+ $artifactTypeName: String!,
3352
+ $description: String
3353
+ ) {
3354
+ createArtifactType(input: {
3355
+ entityName: $entityName,
3356
+ projectName: $projectName,
3357
+ name: $artifactTypeName,
3358
+ description: $description
3359
+ }) {
3360
+ artifactType {
3361
+ id
3362
+ }
3363
+ }
3364
+ }
3365
+ """
3366
+ )
3367
+ entity_name = entity_name or self.settings("entity")
3368
+ project_name = project_name or self.settings("project")
3369
+ response = self.gql(
3370
+ mutation,
3371
+ variable_values={
3372
+ "entityName": entity_name,
3373
+ "projectName": project_name,
3374
+ "artifactTypeName": artifact_type_name,
3375
+ "description": description,
3376
+ },
3377
+ )
3378
+ _id: Optional[str] = response["createArtifactType"]["artifactType"]["id"]
3379
+ return _id
3380
+
3381
+ def server_artifact_introspection(self) -> List:
3382
+ query_string = """
3383
+ query ProbeServerArtifact {
3384
+ ArtifactInfoType: __type(name:"Artifact") {
3385
+ fields {
3386
+ name
3387
+ }
3388
+ }
3389
+ }
3390
+ """
3391
+
3392
+ if self.server_artifact_fields_info is None:
3393
+ query = gql(query_string)
3394
+ res = self.gql(query)
3395
+ input_fields = res.get("ArtifactInfoType", {}).get("fields", [{}])
3396
+ self.server_artifact_fields_info = [
3397
+ field["name"] for field in input_fields if "name" in field
3398
+ ]
3399
+
3400
+ return self.server_artifact_fields_info
3401
+
3402
+ def server_create_artifact_introspection(self) -> List:
3403
+ query_string = """
3404
+ query ProbeServerCreateArtifactInput {
3405
+ CreateArtifactInputInfoType: __type(name:"CreateArtifactInput") {
3406
+ inputFields{
3407
+ name
3408
+ }
3409
+ }
3410
+ }
3411
+ """
3412
+
3413
+ if self.server_create_artifact_input_info is None:
3414
+ query = gql(query_string)
3415
+ res = self.gql(query)
3416
+ input_fields = res.get("CreateArtifactInputInfoType", {}).get(
3417
+ "inputFields", [{}]
3418
+ )
3419
+ self.server_create_artifact_input_info = [
3420
+ field["name"] for field in input_fields if "name" in field
3421
+ ]
3422
+
3423
+ return self.server_create_artifact_input_info
3424
+
3425
+ def _get_create_artifact_mutation(
3426
+ self,
3427
+ fields: List,
3428
+ history_step: Optional[int],
3429
+ distributed_id: Optional[str],
3430
+ ) -> str:
3431
+ types = ""
3432
+ values = ""
3433
+
3434
+ if "historyStep" in fields and history_step not in [0, None]:
3435
+ types += "$historyStep: Int64!,"
3436
+ values += "historyStep: $historyStep,"
3437
+
3438
+ if distributed_id:
3439
+ types += "$distributedID: String,"
3440
+ values += "distributedID: $distributedID,"
3441
+
3442
+ if "clientID" in fields:
3443
+ types += "$clientID: ID,"
3444
+ values += "clientID: $clientID,"
3445
+
3446
+ if "sequenceClientID" in fields:
3447
+ types += "$sequenceClientID: ID,"
3448
+ values += "sequenceClientID: $sequenceClientID,"
3449
+
3450
+ if "enableDigestDeduplication" in fields:
3451
+ values += "enableDigestDeduplication: true,"
3452
+
3453
+ if "ttlDurationSeconds" in fields:
3454
+ types += "$ttlDurationSeconds: Int64,"
3455
+ values += "ttlDurationSeconds: $ttlDurationSeconds,"
3456
+
3457
+ query_template = """
3458
+ mutation CreateArtifact(
3459
+ $artifactTypeName: String!,
3460
+ $artifactCollectionNames: [String!],
3461
+ $entityName: String!,
3462
+ $projectName: String!,
3463
+ $runName: String,
3464
+ $description: String,
3465
+ $digest: String!,
3466
+ $aliases: [ArtifactAliasInput!],
3467
+ $metadata: JSONString,
3468
+ _CREATE_ARTIFACT_ADDITIONAL_TYPE_
3469
+ ) {
3470
+ createArtifact(input: {
3471
+ artifactTypeName: $artifactTypeName,
3472
+ artifactCollectionNames: $artifactCollectionNames,
3473
+ entityName: $entityName,
3474
+ projectName: $projectName,
3475
+ runName: $runName,
3476
+ description: $description,
3477
+ digest: $digest,
3478
+ digestAlgorithm: MANIFEST_MD5,
3479
+ aliases: $aliases,
3480
+ metadata: $metadata,
3481
+ _CREATE_ARTIFACT_ADDITIONAL_VALUE_
3482
+ }) {
3483
+ artifact {
3484
+ id
3485
+ state
3486
+ artifactSequence {
3487
+ id
3488
+ latestArtifact {
3489
+ id
3490
+ versionIndex
3491
+ }
3492
+ }
3493
+ }
3494
+ }
3495
+ }
3496
+ """
3497
+
3498
+ return query_template.replace(
3499
+ "_CREATE_ARTIFACT_ADDITIONAL_TYPE_", types
3500
+ ).replace("_CREATE_ARTIFACT_ADDITIONAL_VALUE_", values)
3501
+
3502
+ def create_artifact(
3503
+ self,
3504
+ artifact_type_name: str,
3505
+ artifact_collection_name: str,
3506
+ digest: str,
3507
+ client_id: Optional[str] = None,
3508
+ sequence_client_id: Optional[str] = None,
3509
+ entity_name: Optional[str] = None,
3510
+ project_name: Optional[str] = None,
3511
+ run_name: Optional[str] = None,
3512
+ description: Optional[str] = None,
3513
+ metadata: Optional[Dict] = None,
3514
+ ttl_duration_seconds: Optional[int] = None,
3515
+ aliases: Optional[List[Dict[str, str]]] = None,
3516
+ distributed_id: Optional[str] = None,
3517
+ is_user_created: Optional[bool] = False,
3518
+ history_step: Optional[int] = None,
3519
+ ) -> Tuple[Dict, Dict]:
3520
+ fields = self.server_create_artifact_introspection()
3521
+ artifact_fields = self.server_artifact_introspection()
3522
+ if "ttlIsInherited" not in artifact_fields and ttl_duration_seconds:
3523
+ wandb.termwarn(
3524
+ "Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL"
3525
+ )
3526
+ # ttlDurationSeconds is only usable if ttlIsInherited is also present
3527
+ ttl_duration_seconds = None
3528
+ query_template = self._get_create_artifact_mutation(
3529
+ fields, history_step, distributed_id
3530
+ )
3531
+
3532
+ entity_name = entity_name or self.settings("entity")
3533
+ project_name = project_name or self.settings("project")
3534
+ if not is_user_created:
3535
+ run_name = run_name or self.current_run_id
3536
+ if aliases is None:
3537
+ aliases = []
3538
+
3539
+ mutation = gql(query_template)
3540
+ response = self.gql(
3541
+ mutation,
3542
+ variable_values={
3543
+ "entityName": entity_name,
3544
+ "projectName": project_name,
3545
+ "runName": run_name,
3546
+ "artifactTypeName": artifact_type_name,
3547
+ "artifactCollectionNames": [artifact_collection_name],
3548
+ "clientID": client_id,
3549
+ "sequenceClientID": sequence_client_id,
3550
+ "digest": digest,
3551
+ "description": description,
3552
+ "aliases": [alias for alias in aliases],
3553
+ "metadata": json.dumps(util.make_safe_for_json(metadata))
3554
+ if metadata
3555
+ else None,
3556
+ "ttlDurationSeconds": ttl_duration_seconds,
3557
+ "distributedID": distributed_id,
3558
+ "historyStep": history_step,
3559
+ },
3560
+ )
3561
+ av = response["createArtifact"]["artifact"]
3562
+ latest = response["createArtifact"]["artifact"]["artifactSequence"].get(
3563
+ "latestArtifact"
3564
+ )
3565
+ return av, latest
3566
+
3567
+ def commit_artifact(self, artifact_id: str) -> "_Response":
3568
+ mutation = gql(
3569
+ """
3570
+ mutation CommitArtifact(
3571
+ $artifactID: ID!,
3572
+ ) {
3573
+ commitArtifact(input: {
3574
+ artifactID: $artifactID,
3575
+ }) {
3576
+ artifact {
3577
+ id
3578
+ digest
3579
+ }
3580
+ }
3581
+ }
3582
+ """
3583
+ )
3584
+
3585
+ response: _Response = self.gql(
3586
+ mutation,
3587
+ variable_values={"artifactID": artifact_id},
3588
+ timeout=60,
3589
+ )
3590
+ return response
3591
+
3592
+ def complete_multipart_upload_artifact(
3593
+ self,
3594
+ artifact_id: str,
3595
+ storage_path: str,
3596
+ completed_parts: List[Dict[str, Any]],
3597
+ upload_id: Optional[str],
3598
+ complete_multipart_action: str = "Complete",
3599
+ ) -> Optional[str]:
3600
+ mutation = gql(
3601
+ """
3602
+ mutation CompleteMultipartUploadArtifact(
3603
+ $completeMultipartAction: CompleteMultipartAction!,
3604
+ $completedParts: [UploadPartsInput!]!,
3605
+ $artifactID: ID!
3606
+ $storagePath: String!
3607
+ $uploadID: String!
3608
+ ) {
3609
+ completeMultipartUploadArtifact(
3610
+ input: {
3611
+ completeMultipartAction: $completeMultipartAction,
3612
+ completedParts: $completedParts,
3613
+ artifactID: $artifactID,
3614
+ storagePath: $storagePath
3615
+ uploadID: $uploadID
3616
+ }
3617
+ ) {
3618
+ digest
3619
+ }
3620
+ }
3621
+ """
3622
+ )
3623
+ response = self.gql(
3624
+ mutation,
3625
+ variable_values={
3626
+ "completeMultipartAction": complete_multipart_action,
3627
+ "artifactID": artifact_id,
3628
+ "storagePath": storage_path,
3629
+ "completedParts": completed_parts,
3630
+ "uploadID": upload_id,
3631
+ },
3632
+ )
3633
+ digest: Optional[str] = response["completeMultipartUploadArtifact"]["digest"]
3634
+ return digest
3635
+
3636
+ def create_artifact_manifest(
3637
+ self,
3638
+ name: str,
3639
+ digest: str,
3640
+ artifact_id: Optional[str],
3641
+ base_artifact_id: Optional[str] = None,
3642
+ entity: Optional[str] = None,
3643
+ project: Optional[str] = None,
3644
+ run: Optional[str] = None,
3645
+ include_upload: bool = True,
3646
+ type: str = "FULL",
3647
+ ) -> Tuple[str, Dict[str, Any]]:
3648
+ mutation = gql(
3649
+ """
3650
+ mutation CreateArtifactManifest(
3651
+ $name: String!,
3652
+ $digest: String!,
3653
+ $artifactID: ID!,
3654
+ $baseArtifactID: ID,
3655
+ $entityName: String!,
3656
+ $projectName: String!,
3657
+ $runName: String!,
3658
+ $includeUpload: Boolean!,
3659
+ {}
3660
+ ) {{
3661
+ createArtifactManifest(input: {{
3662
+ name: $name,
3663
+ digest: $digest,
3664
+ artifactID: $artifactID,
3665
+ baseArtifactID: $baseArtifactID,
3666
+ entityName: $entityName,
3667
+ projectName: $projectName,
3668
+ runName: $runName,
3669
+ {}
3670
+ }}) {{
3671
+ artifactManifest {{
3672
+ id
3673
+ file {{
3674
+ id
3675
+ name
3676
+ displayName
3677
+ uploadUrl @include(if: $includeUpload)
3678
+ uploadHeaders @include(if: $includeUpload)
3679
+ }}
3680
+ }}
3681
+ }}
3682
+ }}
3683
+ """.format(
3684
+ "$type: ArtifactManifestType = FULL" if type != "FULL" else "",
3685
+ "type: $type" if type != "FULL" else "",
3686
+ )
3687
+ )
3688
+
3689
+ entity_name = entity or self.settings("entity")
3690
+ project_name = project or self.settings("project")
3691
+ run_name = run or self.current_run_id
3692
+
3693
+ response = self.gql(
3694
+ mutation,
3695
+ variable_values={
3696
+ "name": name,
3697
+ "digest": digest,
3698
+ "artifactID": artifact_id,
3699
+ "baseArtifactID": base_artifact_id,
3700
+ "entityName": entity_name,
3701
+ "projectName": project_name,
3702
+ "runName": run_name,
3703
+ "includeUpload": include_upload,
3704
+ "type": type,
3705
+ },
3706
+ )
3707
+ return (
3708
+ response["createArtifactManifest"]["artifactManifest"]["id"],
3709
+ response["createArtifactManifest"]["artifactManifest"]["file"],
3710
+ )
3711
+
3712
+ def update_artifact_manifest(
3713
+ self,
3714
+ artifact_manifest_id: str,
3715
+ base_artifact_id: Optional[str] = None,
3716
+ digest: Optional[str] = None,
3717
+ include_upload: Optional[bool] = True,
3718
+ ) -> Tuple[str, Dict[str, Any]]:
3719
+ mutation = gql(
3720
+ """
3721
+ mutation UpdateArtifactManifest(
3722
+ $artifactManifestID: ID!,
3723
+ $digest: String,
3724
+ $baseArtifactID: ID,
3725
+ $includeUpload: Boolean!,
3726
+ ) {
3727
+ updateArtifactManifest(input: {
3728
+ artifactManifestID: $artifactManifestID,
3729
+ digest: $digest,
3730
+ baseArtifactID: $baseArtifactID,
3731
+ }) {
3732
+ artifactManifest {
3733
+ id
3734
+ file {
3735
+ id
3736
+ name
3737
+ displayName
3738
+ uploadUrl @include(if: $includeUpload)
3739
+ uploadHeaders @include(if: $includeUpload)
3740
+ }
3741
+ }
3742
+ }
3743
+ }
3744
+ """
3745
+ )
3746
+
3747
+ response = self.gql(
3748
+ mutation,
3749
+ variable_values={
3750
+ "artifactManifestID": artifact_manifest_id,
3751
+ "digest": digest,
3752
+ "baseArtifactID": base_artifact_id,
3753
+ "includeUpload": include_upload,
3754
+ },
3755
+ )
3756
+
3757
+ return (
3758
+ response["updateArtifactManifest"]["artifactManifest"]["id"],
3759
+ response["updateArtifactManifest"]["artifactManifest"]["file"],
3760
+ )
3761
+
3762
+ def _resolve_client_id(
3763
+ self,
3764
+ client_id: str,
3765
+ ) -> Optional[str]:
3766
+ if client_id in self._client_id_mapping:
3767
+ return self._client_id_mapping[client_id]
3768
+
3769
+ query = gql(
3770
+ """
3771
+ query ClientIDMapping($clientID: ID!) {
3772
+ clientIDMapping(clientID: $clientID) {
3773
+ serverID
3774
+ }
3775
+ }
3776
+ """
3777
+ )
3778
+ response = self.gql(
3779
+ query,
3780
+ variable_values={
3781
+ "clientID": client_id,
3782
+ },
3783
+ )
3784
+ server_id = None
3785
+ if response is not None:
3786
+ client_id_mapping = response.get("clientIDMapping")
3787
+ if client_id_mapping is not None:
3788
+ server_id = client_id_mapping.get("serverID")
3789
+ if server_id is not None:
3790
+ self._client_id_mapping[client_id] = server_id
3791
+ return server_id
3792
+
3793
+ def server_create_artifact_file_spec_input_introspection(self) -> List:
3794
+ query_string = """
3795
+ query ProbeServerCreateArtifactFileSpecInput {
3796
+ CreateArtifactFileSpecInputInfoType: __type(name:"CreateArtifactFileSpecInput") {
3797
+ inputFields{
3798
+ name
3799
+ }
3800
+ }
3801
+ }
3802
+ """
3803
+
3804
+ query = gql(query_string)
3805
+ res = self.gql(query)
3806
+ create_artifact_file_spec_input_info = [
3807
+ field.get("name", "")
3808
+ for field in res.get("CreateArtifactFileSpecInputInfoType", {}).get(
3809
+ "inputFields", [{}]
3810
+ )
3811
+ ]
3812
+ return create_artifact_file_spec_input_info
3813
+
3814
+ @normalize_exceptions
3815
+ def create_artifact_files(
3816
+ self, artifact_files: Iterable["CreateArtifactFileSpecInput"]
3817
+ ) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
3818
+ query_template = """
3819
+ mutation CreateArtifactFiles(
3820
+ $storageLayout: ArtifactStorageLayout!
3821
+ $artifactFiles: [CreateArtifactFileSpecInput!]!
3822
+ ) {
3823
+ createArtifactFiles(input: {
3824
+ artifactFiles: $artifactFiles,
3825
+ storageLayout: $storageLayout,
3826
+ }) {
3827
+ files {
3828
+ edges {
3829
+ node {
3830
+ id
3831
+ name
3832
+ displayName
3833
+ uploadUrl
3834
+ uploadHeaders
3835
+ _MULTIPART_UPLOAD_FIELDS_
3836
+ artifact {
3837
+ id
3838
+ }
3839
+ }
3840
+ }
3841
+ }
3842
+ }
3843
+ }
3844
+ """
3845
+ multipart_upload_url_query = """
3846
+ storagePath
3847
+ uploadMultipartUrls {
3848
+ uploadID
3849
+ uploadUrlParts {
3850
+ partNumber
3851
+ uploadUrl
3852
+ }
3853
+ }
3854
+ """
3855
+
3856
+ # TODO: we should use constants here from interface/artifacts.py
3857
+ # but probably don't want the dependency. We're going to remove
3858
+ # this setting in a future release, so I'm just hard-coding the strings.
3859
+ storage_layout = "V2"
3860
+ if env.get_use_v1_artifacts():
3861
+ storage_layout = "V1"
3862
+
3863
+ create_artifact_file_spec_input_fields = (
3864
+ self.server_create_artifact_file_spec_input_introspection()
3865
+ )
3866
+ if "uploadPartsInput" in create_artifact_file_spec_input_fields:
3867
+ query_template = query_template.replace(
3868
+ "_MULTIPART_UPLOAD_FIELDS_", multipart_upload_url_query
3869
+ )
3870
+ else:
3871
+ query_template = query_template.replace("_MULTIPART_UPLOAD_FIELDS_", "")
3872
+
3873
+ mutation = gql(query_template)
3874
+ response = self.gql(
3875
+ mutation,
3876
+ variable_values={
3877
+ "storageLayout": storage_layout,
3878
+ "artifactFiles": [af for af in artifact_files],
3879
+ },
3880
+ )
3881
+
3882
+ result = {}
3883
+ for edge in response["createArtifactFiles"]["files"]["edges"]:
3884
+ node = edge["node"]
3885
+ result[node["displayName"]] = node
3886
+ return result
3887
+
3888
+ @normalize_exceptions
3889
+ def notify_scriptable_run_alert(
3890
+ self,
3891
+ title: str,
3892
+ text: str,
3893
+ level: Optional[str] = None,
3894
+ wait_duration: Optional["Number"] = None,
3895
+ ) -> bool:
3896
+ mutation = gql(
3897
+ """
3898
+ mutation NotifyScriptableRunAlert(
3899
+ $entityName: String!,
3900
+ $projectName: String!,
3901
+ $runName: String!,
3902
+ $title: String!,
3903
+ $text: String!,
3904
+ $severity: AlertSeverity = INFO,
3905
+ $waitDuration: Duration
3906
+ ) {
3907
+ notifyScriptableRunAlert(input: {
3908
+ entityName: $entityName,
3909
+ projectName: $projectName,
3910
+ runName: $runName,
3911
+ title: $title,
3912
+ text: $text,
3913
+ severity: $severity,
3914
+ waitDuration: $waitDuration
3915
+ }) {
3916
+ success
3917
+ }
3918
+ }
3919
+ """
3920
+ )
3921
+
3922
+ response = self.gql(
3923
+ mutation,
3924
+ variable_values={
3925
+ "entityName": self.settings("entity"),
3926
+ "projectName": self.settings("project"),
3927
+ "runName": self.current_run_id,
3928
+ "title": title,
3929
+ "text": text,
3930
+ "severity": level,
3931
+ "waitDuration": wait_duration,
3932
+ },
3933
+ )
3934
+ success: bool = response["notifyScriptableRunAlert"]["success"]
3935
+ return success
3936
+
3937
+ def get_sweep_state(
3938
+ self, sweep: str, entity: Optional[str] = None, project: Optional[str] = None
3939
+ ) -> "SweepState":
3940
+ state: SweepState = self.sweep(
3941
+ sweep=sweep, entity=entity, project=project, specs="{}"
3942
+ )["state"]
3943
+ return state
3944
+
3945
+ def set_sweep_state(
3946
+ self,
3947
+ sweep: str,
3948
+ state: "SweepState",
3949
+ entity: Optional[str] = None,
3950
+ project: Optional[str] = None,
3951
+ ) -> None:
3952
+ assert state in ("RUNNING", "PAUSED", "CANCELED", "FINISHED")
3953
+ s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
3954
+ curr_state = s["state"].upper()
3955
+ if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
3956
+ raise Exception("Cannot pause %s sweep." % curr_state.lower())
3957
+ elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
3958
+ raise Exception("Sweep already %s." % curr_state.lower())
3959
+ sweep_id = s["id"]
3960
+ mutation = gql(
3961
+ """
3962
+ mutation UpsertSweep(
3963
+ $id: ID,
3964
+ $state: String,
3965
+ $entityName: String,
3966
+ $projectName: String
3967
+ ) {
3968
+ upsertSweep(input: {
3969
+ id: $id,
3970
+ state: $state,
3971
+ entityName: $entityName,
3972
+ projectName: $projectName
3973
+ }){
3974
+ sweep {
3975
+ name
3976
+ }
3977
+ }
3978
+ }
3979
+ """
3980
+ )
3981
+ self.gql(
3982
+ mutation,
3983
+ variable_values={
3984
+ "id": sweep_id,
3985
+ "state": state,
3986
+ "entityName": entity or self.settings("entity"),
3987
+ "projectName": project or self.settings("project"),
3988
+ },
3989
+ )
3990
+
3991
+ def stop_sweep(
3992
+ self,
3993
+ sweep: str,
3994
+ entity: Optional[str] = None,
3995
+ project: Optional[str] = None,
3996
+ ) -> None:
3997
+ """Finish the sweep to stop running new runs and let currently running runs finish."""
3998
+ self.set_sweep_state(
3999
+ sweep=sweep, state="FINISHED", entity=entity, project=project
4000
+ )
4001
+
4002
+ def cancel_sweep(
4003
+ self,
4004
+ sweep: str,
4005
+ entity: Optional[str] = None,
4006
+ project: Optional[str] = None,
4007
+ ) -> None:
4008
+ """Cancel the sweep to kill all running runs and stop running new runs."""
4009
+ self.set_sweep_state(
4010
+ sweep=sweep, state="CANCELED", entity=entity, project=project
4011
+ )
4012
+
4013
+ def pause_sweep(
4014
+ self,
4015
+ sweep: str,
4016
+ entity: Optional[str] = None,
4017
+ project: Optional[str] = None,
4018
+ ) -> None:
4019
+ """Pause the sweep to temporarily stop running new runs."""
4020
+ self.set_sweep_state(
4021
+ sweep=sweep, state="PAUSED", entity=entity, project=project
4022
+ )
4023
+
4024
+ def resume_sweep(
4025
+ self,
4026
+ sweep: str,
4027
+ entity: Optional[str] = None,
4028
+ project: Optional[str] = None,
4029
+ ) -> None:
4030
+ """Resume the sweep to continue running new runs."""
4031
+ self.set_sweep_state(
4032
+ sweep=sweep, state="RUNNING", entity=entity, project=project
4033
+ )
4034
+
4035
+ def _status_request(self, url: str, length: int) -> requests.Response:
4036
+ """Ask google how much we've uploaded."""
4037
+ check_httpclient_logger_handler()
4038
+ return requests.put(
4039
+ url=url,
4040
+ headers={"Content-Length": "0", "Content-Range": "bytes */%i" % length},
4041
+ )
4042
+
4043
+ def _flatten_edges(self, response: "_Response") -> List[Dict]:
4044
+ """Return an array from the nested graphql relay structure."""
4045
+ return [node["node"] for node in response["edges"]]
4046
+
4047
+ @normalize_exceptions
4048
+ def stop_run(
4049
+ self,
4050
+ run_id: str,
4051
+ ) -> bool:
4052
+ mutation = gql(
4053
+ """
4054
+ mutation stopRun($id: ID!) {
4055
+ stopRun(input: {
4056
+ id: $id
4057
+ }) {
4058
+ clientMutationId
4059
+ success
4060
+ }
4061
+ }
4062
+ """
4063
+ )
4064
+
4065
+ response = self.gql(
4066
+ mutation,
4067
+ variable_values={
4068
+ "id": run_id,
4069
+ },
4070
+ )
4071
+
4072
+ success: bool = response["stopRun"].get("success")
4073
+
4074
+ return success