wandb 0.18.2__py3-none-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (827) hide show
  1. package_readme.md +89 -0
  2. wandb/__init__.py +245 -0
  3. wandb/__init__.pyi +1139 -0
  4. wandb/__main__.py +3 -0
  5. wandb/_globals.py +19 -0
  6. wandb/agents/__init__.py +0 -0
  7. wandb/agents/pyagent.py +363 -0
  8. wandb/analytics/__init__.py +3 -0
  9. wandb/analytics/sentry.py +266 -0
  10. wandb/apis/__init__.py +48 -0
  11. wandb/apis/attrs.py +40 -0
  12. wandb/apis/importers/__init__.py +1 -0
  13. wandb/apis/importers/internals/internal.py +385 -0
  14. wandb/apis/importers/internals/protocols.py +99 -0
  15. wandb/apis/importers/internals/util.py +78 -0
  16. wandb/apis/importers/mlflow.py +254 -0
  17. wandb/apis/importers/validation.py +108 -0
  18. wandb/apis/importers/wandb.py +1603 -0
  19. wandb/apis/internal.py +232 -0
  20. wandb/apis/normalize.py +89 -0
  21. wandb/apis/paginator.py +81 -0
  22. wandb/apis/public/__init__.py +34 -0
  23. wandb/apis/public/api.py +1305 -0
  24. wandb/apis/public/artifacts.py +1090 -0
  25. wandb/apis/public/const.py +4 -0
  26. wandb/apis/public/files.py +195 -0
  27. wandb/apis/public/history.py +149 -0
  28. wandb/apis/public/jobs.py +659 -0
  29. wandb/apis/public/projects.py +154 -0
  30. wandb/apis/public/query_generator.py +166 -0
  31. wandb/apis/public/reports.py +469 -0
  32. wandb/apis/public/runs.py +914 -0
  33. wandb/apis/public/sweeps.py +240 -0
  34. wandb/apis/public/teams.py +198 -0
  35. wandb/apis/public/users.py +136 -0
  36. wandb/apis/reports/__init__.py +1 -0
  37. wandb/apis/reports/v1/__init__.py +8 -0
  38. wandb/apis/reports/v2/__init__.py +8 -0
  39. wandb/apis/workspaces/__init__.py +8 -0
  40. wandb/beta/workflows.py +288 -0
  41. wandb/bin/nvidia_gpu_stats +0 -0
  42. wandb/bin/wandb-core +0 -0
  43. wandb/cli/__init__.py +0 -0
  44. wandb/cli/cli.py +3004 -0
  45. wandb/data_types.py +63 -0
  46. wandb/docker/__init__.py +342 -0
  47. wandb/docker/auth.py +436 -0
  48. wandb/docker/wandb-entrypoint.sh +33 -0
  49. wandb/docker/www_authenticate.py +94 -0
  50. wandb/env.py +514 -0
  51. wandb/errors/__init__.py +17 -0
  52. wandb/errors/errors.py +37 -0
  53. wandb/errors/term.py +103 -0
  54. wandb/errors/util.py +57 -0
  55. wandb/errors/warnings.py +2 -0
  56. wandb/filesync/__init__.py +0 -0
  57. wandb/filesync/dir_watcher.py +403 -0
  58. wandb/filesync/stats.py +100 -0
  59. wandb/filesync/step_checksum.py +142 -0
  60. wandb/filesync/step_prepare.py +179 -0
  61. wandb/filesync/step_upload.py +290 -0
  62. wandb/filesync/upload_job.py +142 -0
  63. wandb/integration/__init__.py +0 -0
  64. wandb/integration/catboost/__init__.py +5 -0
  65. wandb/integration/catboost/catboost.py +178 -0
  66. wandb/integration/cohere/__init__.py +3 -0
  67. wandb/integration/cohere/cohere.py +21 -0
  68. wandb/integration/cohere/resolver.py +347 -0
  69. wandb/integration/diffusers/__init__.py +3 -0
  70. wandb/integration/diffusers/autologger.py +76 -0
  71. wandb/integration/diffusers/pipeline_resolver.py +50 -0
  72. wandb/integration/diffusers/resolvers/__init__.py +9 -0
  73. wandb/integration/diffusers/resolvers/multimodal.py +882 -0
  74. wandb/integration/diffusers/resolvers/utils.py +102 -0
  75. wandb/integration/fastai/__init__.py +249 -0
  76. wandb/integration/gym/__init__.py +105 -0
  77. wandb/integration/huggingface/__init__.py +3 -0
  78. wandb/integration/huggingface/huggingface.py +18 -0
  79. wandb/integration/huggingface/resolver.py +213 -0
  80. wandb/integration/keras/__init__.py +11 -0
  81. wandb/integration/keras/callbacks/__init__.py +5 -0
  82. wandb/integration/keras/callbacks/metrics_logger.py +136 -0
  83. wandb/integration/keras/callbacks/model_checkpoint.py +195 -0
  84. wandb/integration/keras/callbacks/tables_builder.py +226 -0
  85. wandb/integration/keras/keras.py +1091 -0
  86. wandb/integration/kfp/__init__.py +6 -0
  87. wandb/integration/kfp/helpers.py +28 -0
  88. wandb/integration/kfp/kfp_patch.py +324 -0
  89. wandb/integration/kfp/wandb_logging.py +182 -0
  90. wandb/integration/langchain/__init__.py +3 -0
  91. wandb/integration/langchain/wandb_tracer.py +48 -0
  92. wandb/integration/lightgbm/__init__.py +239 -0
  93. wandb/integration/lightning/__init__.py +0 -0
  94. wandb/integration/lightning/fabric/__init__.py +3 -0
  95. wandb/integration/lightning/fabric/logger.py +762 -0
  96. wandb/integration/magic.py +556 -0
  97. wandb/integration/metaflow/__init__.py +3 -0
  98. wandb/integration/metaflow/metaflow.py +383 -0
  99. wandb/integration/openai/__init__.py +3 -0
  100. wandb/integration/openai/fine_tuning.py +480 -0
  101. wandb/integration/openai/openai.py +22 -0
  102. wandb/integration/openai/resolver.py +240 -0
  103. wandb/integration/prodigy/__init__.py +3 -0
  104. wandb/integration/prodigy/prodigy.py +299 -0
  105. wandb/integration/sacred/__init__.py +117 -0
  106. wandb/integration/sagemaker/__init__.py +12 -0
  107. wandb/integration/sagemaker/auth.py +28 -0
  108. wandb/integration/sagemaker/config.py +49 -0
  109. wandb/integration/sagemaker/files.py +3 -0
  110. wandb/integration/sagemaker/resources.py +34 -0
  111. wandb/integration/sb3/__init__.py +3 -0
  112. wandb/integration/sb3/sb3.py +153 -0
  113. wandb/integration/sklearn/__init__.py +37 -0
  114. wandb/integration/sklearn/calculate/__init__.py +32 -0
  115. wandb/integration/sklearn/calculate/calibration_curves.py +125 -0
  116. wandb/integration/sklearn/calculate/class_proportions.py +68 -0
  117. wandb/integration/sklearn/calculate/confusion_matrix.py +93 -0
  118. wandb/integration/sklearn/calculate/decision_boundaries.py +40 -0
  119. wandb/integration/sklearn/calculate/elbow_curve.py +55 -0
  120. wandb/integration/sklearn/calculate/feature_importances.py +67 -0
  121. wandb/integration/sklearn/calculate/learning_curve.py +64 -0
  122. wandb/integration/sklearn/calculate/outlier_candidates.py +69 -0
  123. wandb/integration/sklearn/calculate/residuals.py +86 -0
  124. wandb/integration/sklearn/calculate/silhouette.py +118 -0
  125. wandb/integration/sklearn/calculate/summary_metrics.py +62 -0
  126. wandb/integration/sklearn/plot/__init__.py +35 -0
  127. wandb/integration/sklearn/plot/classifier.py +329 -0
  128. wandb/integration/sklearn/plot/clusterer.py +146 -0
  129. wandb/integration/sklearn/plot/regressor.py +121 -0
  130. wandb/integration/sklearn/plot/shared.py +91 -0
  131. wandb/integration/sklearn/utils.py +183 -0
  132. wandb/integration/tensorboard/__init__.py +10 -0
  133. wandb/integration/tensorboard/log.py +355 -0
  134. wandb/integration/tensorboard/monkeypatch.py +185 -0
  135. wandb/integration/tensorflow/__init__.py +5 -0
  136. wandb/integration/tensorflow/estimator_hook.py +54 -0
  137. wandb/integration/torch/__init__.py +0 -0
  138. wandb/integration/torch/wandb_torch.py +554 -0
  139. wandb/integration/ultralytics/__init__.py +11 -0
  140. wandb/integration/ultralytics/bbox_utils.py +208 -0
  141. wandb/integration/ultralytics/callback.py +524 -0
  142. wandb/integration/ultralytics/classification_utils.py +83 -0
  143. wandb/integration/ultralytics/mask_utils.py +202 -0
  144. wandb/integration/ultralytics/pose_utils.py +103 -0
  145. wandb/integration/xgboost/__init__.py +11 -0
  146. wandb/integration/xgboost/xgboost.py +189 -0
  147. wandb/integration/yolov8/__init__.py +0 -0
  148. wandb/integration/yolov8/yolov8.py +284 -0
  149. wandb/jupyter.py +515 -0
  150. wandb/magic.py +3 -0
  151. wandb/mpmain/__init__.py +0 -0
  152. wandb/mpmain/__main__.py +1 -0
  153. wandb/old/__init__.py +0 -0
  154. wandb/old/core.py +53 -0
  155. wandb/old/settings.py +173 -0
  156. wandb/old/summary.py +440 -0
  157. wandb/plot/__init__.py +19 -0
  158. wandb/plot/bar.py +45 -0
  159. wandb/plot/confusion_matrix.py +100 -0
  160. wandb/plot/histogram.py +39 -0
  161. wandb/plot/line.py +43 -0
  162. wandb/plot/line_series.py +88 -0
  163. wandb/plot/pr_curve.py +136 -0
  164. wandb/plot/roc_curve.py +118 -0
  165. wandb/plot/scatter.py +32 -0
  166. wandb/plot/utils.py +183 -0
  167. wandb/plot/viz.py +123 -0
  168. wandb/proto/__init__.py +0 -0
  169. wandb/proto/v3/__init__.py +0 -0
  170. wandb/proto/v3/wandb_base_pb2.py +55 -0
  171. wandb/proto/v3/wandb_internal_pb2.py +1608 -0
  172. wandb/proto/v3/wandb_server_pb2.py +208 -0
  173. wandb/proto/v3/wandb_settings_pb2.py +112 -0
  174. wandb/proto/v3/wandb_telemetry_pb2.py +106 -0
  175. wandb/proto/v4/__init__.py +0 -0
  176. wandb/proto/v4/wandb_base_pb2.py +30 -0
  177. wandb/proto/v4/wandb_internal_pb2.py +360 -0
  178. wandb/proto/v4/wandb_server_pb2.py +63 -0
  179. wandb/proto/v4/wandb_settings_pb2.py +45 -0
  180. wandb/proto/v4/wandb_telemetry_pb2.py +41 -0
  181. wandb/proto/v5/wandb_base_pb2.py +31 -0
  182. wandb/proto/v5/wandb_internal_pb2.py +361 -0
  183. wandb/proto/v5/wandb_server_pb2.py +64 -0
  184. wandb/proto/v5/wandb_settings_pb2.py +46 -0
  185. wandb/proto/v5/wandb_telemetry_pb2.py +42 -0
  186. wandb/proto/wandb_base_pb2.py +10 -0
  187. wandb/proto/wandb_deprecated.py +53 -0
  188. wandb/proto/wandb_generate_deprecated.py +34 -0
  189. wandb/proto/wandb_generate_proto.py +49 -0
  190. wandb/proto/wandb_internal_pb2.py +16 -0
  191. wandb/proto/wandb_server_pb2.py +10 -0
  192. wandb/proto/wandb_settings_pb2.py +10 -0
  193. wandb/proto/wandb_telemetry_pb2.py +10 -0
  194. wandb/py.typed +0 -0
  195. wandb/sdk/__init__.py +37 -0
  196. wandb/sdk/artifacts/__init__.py +0 -0
  197. wandb/sdk/artifacts/_validators.py +90 -0
  198. wandb/sdk/artifacts/artifact.py +2389 -0
  199. wandb/sdk/artifacts/artifact_download_logger.py +43 -0
  200. wandb/sdk/artifacts/artifact_file_cache.py +253 -0
  201. wandb/sdk/artifacts/artifact_instance_cache.py +17 -0
  202. wandb/sdk/artifacts/artifact_manifest.py +74 -0
  203. wandb/sdk/artifacts/artifact_manifest_entry.py +249 -0
  204. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  205. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +92 -0
  206. wandb/sdk/artifacts/artifact_saver.py +269 -0
  207. wandb/sdk/artifacts/artifact_state.py +11 -0
  208. wandb/sdk/artifacts/artifact_ttl.py +7 -0
  209. wandb/sdk/artifacts/exceptions.py +57 -0
  210. wandb/sdk/artifacts/staging.py +25 -0
  211. wandb/sdk/artifacts/storage_handler.py +62 -0
  212. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  213. wandb/sdk/artifacts/storage_handlers/azure_handler.py +208 -0
  214. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +228 -0
  215. wandb/sdk/artifacts/storage_handlers/http_handler.py +114 -0
  216. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +141 -0
  217. wandb/sdk/artifacts/storage_handlers/multi_handler.py +56 -0
  218. wandb/sdk/artifacts/storage_handlers/s3_handler.py +300 -0
  219. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +72 -0
  220. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +135 -0
  221. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +74 -0
  222. wandb/sdk/artifacts/storage_layout.py +6 -0
  223. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  224. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  225. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +378 -0
  226. wandb/sdk/artifacts/storage_policy.py +72 -0
  227. wandb/sdk/backend/__init__.py +0 -0
  228. wandb/sdk/backend/backend.py +222 -0
  229. wandb/sdk/data_types/__init__.py +0 -0
  230. wandb/sdk/data_types/_dtypes.py +914 -0
  231. wandb/sdk/data_types/_private.py +10 -0
  232. wandb/sdk/data_types/audio.py +165 -0
  233. wandb/sdk/data_types/base_types/__init__.py +0 -0
  234. wandb/sdk/data_types/base_types/json_metadata.py +55 -0
  235. wandb/sdk/data_types/base_types/media.py +315 -0
  236. wandb/sdk/data_types/base_types/wb_value.py +272 -0
  237. wandb/sdk/data_types/bokeh.py +70 -0
  238. wandb/sdk/data_types/graph.py +405 -0
  239. wandb/sdk/data_types/helper_types/__init__.py +0 -0
  240. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +295 -0
  241. wandb/sdk/data_types/helper_types/classes.py +159 -0
  242. wandb/sdk/data_types/helper_types/image_mask.py +235 -0
  243. wandb/sdk/data_types/histogram.py +96 -0
  244. wandb/sdk/data_types/html.py +115 -0
  245. wandb/sdk/data_types/image.py +845 -0
  246. wandb/sdk/data_types/molecule.py +241 -0
  247. wandb/sdk/data_types/object_3d.py +474 -0
  248. wandb/sdk/data_types/plotly.py +82 -0
  249. wandb/sdk/data_types/saved_model.py +446 -0
  250. wandb/sdk/data_types/table.py +1204 -0
  251. wandb/sdk/data_types/trace_tree.py +438 -0
  252. wandb/sdk/data_types/utils.py +229 -0
  253. wandb/sdk/data_types/video.py +247 -0
  254. wandb/sdk/integration_utils/__init__.py +0 -0
  255. wandb/sdk/integration_utils/auto_logging.py +239 -0
  256. wandb/sdk/integration_utils/data_logging.py +475 -0
  257. wandb/sdk/interface/__init__.py +0 -0
  258. wandb/sdk/interface/constants.py +4 -0
  259. wandb/sdk/interface/interface.py +972 -0
  260. wandb/sdk/interface/interface_queue.py +59 -0
  261. wandb/sdk/interface/interface_relay.py +53 -0
  262. wandb/sdk/interface/interface_shared.py +537 -0
  263. wandb/sdk/interface/interface_sock.py +61 -0
  264. wandb/sdk/interface/message_future.py +27 -0
  265. wandb/sdk/interface/message_future_poll.py +50 -0
  266. wandb/sdk/interface/router.py +118 -0
  267. wandb/sdk/interface/router_queue.py +44 -0
  268. wandb/sdk/interface/router_relay.py +39 -0
  269. wandb/sdk/interface/router_sock.py +36 -0
  270. wandb/sdk/interface/summary_record.py +67 -0
  271. wandb/sdk/internal/__init__.py +0 -0
  272. wandb/sdk/internal/context.py +89 -0
  273. wandb/sdk/internal/datastore.py +297 -0
  274. wandb/sdk/internal/file_pusher.py +181 -0
  275. wandb/sdk/internal/file_stream.py +695 -0
  276. wandb/sdk/internal/flow_control.py +263 -0
  277. wandb/sdk/internal/handler.py +901 -0
  278. wandb/sdk/internal/internal.py +417 -0
  279. wandb/sdk/internal/internal_api.py +4358 -0
  280. wandb/sdk/internal/internal_util.py +100 -0
  281. wandb/sdk/internal/job_builder.py +629 -0
  282. wandb/sdk/internal/profiler.py +78 -0
  283. wandb/sdk/internal/progress.py +83 -0
  284. wandb/sdk/internal/run.py +25 -0
  285. wandb/sdk/internal/sample.py +70 -0
  286. wandb/sdk/internal/sender.py +1686 -0
  287. wandb/sdk/internal/sender_config.py +197 -0
  288. wandb/sdk/internal/settings_static.py +90 -0
  289. wandb/sdk/internal/system/__init__.py +0 -0
  290. wandb/sdk/internal/system/assets/__init__.py +27 -0
  291. wandb/sdk/internal/system/assets/aggregators.py +37 -0
  292. wandb/sdk/internal/system/assets/asset_registry.py +20 -0
  293. wandb/sdk/internal/system/assets/cpu.py +163 -0
  294. wandb/sdk/internal/system/assets/disk.py +210 -0
  295. wandb/sdk/internal/system/assets/gpu.py +416 -0
  296. wandb/sdk/internal/system/assets/gpu_amd.py +239 -0
  297. wandb/sdk/internal/system/assets/gpu_apple.py +177 -0
  298. wandb/sdk/internal/system/assets/interfaces.py +207 -0
  299. wandb/sdk/internal/system/assets/ipu.py +177 -0
  300. wandb/sdk/internal/system/assets/memory.py +166 -0
  301. wandb/sdk/internal/system/assets/network.py +125 -0
  302. wandb/sdk/internal/system/assets/open_metrics.py +299 -0
  303. wandb/sdk/internal/system/assets/tpu.py +154 -0
  304. wandb/sdk/internal/system/assets/trainium.py +399 -0
  305. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  306. wandb/sdk/internal/system/system_info.py +249 -0
  307. wandb/sdk/internal/system/system_monitor.py +229 -0
  308. wandb/sdk/internal/tb_watcher.py +518 -0
  309. wandb/sdk/internal/thread_local_settings.py +18 -0
  310. wandb/sdk/internal/writer.py +206 -0
  311. wandb/sdk/launch/__init__.py +14 -0
  312. wandb/sdk/launch/_launch.py +330 -0
  313. wandb/sdk/launch/_launch_add.py +255 -0
  314. wandb/sdk/launch/_project_spec.py +566 -0
  315. wandb/sdk/launch/agent/__init__.py +5 -0
  316. wandb/sdk/launch/agent/agent.py +924 -0
  317. wandb/sdk/launch/agent/config.py +296 -0
  318. wandb/sdk/launch/agent/job_status_tracker.py +53 -0
  319. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  320. wandb/sdk/launch/builder/__init__.py +0 -0
  321. wandb/sdk/launch/builder/abstract.py +156 -0
  322. wandb/sdk/launch/builder/build.py +297 -0
  323. wandb/sdk/launch/builder/context_manager.py +235 -0
  324. wandb/sdk/launch/builder/docker_builder.py +177 -0
  325. wandb/sdk/launch/builder/kaniko_builder.py +595 -0
  326. wandb/sdk/launch/builder/noop.py +58 -0
  327. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +188 -0
  328. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  329. wandb/sdk/launch/create_job.py +528 -0
  330. wandb/sdk/launch/environment/abstract.py +29 -0
  331. wandb/sdk/launch/environment/aws_environment.py +322 -0
  332. wandb/sdk/launch/environment/azure_environment.py +105 -0
  333. wandb/sdk/launch/environment/gcp_environment.py +335 -0
  334. wandb/sdk/launch/environment/local_environment.py +66 -0
  335. wandb/sdk/launch/errors.py +19 -0
  336. wandb/sdk/launch/git_reference.py +109 -0
  337. wandb/sdk/launch/inputs/files.py +148 -0
  338. wandb/sdk/launch/inputs/internal.py +315 -0
  339. wandb/sdk/launch/inputs/manage.py +113 -0
  340. wandb/sdk/launch/inputs/schema.py +39 -0
  341. wandb/sdk/launch/loader.py +249 -0
  342. wandb/sdk/launch/registry/abstract.py +48 -0
  343. wandb/sdk/launch/registry/anon.py +29 -0
  344. wandb/sdk/launch/registry/azure_container_registry.py +124 -0
  345. wandb/sdk/launch/registry/elastic_container_registry.py +192 -0
  346. wandb/sdk/launch/registry/google_artifact_registry.py +219 -0
  347. wandb/sdk/launch/registry/local_registry.py +67 -0
  348. wandb/sdk/launch/runner/__init__.py +0 -0
  349. wandb/sdk/launch/runner/abstract.py +195 -0
  350. wandb/sdk/launch/runner/kubernetes_monitor.py +474 -0
  351. wandb/sdk/launch/runner/kubernetes_runner.py +963 -0
  352. wandb/sdk/launch/runner/local_container.py +301 -0
  353. wandb/sdk/launch/runner/local_process.py +78 -0
  354. wandb/sdk/launch/runner/sagemaker_runner.py +426 -0
  355. wandb/sdk/launch/runner/vertex_runner.py +230 -0
  356. wandb/sdk/launch/sweeps/__init__.py +39 -0
  357. wandb/sdk/launch/sweeps/scheduler.py +742 -0
  358. wandb/sdk/launch/sweeps/scheduler_sweep.py +91 -0
  359. wandb/sdk/launch/sweeps/utils.py +316 -0
  360. wandb/sdk/launch/utils.py +746 -0
  361. wandb/sdk/launch/wandb_reference.py +138 -0
  362. wandb/sdk/lib/__init__.py +5 -0
  363. wandb/sdk/lib/_settings_toposort_generate.py +159 -0
  364. wandb/sdk/lib/_settings_toposort_generated.py +250 -0
  365. wandb/sdk/lib/_wburls_generate.py +25 -0
  366. wandb/sdk/lib/_wburls_generated.py +22 -0
  367. wandb/sdk/lib/apikey.py +273 -0
  368. wandb/sdk/lib/capped_dict.py +26 -0
  369. wandb/sdk/lib/config_util.py +101 -0
  370. wandb/sdk/lib/credentials.py +141 -0
  371. wandb/sdk/lib/deprecate.py +42 -0
  372. wandb/sdk/lib/disabled.py +29 -0
  373. wandb/sdk/lib/exit_hooks.py +54 -0
  374. wandb/sdk/lib/file_stream_utils.py +118 -0
  375. wandb/sdk/lib/filenames.py +64 -0
  376. wandb/sdk/lib/filesystem.py +372 -0
  377. wandb/sdk/lib/fsm.py +174 -0
  378. wandb/sdk/lib/gitlib.py +239 -0
  379. wandb/sdk/lib/gql_request.py +65 -0
  380. wandb/sdk/lib/handler_util.py +21 -0
  381. wandb/sdk/lib/hashutil.py +84 -0
  382. wandb/sdk/lib/import_hooks.py +275 -0
  383. wandb/sdk/lib/ipython.py +146 -0
  384. wandb/sdk/lib/json_util.py +80 -0
  385. wandb/sdk/lib/lazyloader.py +63 -0
  386. wandb/sdk/lib/mailbox.py +460 -0
  387. wandb/sdk/lib/module.py +69 -0
  388. wandb/sdk/lib/paths.py +106 -0
  389. wandb/sdk/lib/preinit.py +42 -0
  390. wandb/sdk/lib/printer.py +313 -0
  391. wandb/sdk/lib/proto_util.py +90 -0
  392. wandb/sdk/lib/redirect.py +845 -0
  393. wandb/sdk/lib/reporting.py +99 -0
  394. wandb/sdk/lib/retry.py +289 -0
  395. wandb/sdk/lib/run_moment.py +78 -0
  396. wandb/sdk/lib/runid.py +12 -0
  397. wandb/sdk/lib/server.py +52 -0
  398. wandb/sdk/lib/service_connection.py +216 -0
  399. wandb/sdk/lib/service_token.py +94 -0
  400. wandb/sdk/lib/sock_client.py +295 -0
  401. wandb/sdk/lib/sparkline.py +45 -0
  402. wandb/sdk/lib/telemetry.py +100 -0
  403. wandb/sdk/lib/timed_input.py +133 -0
  404. wandb/sdk/lib/timer.py +19 -0
  405. wandb/sdk/lib/tracelog.py +255 -0
  406. wandb/sdk/lib/wburls.py +46 -0
  407. wandb/sdk/service/__init__.py +0 -0
  408. wandb/sdk/service/_startup_debug.py +22 -0
  409. wandb/sdk/service/port_file.py +53 -0
  410. wandb/sdk/service/server.py +116 -0
  411. wandb/sdk/service/server_sock.py +276 -0
  412. wandb/sdk/service/service.py +242 -0
  413. wandb/sdk/service/streams.py +417 -0
  414. wandb/sdk/verify/__init__.py +0 -0
  415. wandb/sdk/verify/verify.py +501 -0
  416. wandb/sdk/wandb_alerts.py +12 -0
  417. wandb/sdk/wandb_config.py +322 -0
  418. wandb/sdk/wandb_helper.py +54 -0
  419. wandb/sdk/wandb_init.py +1266 -0
  420. wandb/sdk/wandb_login.py +349 -0
  421. wandb/sdk/wandb_metric.py +110 -0
  422. wandb/sdk/wandb_require.py +97 -0
  423. wandb/sdk/wandb_require_helpers.py +44 -0
  424. wandb/sdk/wandb_run.py +4236 -0
  425. wandb/sdk/wandb_settings.py +2001 -0
  426. wandb/sdk/wandb_setup.py +409 -0
  427. wandb/sdk/wandb_summary.py +150 -0
  428. wandb/sdk/wandb_sweep.py +119 -0
  429. wandb/sdk/wandb_sync.py +81 -0
  430. wandb/sdk/wandb_watch.py +144 -0
  431. wandb/sklearn.py +35 -0
  432. wandb/sync/__init__.py +3 -0
  433. wandb/sync/sync.py +443 -0
  434. wandb/trigger.py +29 -0
  435. wandb/util.py +1956 -0
  436. wandb/vendor/__init__.py +0 -0
  437. wandb/vendor/gql-0.2.0/setup.py +40 -0
  438. wandb/vendor/gql-0.2.0/tests/__init__.py +0 -0
  439. wandb/vendor/gql-0.2.0/tests/starwars/__init__.py +0 -0
  440. wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py +96 -0
  441. wandb/vendor/gql-0.2.0/tests/starwars/schema.py +146 -0
  442. wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py +293 -0
  443. wandb/vendor/gql-0.2.0/tests/starwars/test_query.py +355 -0
  444. wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py +171 -0
  445. wandb/vendor/gql-0.2.0/tests/test_client.py +31 -0
  446. wandb/vendor/gql-0.2.0/tests/test_transport.py +89 -0
  447. wandb/vendor/gql-0.2.0/wandb_gql/__init__.py +4 -0
  448. wandb/vendor/gql-0.2.0/wandb_gql/client.py +75 -0
  449. wandb/vendor/gql-0.2.0/wandb_gql/dsl.py +152 -0
  450. wandb/vendor/gql-0.2.0/wandb_gql/gql.py +10 -0
  451. wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py +0 -0
  452. wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py +6 -0
  453. wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py +15 -0
  454. wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py +46 -0
  455. wandb/vendor/gql-0.2.0/wandb_gql/utils.py +21 -0
  456. wandb/vendor/graphql-core-1.1/setup.py +86 -0
  457. wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py +287 -0
  458. wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py +6 -0
  459. wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py +42 -0
  460. wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py +11 -0
  461. wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py +29 -0
  462. wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py +36 -0
  463. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py +26 -0
  464. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py +311 -0
  465. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py +398 -0
  466. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py +0 -0
  467. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py +53 -0
  468. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py +22 -0
  469. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py +32 -0
  470. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py +7 -0
  471. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py +35 -0
  472. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py +6 -0
  473. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py +0 -0
  474. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py +66 -0
  475. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py +252 -0
  476. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py +151 -0
  477. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py +7 -0
  478. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py +57 -0
  479. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py +145 -0
  480. wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py +60 -0
  481. wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py +0 -0
  482. wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py +1349 -0
  483. wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py +19 -0
  484. wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py +435 -0
  485. wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py +30 -0
  486. wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py +779 -0
  487. wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py +193 -0
  488. wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py +18 -0
  489. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py +222 -0
  490. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py +82 -0
  491. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py +0 -0
  492. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py +17 -0
  493. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py +28 -0
  494. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py +40 -0
  495. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py +8 -0
  496. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py +43 -0
  497. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py +78 -0
  498. wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py +67 -0
  499. wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py +619 -0
  500. wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py +132 -0
  501. wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py +440 -0
  502. wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py +131 -0
  503. wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py +100 -0
  504. wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py +145 -0
  505. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py +0 -0
  506. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py +9 -0
  507. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py +65 -0
  508. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py +49 -0
  509. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py +24 -0
  510. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py +75 -0
  511. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py +291 -0
  512. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py +250 -0
  513. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py +9 -0
  514. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py +357 -0
  515. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py +27 -0
  516. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py +21 -0
  517. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py +90 -0
  518. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py +67 -0
  519. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py +66 -0
  520. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py +21 -0
  521. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py +168 -0
  522. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py +56 -0
  523. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py +69 -0
  524. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py +21 -0
  525. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py +149 -0
  526. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py +69 -0
  527. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py +4 -0
  528. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py +79 -0
  529. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py +24 -0
  530. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py +8 -0
  531. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py +44 -0
  532. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py +113 -0
  533. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py +33 -0
  534. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py +70 -0
  535. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py +97 -0
  536. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py +19 -0
  537. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py +43 -0
  538. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py +23 -0
  539. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py +59 -0
  540. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py +36 -0
  541. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py +38 -0
  542. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py +37 -0
  543. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py +529 -0
  544. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py +44 -0
  545. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py +46 -0
  546. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py +33 -0
  547. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py +32 -0
  548. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py +28 -0
  549. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py +33 -0
  550. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py +31 -0
  551. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py +27 -0
  552. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py +21 -0
  553. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py +53 -0
  554. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py +158 -0
  555. wandb/vendor/promise-2.3.0/conftest.py +30 -0
  556. wandb/vendor/promise-2.3.0/setup.py +64 -0
  557. wandb/vendor/promise-2.3.0/tests/__init__.py +0 -0
  558. wandb/vendor/promise-2.3.0/tests/conftest.py +8 -0
  559. wandb/vendor/promise-2.3.0/tests/test_awaitable.py +32 -0
  560. wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py +47 -0
  561. wandb/vendor/promise-2.3.0/tests/test_benchmark.py +116 -0
  562. wandb/vendor/promise-2.3.0/tests/test_complex_threads.py +23 -0
  563. wandb/vendor/promise-2.3.0/tests/test_dataloader.py +452 -0
  564. wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py +99 -0
  565. wandb/vendor/promise-2.3.0/tests/test_dataloader_extra.py +65 -0
  566. wandb/vendor/promise-2.3.0/tests/test_extra.py +670 -0
  567. wandb/vendor/promise-2.3.0/tests/test_issues.py +132 -0
  568. wandb/vendor/promise-2.3.0/tests/test_promise_list.py +70 -0
  569. wandb/vendor/promise-2.3.0/tests/test_spec.py +584 -0
  570. wandb/vendor/promise-2.3.0/tests/test_thread_safety.py +115 -0
  571. wandb/vendor/promise-2.3.0/tests/utils.py +3 -0
  572. wandb/vendor/promise-2.3.0/wandb_promise/__init__.py +38 -0
  573. wandb/vendor/promise-2.3.0/wandb_promise/async_.py +135 -0
  574. wandb/vendor/promise-2.3.0/wandb_promise/compat.py +32 -0
  575. wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py +326 -0
  576. wandb/vendor/promise-2.3.0/wandb_promise/iterate_promise.py +12 -0
  577. wandb/vendor/promise-2.3.0/wandb_promise/promise.py +848 -0
  578. wandb/vendor/promise-2.3.0/wandb_promise/promise_list.py +151 -0
  579. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/__init__.py +0 -0
  580. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/version.py +83 -0
  581. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/__init__.py +0 -0
  582. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/asyncio.py +22 -0
  583. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/gevent.py +21 -0
  584. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/immediate.py +27 -0
  585. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/thread.py +18 -0
  586. wandb/vendor/promise-2.3.0/wandb_promise/utils.py +56 -0
  587. wandb/vendor/pygments/__init__.py +90 -0
  588. wandb/vendor/pygments/cmdline.py +568 -0
  589. wandb/vendor/pygments/console.py +74 -0
  590. wandb/vendor/pygments/filter.py +74 -0
  591. wandb/vendor/pygments/filters/__init__.py +350 -0
  592. wandb/vendor/pygments/formatter.py +95 -0
  593. wandb/vendor/pygments/formatters/__init__.py +153 -0
  594. wandb/vendor/pygments/formatters/_mapping.py +85 -0
  595. wandb/vendor/pygments/formatters/bbcode.py +109 -0
  596. wandb/vendor/pygments/formatters/html.py +851 -0
  597. wandb/vendor/pygments/formatters/img.py +600 -0
  598. wandb/vendor/pygments/formatters/irc.py +182 -0
  599. wandb/vendor/pygments/formatters/latex.py +482 -0
  600. wandb/vendor/pygments/formatters/other.py +160 -0
  601. wandb/vendor/pygments/formatters/rtf.py +147 -0
  602. wandb/vendor/pygments/formatters/svg.py +153 -0
  603. wandb/vendor/pygments/formatters/terminal.py +136 -0
  604. wandb/vendor/pygments/formatters/terminal256.py +309 -0
  605. wandb/vendor/pygments/lexer.py +871 -0
  606. wandb/vendor/pygments/lexers/__init__.py +329 -0
  607. wandb/vendor/pygments/lexers/_asy_builtins.py +1645 -0
  608. wandb/vendor/pygments/lexers/_cl_builtins.py +232 -0
  609. wandb/vendor/pygments/lexers/_cocoa_builtins.py +72 -0
  610. wandb/vendor/pygments/lexers/_csound_builtins.py +1346 -0
  611. wandb/vendor/pygments/lexers/_lasso_builtins.py +5327 -0
  612. wandb/vendor/pygments/lexers/_lua_builtins.py +295 -0
  613. wandb/vendor/pygments/lexers/_mapping.py +500 -0
  614. wandb/vendor/pygments/lexers/_mql_builtins.py +1172 -0
  615. wandb/vendor/pygments/lexers/_openedge_builtins.py +2547 -0
  616. wandb/vendor/pygments/lexers/_php_builtins.py +4756 -0
  617. wandb/vendor/pygments/lexers/_postgres_builtins.py +621 -0
  618. wandb/vendor/pygments/lexers/_scilab_builtins.py +3094 -0
  619. wandb/vendor/pygments/lexers/_sourcemod_builtins.py +1163 -0
  620. wandb/vendor/pygments/lexers/_stan_builtins.py +532 -0
  621. wandb/vendor/pygments/lexers/_stata_builtins.py +419 -0
  622. wandb/vendor/pygments/lexers/_tsql_builtins.py +1004 -0
  623. wandb/vendor/pygments/lexers/_vim_builtins.py +1939 -0
  624. wandb/vendor/pygments/lexers/actionscript.py +240 -0
  625. wandb/vendor/pygments/lexers/agile.py +24 -0
  626. wandb/vendor/pygments/lexers/algebra.py +221 -0
  627. wandb/vendor/pygments/lexers/ambient.py +76 -0
  628. wandb/vendor/pygments/lexers/ampl.py +87 -0
  629. wandb/vendor/pygments/lexers/apl.py +101 -0
  630. wandb/vendor/pygments/lexers/archetype.py +318 -0
  631. wandb/vendor/pygments/lexers/asm.py +641 -0
  632. wandb/vendor/pygments/lexers/automation.py +374 -0
  633. wandb/vendor/pygments/lexers/basic.py +500 -0
  634. wandb/vendor/pygments/lexers/bibtex.py +160 -0
  635. wandb/vendor/pygments/lexers/business.py +612 -0
  636. wandb/vendor/pygments/lexers/c_cpp.py +252 -0
  637. wandb/vendor/pygments/lexers/c_like.py +541 -0
  638. wandb/vendor/pygments/lexers/capnproto.py +78 -0
  639. wandb/vendor/pygments/lexers/chapel.py +102 -0
  640. wandb/vendor/pygments/lexers/clean.py +288 -0
  641. wandb/vendor/pygments/lexers/compiled.py +34 -0
  642. wandb/vendor/pygments/lexers/configs.py +833 -0
  643. wandb/vendor/pygments/lexers/console.py +114 -0
  644. wandb/vendor/pygments/lexers/crystal.py +393 -0
  645. wandb/vendor/pygments/lexers/csound.py +366 -0
  646. wandb/vendor/pygments/lexers/css.py +689 -0
  647. wandb/vendor/pygments/lexers/d.py +251 -0
  648. wandb/vendor/pygments/lexers/dalvik.py +125 -0
  649. wandb/vendor/pygments/lexers/data.py +555 -0
  650. wandb/vendor/pygments/lexers/diff.py +165 -0
  651. wandb/vendor/pygments/lexers/dotnet.py +691 -0
  652. wandb/vendor/pygments/lexers/dsls.py +878 -0
  653. wandb/vendor/pygments/lexers/dylan.py +289 -0
  654. wandb/vendor/pygments/lexers/ecl.py +125 -0
  655. wandb/vendor/pygments/lexers/eiffel.py +65 -0
  656. wandb/vendor/pygments/lexers/elm.py +121 -0
  657. wandb/vendor/pygments/lexers/erlang.py +533 -0
  658. wandb/vendor/pygments/lexers/esoteric.py +277 -0
  659. wandb/vendor/pygments/lexers/ezhil.py +69 -0
  660. wandb/vendor/pygments/lexers/factor.py +344 -0
  661. wandb/vendor/pygments/lexers/fantom.py +250 -0
  662. wandb/vendor/pygments/lexers/felix.py +273 -0
  663. wandb/vendor/pygments/lexers/forth.py +177 -0
  664. wandb/vendor/pygments/lexers/fortran.py +205 -0
  665. wandb/vendor/pygments/lexers/foxpro.py +428 -0
  666. wandb/vendor/pygments/lexers/functional.py +21 -0
  667. wandb/vendor/pygments/lexers/go.py +101 -0
  668. wandb/vendor/pygments/lexers/grammar_notation.py +213 -0
  669. wandb/vendor/pygments/lexers/graph.py +80 -0
  670. wandb/vendor/pygments/lexers/graphics.py +553 -0
  671. wandb/vendor/pygments/lexers/haskell.py +843 -0
  672. wandb/vendor/pygments/lexers/haxe.py +936 -0
  673. wandb/vendor/pygments/lexers/hdl.py +382 -0
  674. wandb/vendor/pygments/lexers/hexdump.py +103 -0
  675. wandb/vendor/pygments/lexers/html.py +602 -0
  676. wandb/vendor/pygments/lexers/idl.py +270 -0
  677. wandb/vendor/pygments/lexers/igor.py +288 -0
  678. wandb/vendor/pygments/lexers/inferno.py +96 -0
  679. wandb/vendor/pygments/lexers/installers.py +322 -0
  680. wandb/vendor/pygments/lexers/int_fiction.py +1343 -0
  681. wandb/vendor/pygments/lexers/iolang.py +63 -0
  682. wandb/vendor/pygments/lexers/j.py +146 -0
  683. wandb/vendor/pygments/lexers/javascript.py +1525 -0
  684. wandb/vendor/pygments/lexers/julia.py +333 -0
  685. wandb/vendor/pygments/lexers/jvm.py +1573 -0
  686. wandb/vendor/pygments/lexers/lisp.py +2621 -0
  687. wandb/vendor/pygments/lexers/make.py +202 -0
  688. wandb/vendor/pygments/lexers/markup.py +595 -0
  689. wandb/vendor/pygments/lexers/math.py +21 -0
  690. wandb/vendor/pygments/lexers/matlab.py +663 -0
  691. wandb/vendor/pygments/lexers/ml.py +769 -0
  692. wandb/vendor/pygments/lexers/modeling.py +358 -0
  693. wandb/vendor/pygments/lexers/modula2.py +1561 -0
  694. wandb/vendor/pygments/lexers/monte.py +204 -0
  695. wandb/vendor/pygments/lexers/ncl.py +894 -0
  696. wandb/vendor/pygments/lexers/nimrod.py +159 -0
  697. wandb/vendor/pygments/lexers/nit.py +64 -0
  698. wandb/vendor/pygments/lexers/nix.py +136 -0
  699. wandb/vendor/pygments/lexers/oberon.py +105 -0
  700. wandb/vendor/pygments/lexers/objective.py +504 -0
  701. wandb/vendor/pygments/lexers/ooc.py +85 -0
  702. wandb/vendor/pygments/lexers/other.py +41 -0
  703. wandb/vendor/pygments/lexers/parasail.py +79 -0
  704. wandb/vendor/pygments/lexers/parsers.py +835 -0
  705. wandb/vendor/pygments/lexers/pascal.py +644 -0
  706. wandb/vendor/pygments/lexers/pawn.py +199 -0
  707. wandb/vendor/pygments/lexers/perl.py +620 -0
  708. wandb/vendor/pygments/lexers/php.py +267 -0
  709. wandb/vendor/pygments/lexers/praat.py +294 -0
  710. wandb/vendor/pygments/lexers/prolog.py +306 -0
  711. wandb/vendor/pygments/lexers/python.py +939 -0
  712. wandb/vendor/pygments/lexers/qvt.py +152 -0
  713. wandb/vendor/pygments/lexers/r.py +453 -0
  714. wandb/vendor/pygments/lexers/rdf.py +270 -0
  715. wandb/vendor/pygments/lexers/rebol.py +431 -0
  716. wandb/vendor/pygments/lexers/resource.py +85 -0
  717. wandb/vendor/pygments/lexers/rnc.py +67 -0
  718. wandb/vendor/pygments/lexers/roboconf.py +82 -0
  719. wandb/vendor/pygments/lexers/robotframework.py +560 -0
  720. wandb/vendor/pygments/lexers/ruby.py +519 -0
  721. wandb/vendor/pygments/lexers/rust.py +220 -0
  722. wandb/vendor/pygments/lexers/sas.py +228 -0
  723. wandb/vendor/pygments/lexers/scripting.py +1222 -0
  724. wandb/vendor/pygments/lexers/shell.py +794 -0
  725. wandb/vendor/pygments/lexers/smalltalk.py +195 -0
  726. wandb/vendor/pygments/lexers/smv.py +79 -0
  727. wandb/vendor/pygments/lexers/snobol.py +83 -0
  728. wandb/vendor/pygments/lexers/special.py +103 -0
  729. wandb/vendor/pygments/lexers/sql.py +681 -0
  730. wandb/vendor/pygments/lexers/stata.py +108 -0
  731. wandb/vendor/pygments/lexers/supercollider.py +90 -0
  732. wandb/vendor/pygments/lexers/tcl.py +145 -0
  733. wandb/vendor/pygments/lexers/templates.py +2283 -0
  734. wandb/vendor/pygments/lexers/testing.py +207 -0
  735. wandb/vendor/pygments/lexers/text.py +25 -0
  736. wandb/vendor/pygments/lexers/textedit.py +169 -0
  737. wandb/vendor/pygments/lexers/textfmts.py +297 -0
  738. wandb/vendor/pygments/lexers/theorem.py +458 -0
  739. wandb/vendor/pygments/lexers/trafficscript.py +54 -0
  740. wandb/vendor/pygments/lexers/typoscript.py +226 -0
  741. wandb/vendor/pygments/lexers/urbi.py +133 -0
  742. wandb/vendor/pygments/lexers/varnish.py +190 -0
  743. wandb/vendor/pygments/lexers/verification.py +111 -0
  744. wandb/vendor/pygments/lexers/web.py +24 -0
  745. wandb/vendor/pygments/lexers/webmisc.py +988 -0
  746. wandb/vendor/pygments/lexers/whiley.py +116 -0
  747. wandb/vendor/pygments/lexers/x10.py +69 -0
  748. wandb/vendor/pygments/modeline.py +44 -0
  749. wandb/vendor/pygments/plugin.py +68 -0
  750. wandb/vendor/pygments/regexopt.py +92 -0
  751. wandb/vendor/pygments/scanner.py +105 -0
  752. wandb/vendor/pygments/sphinxext.py +158 -0
  753. wandb/vendor/pygments/style.py +155 -0
  754. wandb/vendor/pygments/styles/__init__.py +80 -0
  755. wandb/vendor/pygments/styles/abap.py +29 -0
  756. wandb/vendor/pygments/styles/algol.py +63 -0
  757. wandb/vendor/pygments/styles/algol_nu.py +63 -0
  758. wandb/vendor/pygments/styles/arduino.py +98 -0
  759. wandb/vendor/pygments/styles/autumn.py +65 -0
  760. wandb/vendor/pygments/styles/borland.py +51 -0
  761. wandb/vendor/pygments/styles/bw.py +49 -0
  762. wandb/vendor/pygments/styles/colorful.py +81 -0
  763. wandb/vendor/pygments/styles/default.py +73 -0
  764. wandb/vendor/pygments/styles/emacs.py +72 -0
  765. wandb/vendor/pygments/styles/friendly.py +72 -0
  766. wandb/vendor/pygments/styles/fruity.py +42 -0
  767. wandb/vendor/pygments/styles/igor.py +29 -0
  768. wandb/vendor/pygments/styles/lovelace.py +97 -0
  769. wandb/vendor/pygments/styles/manni.py +75 -0
  770. wandb/vendor/pygments/styles/monokai.py +106 -0
  771. wandb/vendor/pygments/styles/murphy.py +80 -0
  772. wandb/vendor/pygments/styles/native.py +65 -0
  773. wandb/vendor/pygments/styles/paraiso_dark.py +125 -0
  774. wandb/vendor/pygments/styles/paraiso_light.py +125 -0
  775. wandb/vendor/pygments/styles/pastie.py +75 -0
  776. wandb/vendor/pygments/styles/perldoc.py +69 -0
  777. wandb/vendor/pygments/styles/rainbow_dash.py +89 -0
  778. wandb/vendor/pygments/styles/rrt.py +33 -0
  779. wandb/vendor/pygments/styles/sas.py +44 -0
  780. wandb/vendor/pygments/styles/stata.py +40 -0
  781. wandb/vendor/pygments/styles/tango.py +141 -0
  782. wandb/vendor/pygments/styles/trac.py +63 -0
  783. wandb/vendor/pygments/styles/vim.py +63 -0
  784. wandb/vendor/pygments/styles/vs.py +38 -0
  785. wandb/vendor/pygments/styles/xcode.py +51 -0
  786. wandb/vendor/pygments/token.py +213 -0
  787. wandb/vendor/pygments/unistring.py +217 -0
  788. wandb/vendor/pygments/util.py +388 -0
  789. wandb/vendor/pynvml/__init__.py +0 -0
  790. wandb/vendor/pynvml/pynvml.py +4779 -0
  791. wandb/vendor/watchdog_0_9_0/wandb_watchdog/__init__.py +17 -0
  792. wandb/vendor/watchdog_0_9_0/wandb_watchdog/events.py +615 -0
  793. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/__init__.py +98 -0
  794. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/api.py +369 -0
  795. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents.py +172 -0
  796. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents2.py +239 -0
  797. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify.py +218 -0
  798. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_buffer.py +81 -0
  799. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_c.py +575 -0
  800. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/kqueue.py +730 -0
  801. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/polling.py +145 -0
  802. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/read_directory_changes.py +133 -0
  803. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/winapi.py +348 -0
  804. wandb/vendor/watchdog_0_9_0/wandb_watchdog/patterns.py +265 -0
  805. wandb/vendor/watchdog_0_9_0/wandb_watchdog/tricks/__init__.py +174 -0
  806. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/__init__.py +151 -0
  807. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/bricks.py +249 -0
  808. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/compat.py +29 -0
  809. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/decorators.py +198 -0
  810. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/delayed_queue.py +88 -0
  811. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/dirsnapshot.py +293 -0
  812. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/echo.py +157 -0
  813. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/event_backport.py +41 -0
  814. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/importlib2.py +40 -0
  815. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/platform.py +57 -0
  816. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/unicode_paths.py +64 -0
  817. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/win32stat.py +123 -0
  818. wandb/vendor/watchdog_0_9_0/wandb_watchdog/version.py +28 -0
  819. wandb/vendor/watchdog_0_9_0/wandb_watchdog/watchmedo.py +577 -0
  820. wandb/wandb_agent.py +588 -0
  821. wandb/wandb_controller.py +721 -0
  822. wandb/wandb_run.py +9 -0
  823. wandb-0.18.2.dist-info/METADATA +213 -0
  824. wandb-0.18.2.dist-info/RECORD +827 -0
  825. wandb-0.18.2.dist-info/WHEEL +5 -0
  826. wandb-0.18.2.dist-info/entry_points.txt +3 -0
  827. wandb-0.18.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,4358 @@
1
+ import ast
2
+ import base64
3
+ import datetime
4
+ import functools
5
+ import http.client
6
+ import json
7
+ import logging
8
+ import os
9
+ import re
10
+ import socket
11
+ import sys
12
+ import threading
13
+ from copy import deepcopy
14
+ from pathlib import Path
15
+ from typing import (
16
+ IO,
17
+ TYPE_CHECKING,
18
+ Any,
19
+ Callable,
20
+ Dict,
21
+ Iterable,
22
+ List,
23
+ Mapping,
24
+ MutableMapping,
25
+ Optional,
26
+ Sequence,
27
+ TextIO,
28
+ Tuple,
29
+ Union,
30
+ )
31
+
32
+ import click
33
+ import requests
34
+ import yaml
35
+ from wandb_gql import Client, gql
36
+ from wandb_gql.client import RetryError
37
+
38
+ import wandb
39
+ from wandb import env, util
40
+ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
41
+ from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
42
+ from wandb.integration.sagemaker import parse_sm_secrets
43
+ from wandb.old.settings import Settings
44
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
45
+ from wandb.sdk.lib.gql_request import GraphQLSession
46
+ from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
47
+
48
+ from ..lib import credentials, retry
49
+ from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
50
+ from ..lib.gitlib import GitRepo
51
+ from . import context
52
+ from .progress import Progress
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ LAUNCH_DEFAULT_PROJECT = "model-registry"
57
+
58
+ if TYPE_CHECKING:
59
+ if sys.version_info >= (3, 8):
60
+ from typing import Literal, TypedDict
61
+ else:
62
+ from typing_extensions import Literal, TypedDict
63
+
64
+ from .progress import ProgressFn
65
+
66
+ class CreateArtifactFileSpecInput(TypedDict, total=False):
67
+ """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
68
+
69
+ artifactID: str # noqa: N815
70
+ name: str
71
+ md5: str
72
+ mimetype: Optional[str]
73
+ artifactManifestID: Optional[str] # noqa: N815
74
+ uploadPartsInput: Optional[List[Dict[str, object]]] # noqa: N815
75
+
76
+ class CreateArtifactFilesResponseFile(TypedDict):
77
+ id: str
78
+ name: str
79
+ displayName: str # noqa: N815
80
+ uploadUrl: Optional[str] # noqa: N815
81
+ uploadHeaders: Sequence[str] # noqa: N815
82
+ uploadMultipartUrls: "UploadPartsResponse" # noqa: N815
83
+ storagePath: str # noqa: N815
84
+ artifact: "CreateArtifactFilesResponseFileNode"
85
+
86
+ class CreateArtifactFilesResponseFileNode(TypedDict):
87
+ id: str
88
+
89
+ class UploadPartsResponse(TypedDict):
90
+ uploadUrlParts: List["UploadUrlParts"] # noqa: N815
91
+ uploadID: str # noqa: N815
92
+
93
+ class UploadUrlParts(TypedDict):
94
+ partNumber: int # noqa: N815
95
+ uploadUrl: str # noqa: N815
96
+
97
+ class CompleteMultipartUploadArtifactInput(TypedDict):
98
+ """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
99
+
100
+ completeMultipartAction: str # noqa: N815
101
+ completedParts: Dict[int, str] # noqa: N815
102
+ artifactID: str # noqa: N815
103
+ storagePath: str # noqa: N815
104
+ uploadID: str # noqa: N815
105
+ md5: str
106
+
107
+ class CompleteMultipartUploadArtifactResponse(TypedDict):
108
+ digest: str
109
+
110
+ class DefaultSettings(TypedDict):
111
+ section: str
112
+ git_remote: str
113
+ ignore_globs: Optional[List[str]]
114
+ base_url: Optional[str]
115
+ root_dir: Optional[str]
116
+ api_key: Optional[str]
117
+ entity: Optional[str]
118
+ project: Optional[str]
119
+ _extra_http_headers: Optional[Mapping[str, str]]
120
+ _proxies: Optional[Mapping[str, str]]
121
+
122
+ _Response = MutableMapping
123
+ SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"]
124
+ Number = Union[int, float]
125
+
126
+ # class _MappingSupportsCopy(Protocol):
127
+ # def copy(self) -> "_MappingSupportsCopy": ...
128
+ # def keys(self) -> Iterable: ...
129
+ # def __getitem__(self, name: str) -> Any: ...
130
+
131
+ httpclient_logger = logging.getLogger("http.client")
132
+ if os.environ.get("WANDB_DEBUG"):
133
+ httpclient_logger.setLevel(logging.DEBUG)
134
+
135
+
136
+ def check_httpclient_logger_handler() -> None:
137
+ # Only enable http.client logging if WANDB_DEBUG is set
138
+ if not os.environ.get("WANDB_DEBUG"):
139
+ return
140
+ if httpclient_logger.handlers:
141
+ return
142
+
143
+ # Enable HTTPConnection debug logging to the logging framework
144
+ level = logging.DEBUG
145
+
146
+ def httpclient_log(*args: Any) -> None:
147
+ httpclient_logger.log(level, " ".join(args))
148
+
149
+ # mask the print() built-in in the http.client module to use logging instead
150
+ http.client.print = httpclient_log # type: ignore[attr-defined]
151
+ # enable debugging
152
+ http.client.HTTPConnection.debuglevel = 1
153
+
154
+ root_logger = logging.getLogger("wandb")
155
+ if root_logger.handlers:
156
+ httpclient_logger.addHandler(root_logger.handlers[0])
157
+
158
+
159
+ class _ThreadLocalData(threading.local):
160
+ context: Optional[context.Context]
161
+
162
+ def __init__(self) -> None:
163
+ self.context = None
164
+
165
+
166
+ class Api:
167
+ """W&B Internal Api wrapper.
168
+
169
+ Note:
170
+ Settings are automatically overridden by looking for
171
+ a `wandb/settings` file in the current working directory or its parent
172
+ directory. If none can be found, we look in the current user's home
173
+ directory.
174
+
175
+ Arguments:
176
+ default_settings(dict, optional): If you aren't using a settings
177
+ file, or you wish to override the section to use in the settings file
178
+ Override the settings here.
179
+ """
180
+
181
+ HTTP_TIMEOUT = env.get_http_timeout(20)
182
+ FILE_PUSHER_TIMEOUT = env.get_file_pusher_timeout()
183
+ _global_context: context.Context
184
+ _local_data: _ThreadLocalData
185
+
186
+ def __init__(
187
+ self,
188
+ default_settings: Optional[
189
+ Union[
190
+ "wandb.sdk.wandb_settings.Settings",
191
+ "wandb.sdk.internal.settings_static.SettingsStatic",
192
+ Settings,
193
+ dict,
194
+ ]
195
+ ] = None,
196
+ load_settings: bool = True,
197
+ retry_timedelta: datetime.timedelta = datetime.timedelta( # noqa: B008 # okay because it's immutable
198
+ days=7
199
+ ),
200
+ environ: MutableMapping = os.environ,
201
+ retry_callback: Optional[Callable[[int, str], Any]] = None,
202
+ ) -> None:
203
+ self._environ = environ
204
+ self._global_context = context.Context()
205
+ self._local_data = _ThreadLocalData()
206
+ self.default_settings: DefaultSettings = {
207
+ "section": "default",
208
+ "git_remote": "origin",
209
+ "ignore_globs": [],
210
+ "base_url": "https://api.wandb.ai",
211
+ "root_dir": None,
212
+ "api_key": None,
213
+ "entity": None,
214
+ "project": None,
215
+ "_extra_http_headers": None,
216
+ "_proxies": None,
217
+ }
218
+ self.retry_timedelta = retry_timedelta
219
+ # todo: Old Settings do not follow the SupportsKeysAndGetItem Protocol
220
+ default_settings = default_settings or {}
221
+ self.default_settings.update(default_settings) # type: ignore
222
+ self.retry_uploads = 10
223
+ self._settings = Settings(
224
+ load_settings=load_settings,
225
+ root_dir=self.default_settings.get("root_dir"),
226
+ )
227
+ self.git = GitRepo(remote=self.settings("git_remote"))
228
+ # Mutable settings set by the _file_stream_api
229
+ self.dynamic_settings = {
230
+ "system_sample_seconds": 2,
231
+ "system_samples": 15,
232
+ "heartbeat_seconds": 30,
233
+ }
234
+
235
+ # todo: remove these hacky hacks after settings refactor is complete
236
+ # keeping this code here to limit scope and so that it is easy to remove later
237
+ self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
238
+ self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
239
+ )
240
+ self._extra_http_headers.update(_thread_local_api_settings.headers or {})
241
+
242
+ auth = None
243
+ if self.access_token is not None:
244
+ self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
245
+ elif _thread_local_api_settings.cookies is None:
246
+ auth = ("api", self.api_key or "")
247
+
248
+ proxies = self.settings("_proxies") or json.loads(
249
+ self._environ.get("WANDB__PROXIES", "{}")
250
+ )
251
+
252
+ self.client = Client(
253
+ transport=GraphQLSession(
254
+ headers={
255
+ "User-Agent": self.user_agent,
256
+ "X-WANDB-USERNAME": env.get_username(env=self._environ),
257
+ "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
258
+ **self._extra_http_headers,
259
+ },
260
+ use_json=True,
261
+ # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
262
+ # https://bugs.python.org/issue22889
263
+ timeout=self.HTTP_TIMEOUT,
264
+ auth=auth,
265
+ url=f"{self.settings('base_url')}/graphql",
266
+ cookies=_thread_local_api_settings.cookies,
267
+ proxies=proxies,
268
+ )
269
+ )
270
+
271
+ self.retry_callback = retry_callback
272
+ self._retry_gql = retry.Retry(
273
+ self.execute,
274
+ retry_timedelta=retry_timedelta,
275
+ check_retry_fn=util.no_retry_auth,
276
+ retryable_exceptions=(RetryError, requests.RequestException),
277
+ retry_callback=retry_callback,
278
+ )
279
+ self._current_run_id: Optional[str] = None
280
+ self._file_stream_api = None
281
+ self._upload_file_session = requests.Session()
282
+ if self.FILE_PUSHER_TIMEOUT:
283
+ self._upload_file_session.put = functools.partial( # type: ignore
284
+ self._upload_file_session.put,
285
+ timeout=self.FILE_PUSHER_TIMEOUT,
286
+ )
287
+ if proxies:
288
+ self._upload_file_session.proxies.update(proxies)
289
+ # This Retry class is initialized once for each Api instance, so this
290
+ # defaults to retrying 1 million times per process or 7 days
291
+ self.upload_file_retry = normalize_exceptions(
292
+ retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
293
+ )
294
+ self.upload_multipart_file_chunk_retry = normalize_exceptions(
295
+ retry.retriable(retry_timedelta=retry_timedelta)(
296
+ self.upload_multipart_file_chunk
297
+ )
298
+ )
299
+ self._client_id_mapping: Dict[str, str] = {}
300
+ # Large file uploads to azure can optionally use their SDK
301
+ self._azure_blob_module = util.get_module("azure.storage.blob")
302
+
303
+ self.query_types: Optional[List[str]] = None
304
+ self.mutation_types: Optional[List[str]] = None
305
+ self.server_info_types: Optional[List[str]] = None
306
+ self.server_use_artifact_input_info: Optional[List[str]] = None
307
+ self.server_create_artifact_input_info: Optional[List[str]] = None
308
+ self.server_artifact_fields_info: Optional[List[str]] = None
309
+ self._max_cli_version: Optional[str] = None
310
+ self._server_settings_type: Optional[List[str]] = None
311
+ self.fail_run_queue_item_input_info: Optional[List[str]] = None
312
+ self.create_launch_agent_input_info: Optional[List[str]] = None
313
+ self.server_create_run_queue_supports_drc: Optional[bool] = None
314
+ self.server_create_run_queue_supports_priority: Optional[bool] = None
315
+ self.server_supports_template_variables: Optional[bool] = None
316
+ self.server_push_to_run_queue_supports_priority: Optional[bool] = None
317
+
318
+ def gql(self, *args: Any, **kwargs: Any) -> Any:
319
+ ret = self._retry_gql(
320
+ *args,
321
+ retry_cancel_event=self.context.cancel_event,
322
+ **kwargs,
323
+ )
324
+ return ret
325
+
326
+ def set_local_context(self, api_context: Optional[context.Context]) -> None:
327
+ self._local_data.context = api_context
328
+
329
+ def clear_local_context(self) -> None:
330
+ self._local_data.context = None
331
+
332
+ @property
333
+ def context(self) -> context.Context:
334
+ return self._local_data.context or self._global_context
335
+
336
+ def reauth(self) -> None:
337
+ """Ensure the current api key is set in the transport."""
338
+ self.client.transport.session.auth = ("api", self.api_key or "")
339
+
340
+ def relocate(self) -> None:
341
+ """Ensure the current api points to the right server."""
342
+ self.client.transport.url = "{}/graphql".format(self.settings("base_url"))
343
+
344
+ def execute(self, *args: Any, **kwargs: Any) -> "_Response":
345
+ """Wrapper around execute that logs in cases of failure."""
346
+ try:
347
+ return self.client.execute(*args, **kwargs) # type: ignore
348
+ except requests.exceptions.HTTPError as err:
349
+ response = err.response
350
+ assert response is not None
351
+ logger.error(f"{response.status_code} response executing GraphQL.")
352
+ logger.error(response.text)
353
+ for error in parse_backend_error_messages(response):
354
+ wandb.termerror(f"Error while calling W&B API: {error} ({response})")
355
+ raise
356
+
357
+ def disabled(self) -> Union[str, bool]:
358
+ return self._settings.get(Settings.DEFAULT_SECTION, "disabled", fallback=False) # type: ignore
359
+
360
+ def set_current_run_id(self, run_id: str) -> None:
361
+ self._current_run_id = run_id
362
+
363
+ @property
364
+ def current_run_id(self) -> Optional[str]:
365
+ return self._current_run_id
366
+
367
+ @property
368
+ def user_agent(self) -> str:
369
+ return f"W&B Internal Client {wandb.__version__}"
370
+
371
+ @property
372
+ def api_key(self) -> Optional[str]:
373
+ if _thread_local_api_settings.api_key:
374
+ return _thread_local_api_settings.api_key
375
+ auth = requests.utils.get_netrc_auth(self.api_url)
376
+ key = None
377
+ if auth:
378
+ key = auth[-1]
379
+
380
+ # Environment should take precedence
381
+ env_key: Optional[str] = self._environ.get(env.API_KEY)
382
+ sagemaker_key: Optional[str] = parse_sm_secrets().get(env.API_KEY)
383
+ default_key: Optional[str] = self.default_settings.get("api_key")
384
+ return env_key or key or sagemaker_key or default_key
385
+
386
+ @property
387
+ def access_token(self) -> Optional[str]:
388
+ """Retrieves an access token for authentication.
389
+
390
+ This function attempts to exchange an identity token for a temporary
391
+ access token from the server, and save it to the credentials file.
392
+ It uses the path to the identity token as defined in the environment
393
+ variables. If the environment variable is not set, it returns None.
394
+
395
+ Returns:
396
+ Optional[str]: The access token if available, otherwise None if
397
+ no identity token is supplied.
398
+ Raises:
399
+ AuthenticationError: If the path to the identity token is not found.
400
+ """
401
+ token_file_str = self._environ.get(env.IDENTITY_TOKEN_FILE)
402
+ if not token_file_str:
403
+ return None
404
+
405
+ token_file = Path(token_file_str)
406
+ if not token_file.exists():
407
+ raise AuthenticationError(f"Identity token file not found: {token_file}")
408
+
409
+ base_url = self.settings("base_url")
410
+ credentials_file = env.get_credentials_file(
411
+ str(credentials.DEFAULT_WANDB_CREDENTIALS_FILE), self._environ
412
+ )
413
+ return credentials.access_token(base_url, token_file, credentials_file)
414
+
415
+ @property
416
+ def api_url(self) -> str:
417
+ return self.settings("base_url") # type: ignore
418
+
419
+ @property
420
+ def app_url(self) -> str:
421
+ return wandb.util.app_url(self.api_url)
422
+
423
+ @property
424
+ def default_entity(self) -> str:
425
+ return self.viewer().get("entity") # type: ignore
426
+
427
+ def settings(self, key: Optional[str] = None, section: Optional[str] = None) -> Any:
428
+ """The settings overridden from the wandb/settings file.
429
+
430
+ Arguments:
431
+ key (str, optional): If provided only this setting is returned
432
+ section (str, optional): If provided this section of the setting file is
433
+ used, defaults to "default"
434
+
435
+ Returns:
436
+ A dict with the current settings
437
+
438
+ {
439
+ "entity": "models",
440
+ "base_url": "https://api.wandb.ai",
441
+ "project": None
442
+ }
443
+ """
444
+ result = self.default_settings.copy()
445
+ result.update(self._settings.items(section=section)) # type: ignore
446
+ result.update(
447
+ {
448
+ "entity": env.get_entity(
449
+ self._settings.get(
450
+ Settings.DEFAULT_SECTION,
451
+ "entity",
452
+ fallback=result.get("entity"),
453
+ ),
454
+ env=self._environ,
455
+ ),
456
+ "project": env.get_project(
457
+ self._settings.get(
458
+ Settings.DEFAULT_SECTION,
459
+ "project",
460
+ fallback=result.get("project"),
461
+ ),
462
+ env=self._environ,
463
+ ),
464
+ "base_url": env.get_base_url(
465
+ self._settings.get(
466
+ Settings.DEFAULT_SECTION,
467
+ "base_url",
468
+ fallback=result.get("base_url"),
469
+ ),
470
+ env=self._environ,
471
+ ),
472
+ "ignore_globs": env.get_ignore(
473
+ self._settings.get(
474
+ Settings.DEFAULT_SECTION,
475
+ "ignore_globs",
476
+ fallback=result.get("ignore_globs"),
477
+ ),
478
+ env=self._environ,
479
+ ),
480
+ }
481
+ )
482
+
483
+ return result if key is None else result[key] # type: ignore
484
+
485
+ def clear_setting(
486
+ self, key: str, globally: bool = False, persist: bool = False
487
+ ) -> None:
488
+ self._settings.clear(
489
+ Settings.DEFAULT_SECTION, key, globally=globally, persist=persist
490
+ )
491
+
492
+ def set_setting(
493
+ self, key: str, value: Any, globally: bool = False, persist: bool = False
494
+ ) -> None:
495
+ self._settings.set(
496
+ Settings.DEFAULT_SECTION, key, value, globally=globally, persist=persist
497
+ )
498
+ if key == "entity":
499
+ env.set_entity(value, env=self._environ)
500
+ elif key == "project":
501
+ env.set_project(value, env=self._environ)
502
+ elif key == "base_url":
503
+ self.relocate()
504
+
505
+ def parse_slug(
506
+ self, slug: str, project: Optional[str] = None, run: Optional[str] = None
507
+ ) -> Tuple[str, str]:
508
+ """Parse a slug into a project and run.
509
+
510
+ Arguments:
511
+ slug (str): The slug to parse
512
+ project (str, optional): The project to use, if not provided it will be
513
+ inferred from the slug
514
+ run (str, optional): The run to use, if not provided it will be inferred
515
+ from the slug
516
+
517
+ Returns:
518
+ A dict with the project and run
519
+ """
520
+ if slug and "/" in slug:
521
+ parts = slug.split("/")
522
+ project = parts[0]
523
+ run = parts[1]
524
+ else:
525
+ project = project or self.settings().get("project")
526
+ if project is None:
527
+ raise CommError("No default project configured.")
528
+ run = run or slug or self.current_run_id or env.get_run(env=self._environ)
529
+ assert run, "run must be specified"
530
+ return project, run
531
+
532
+ @normalize_exceptions
533
+ def server_info_introspection(self) -> Tuple[List[str], List[str], List[str]]:
534
+ query_string = """
535
+ query ProbeServerCapabilities {
536
+ QueryType: __type(name: "Query") {
537
+ ...fieldData
538
+ }
539
+ MutationType: __type(name: "Mutation") {
540
+ ...fieldData
541
+ }
542
+ ServerInfoType: __type(name: "ServerInfo") {
543
+ ...fieldData
544
+ }
545
+ }
546
+
547
+ fragment fieldData on __Type {
548
+ fields {
549
+ name
550
+ }
551
+ }
552
+ """
553
+ if (
554
+ self.query_types is None
555
+ or self.mutation_types is None
556
+ or self.server_info_types is None
557
+ ):
558
+ query = gql(query_string)
559
+ res = self.gql(query)
560
+
561
+ self.query_types = [
562
+ field.get("name", "")
563
+ for field in res.get("QueryType", {}).get("fields", [{}])
564
+ ]
565
+ self.mutation_types = [
566
+ field.get("name", "")
567
+ for field in res.get("MutationType", {}).get("fields", [{}])
568
+ ]
569
+ self.server_info_types = [
570
+ field.get("name", "")
571
+ for field in res.get("ServerInfoType", {}).get("fields", [{}])
572
+ ]
573
+ return self.query_types, self.server_info_types, self.mutation_types
574
+
575
+ @normalize_exceptions
576
+ def server_settings_introspection(self) -> None:
577
+ query_string = """
578
+ query ProbeServerSettings {
579
+ ServerSettingsType: __type(name: "ServerSettings") {
580
+ ...fieldData
581
+ }
582
+ }
583
+
584
+ fragment fieldData on __Type {
585
+ fields {
586
+ name
587
+ }
588
+ }
589
+ """
590
+ if self._server_settings_type is None:
591
+ query = gql(query_string)
592
+ res = self.gql(query)
593
+ self._server_settings_type = (
594
+ [
595
+ field.get("name", "")
596
+ for field in res.get("ServerSettingsType", {}).get("fields", [{}])
597
+ ]
598
+ if res
599
+ else []
600
+ )
601
+
602
+ def server_use_artifact_input_introspection(self) -> List:
603
+ query_string = """
604
+ query ProbeServerUseArtifactInput {
605
+ UseArtifactInputInfoType: __type(name: "UseArtifactInput") {
606
+ name
607
+ inputFields {
608
+ name
609
+ }
610
+ }
611
+ }
612
+ """
613
+
614
+ if self.server_use_artifact_input_info is None:
615
+ query = gql(query_string)
616
+ res = self.gql(query)
617
+ self.server_use_artifact_input_info = [
618
+ field.get("name", "")
619
+ for field in res.get("UseArtifactInputInfoType", {}).get(
620
+ "inputFields", [{}]
621
+ )
622
+ ]
623
+ return self.server_use_artifact_input_info
624
+
625
+ @normalize_exceptions
626
+ def launch_agent_introspection(self) -> Optional[str]:
627
+ query = gql(
628
+ """
629
+ query LaunchAgentIntrospection {
630
+ LaunchAgentType: __type(name: "LaunchAgent") {
631
+ name
632
+ }
633
+ }
634
+ """
635
+ )
636
+
637
+ res = self.gql(query)
638
+ return res.get("LaunchAgentType") or None
639
+
640
+ @normalize_exceptions
641
+ def create_run_queue_introspection(self) -> Tuple[bool, bool, bool]:
642
+ _, _, mutations = self.server_info_introspection()
643
+ query_string = """
644
+ query ProbeCreateRunQueueInput {
645
+ CreateRunQueueInputType: __type(name: "CreateRunQueueInput") {
646
+ name
647
+ inputFields {
648
+ name
649
+ }
650
+ }
651
+ }
652
+ """
653
+ if (
654
+ self.server_create_run_queue_supports_drc is None
655
+ or self.server_create_run_queue_supports_priority is None
656
+ ):
657
+ query = gql(query_string)
658
+ res = self.gql(query)
659
+ if res is None:
660
+ raise CommError("Could not get CreateRunQueue input from GQL.")
661
+ self.server_create_run_queue_supports_drc = "defaultResourceConfigID" in [
662
+ x["name"]
663
+ for x in (
664
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
665
+ )
666
+ ]
667
+ self.server_create_run_queue_supports_priority = "prioritizationMode" in [
668
+ x["name"]
669
+ for x in (
670
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
671
+ )
672
+ ]
673
+ return (
674
+ "createRunQueue" in mutations,
675
+ self.server_create_run_queue_supports_drc,
676
+ self.server_create_run_queue_supports_priority,
677
+ )
678
+
679
+ @normalize_exceptions
680
+ def upsert_run_queue_introspection(self) -> bool:
681
+ _, _, mutations = self.server_info_introspection()
682
+ return "upsertRunQueue" in mutations
683
+
684
+ @normalize_exceptions
685
+ def push_to_run_queue_introspection(self) -> Tuple[bool, bool]:
686
+ query_string = """
687
+ query ProbePushToRunQueueInput {
688
+ PushToRunQueueInputType: __type(name: "PushToRunQueueInput") {
689
+ name
690
+ inputFields {
691
+ name
692
+ }
693
+ }
694
+ }
695
+ """
696
+
697
+ if (
698
+ self.server_supports_template_variables is None
699
+ or self.server_push_to_run_queue_supports_priority is None
700
+ ):
701
+ query = gql(query_string)
702
+ res = self.gql(query)
703
+ self.server_supports_template_variables = "templateVariableValues" in [
704
+ x["name"]
705
+ for x in (
706
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
707
+ )
708
+ ]
709
+ self.server_push_to_run_queue_supports_priority = "priority" in [
710
+ x["name"]
711
+ for x in (
712
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
713
+ )
714
+ ]
715
+
716
+ return (
717
+ self.server_supports_template_variables,
718
+ self.server_push_to_run_queue_supports_priority,
719
+ )
720
+
721
+ @normalize_exceptions
722
+ def create_default_resource_config_introspection(self) -> bool:
723
+ _, _, mutations = self.server_info_introspection()
724
+ return "createDefaultResourceConfig" in mutations
725
+
726
+ @normalize_exceptions
727
+ def fail_run_queue_item_introspection(self) -> bool:
728
+ _, _, mutations = self.server_info_introspection()
729
+ return "failRunQueueItem" in mutations
730
+
731
+ @normalize_exceptions
732
+ def fail_run_queue_item_fields_introspection(self) -> List:
733
+ if self.fail_run_queue_item_input_info:
734
+ return self.fail_run_queue_item_input_info
735
+ query_string = """
736
+ query ProbeServerFailRunQueueItemInput {
737
+ FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") {
738
+ inputFields{
739
+ name
740
+ }
741
+ }
742
+ }
743
+ """
744
+
745
+ query = gql(query_string)
746
+ res = self.gql(query)
747
+
748
+ self.fail_run_queue_item_input_info = [
749
+ field.get("name", "")
750
+ for field in res.get("FailRunQueueItemInputInfoType", {}).get(
751
+ "inputFields", [{}]
752
+ )
753
+ ]
754
+ return self.fail_run_queue_item_input_info
755
+
756
+ @normalize_exceptions
757
+ def fail_run_queue_item(
758
+ self,
759
+ run_queue_item_id: str,
760
+ message: str,
761
+ stage: str,
762
+ file_paths: Optional[List[str]] = None,
763
+ ) -> bool:
764
+ if not self.fail_run_queue_item_introspection():
765
+ return False
766
+ variable_values: Dict[str, Union[str, Optional[List[str]]]] = {
767
+ "runQueueItemId": run_queue_item_id,
768
+ }
769
+ if "message" in self.fail_run_queue_item_fields_introspection():
770
+ variable_values.update({"message": message, "stage": stage})
771
+ if file_paths is not None:
772
+ variable_values["filePaths"] = file_paths
773
+ mutation_string = """
774
+ mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
775
+ failRunQueueItem(
776
+ input: {
777
+ runQueueItemId: $runQueueItemId
778
+ message: $message
779
+ stage: $stage
780
+ filePaths: $filePaths
781
+ }
782
+ ) {
783
+ success
784
+ }
785
+ }
786
+ """
787
+ else:
788
+ mutation_string = """
789
+ mutation failRunQueueItem($runQueueItemId: ID!) {
790
+ failRunQueueItem(
791
+ input: {
792
+ runQueueItemId: $runQueueItemId
793
+ }
794
+ ) {
795
+ success
796
+ }
797
+ }
798
+ """
799
+
800
+ mutation = gql(mutation_string)
801
+ response = self.gql(mutation, variable_values=variable_values)
802
+ result: bool = response["failRunQueueItem"]["success"]
803
+ return result
804
+
805
+ @normalize_exceptions
806
+ def update_run_queue_item_warning_introspection(self) -> bool:
807
+ _, _, mutations = self.server_info_introspection()
808
+ return "updateRunQueueItemWarning" in mutations
809
+
810
+ @normalize_exceptions
811
+ def update_run_queue_item_warning(
812
+ self,
813
+ run_queue_item_id: str,
814
+ message: str,
815
+ stage: str,
816
+ file_paths: Optional[List[str]] = None,
817
+ ) -> bool:
818
+ if not self.update_run_queue_item_warning_introspection():
819
+ return False
820
+ mutation = gql(
821
+ """
822
+ mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
823
+ updateRunQueueItemWarning(
824
+ input: {
825
+ runQueueItemId: $runQueueItemId
826
+ message: $message
827
+ stage: $stage
828
+ filePaths: $filePaths
829
+ }
830
+ ) {
831
+ success
832
+ }
833
+ }
834
+ """
835
+ )
836
+ response = self.gql(
837
+ mutation,
838
+ variable_values={
839
+ "runQueueItemId": run_queue_item_id,
840
+ "message": message,
841
+ "stage": stage,
842
+ "filePaths": file_paths,
843
+ },
844
+ )
845
+ result: bool = response["updateRunQueueItemWarning"]["success"]
846
+ return result
847
+
848
+ @normalize_exceptions
849
+ def viewer(self) -> Dict[str, Any]:
850
+ query = gql(
851
+ """
852
+ query Viewer{
853
+ viewer {
854
+ id
855
+ entity
856
+ username
857
+ flags
858
+ teams {
859
+ edges {
860
+ node {
861
+ name
862
+ }
863
+ }
864
+ }
865
+ }
866
+ }
867
+ """
868
+ )
869
+ res = self.gql(query)
870
+ return res.get("viewer") or {}
871
+
872
+ @normalize_exceptions
873
+ def max_cli_version(self) -> Optional[str]:
874
+ if self._max_cli_version is not None:
875
+ return self._max_cli_version
876
+
877
+ query_types, server_info_types, _ = self.server_info_introspection()
878
+ cli_version_exists = (
879
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
880
+ )
881
+ if not cli_version_exists:
882
+ return None
883
+
884
+ _, server_info = self.viewer_server_info()
885
+ self._max_cli_version = server_info.get("cliVersionInfo", {}).get(
886
+ "max_cli_version"
887
+ )
888
+ return self._max_cli_version
889
+
890
+ @normalize_exceptions
891
+ def viewer_server_info(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
892
+ local_query = """
893
+ latestLocalVersionInfo {
894
+ outOfDate
895
+ latestVersionString
896
+ versionOnThisInstanceString
897
+ }
898
+ """
899
+ cli_query = """
900
+ serverInfo {
901
+ cliVersionInfo
902
+ _LOCAL_QUERY_
903
+ }
904
+ """
905
+ query_template = """
906
+ query Viewer{
907
+ viewer {
908
+ id
909
+ entity
910
+ username
911
+ email
912
+ flags
913
+ teams {
914
+ edges {
915
+ node {
916
+ name
917
+ }
918
+ }
919
+ }
920
+ }
921
+ _CLI_QUERY_
922
+ }
923
+ """
924
+ query_types, server_info_types, _ = self.server_info_introspection()
925
+
926
+ cli_version_exists = (
927
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
928
+ )
929
+
930
+ local_version_exists = (
931
+ "serverInfo" in query_types
932
+ and "latestLocalVersionInfo" in server_info_types
933
+ )
934
+
935
+ cli_query_string = "" if not cli_version_exists else cli_query
936
+ local_query_string = "" if not local_version_exists else local_query
937
+
938
+ query_string = query_template.replace("_CLI_QUERY_", cli_query_string).replace(
939
+ "_LOCAL_QUERY_", local_query_string
940
+ )
941
+ query = gql(query_string)
942
+ res = self.gql(query)
943
+ return res.get("viewer") or {}, res.get("serverInfo") or {}
944
+
945
+ @normalize_exceptions
946
+ def list_projects(self, entity: Optional[str] = None) -> List[Dict[str, str]]:
947
+ """List projects in W&B scoped by entity.
948
+
949
+ Arguments:
950
+ entity (str, optional): The entity to scope this project to.
951
+
952
+ Returns:
953
+ [{"id","name","description"}]
954
+ """
955
+ query = gql(
956
+ """
957
+ query EntityProjects($entity: String) {
958
+ models(first: 10, entityName: $entity) {
959
+ edges {
960
+ node {
961
+ id
962
+ name
963
+ description
964
+ }
965
+ }
966
+ }
967
+ }
968
+ """
969
+ )
970
+ project_list: List[Dict[str, str]] = self._flatten_edges(
971
+ self.gql(
972
+ query, variable_values={"entity": entity or self.settings("entity")}
973
+ )["models"]
974
+ )
975
+ return project_list
976
+
977
+ @normalize_exceptions
978
+ def project(self, project: str, entity: Optional[str] = None) -> "_Response":
979
+ """Retrieve project.
980
+
981
+ Arguments:
982
+ project (str): The project to get details for
983
+ entity (str, optional): The entity to scope this project to.
984
+
985
+ Returns:
986
+ [{"id","name","repo","dockerImage","description"}]
987
+ """
988
+ query = gql(
989
+ """
990
+ query ProjectDetails($entity: String, $project: String) {
991
+ model(name: $project, entityName: $entity) {
992
+ id
993
+ name
994
+ repo
995
+ dockerImage
996
+ description
997
+ }
998
+ }
999
+ """
1000
+ )
1001
+ response: _Response = self.gql(
1002
+ query, variable_values={"entity": entity, "project": project}
1003
+ )["model"]
1004
+ return response
1005
+
1006
+ @normalize_exceptions
1007
+ def sweep(
1008
+ self,
1009
+ sweep: str,
1010
+ specs: str,
1011
+ project: Optional[str] = None,
1012
+ entity: Optional[str] = None,
1013
+ ) -> Dict[str, Any]:
1014
+ """Retrieve sweep.
1015
+
1016
+ Arguments:
1017
+ sweep (str): The sweep to get details for
1018
+ specs (str): history specs
1019
+ project (str, optional): The project to scope this sweep to.
1020
+ entity (str, optional): The entity to scope this sweep to.
1021
+
1022
+ Returns:
1023
+ [{"id","name","repo","dockerImage","description"}]
1024
+ """
1025
+ query = gql(
1026
+ """
1027
+ query SweepWithRuns($entity: String, $project: String, $sweep: String!, $specs: [JSONString!]!) {
1028
+ project(name: $project, entityName: $entity) {
1029
+ sweep(sweepName: $sweep) {
1030
+ id
1031
+ name
1032
+ method
1033
+ state
1034
+ description
1035
+ config
1036
+ createdAt
1037
+ heartbeatAt
1038
+ updatedAt
1039
+ earlyStopJobRunning
1040
+ bestLoss
1041
+ controller
1042
+ scheduler
1043
+ runs {
1044
+ edges {
1045
+ node {
1046
+ name
1047
+ state
1048
+ config
1049
+ exitcode
1050
+ heartbeatAt
1051
+ shouldStop
1052
+ failed
1053
+ stopped
1054
+ running
1055
+ summaryMetrics
1056
+ sampledHistory(specs: $specs)
1057
+ }
1058
+ }
1059
+ }
1060
+ }
1061
+ }
1062
+ }
1063
+ """
1064
+ )
1065
+ entity = entity or self.settings("entity")
1066
+ project = project or self.settings("project")
1067
+ response = self.gql(
1068
+ query,
1069
+ variable_values={
1070
+ "entity": entity,
1071
+ "project": project,
1072
+ "sweep": sweep,
1073
+ "specs": specs,
1074
+ },
1075
+ )
1076
+ if response["project"] is None or response["project"]["sweep"] is None:
1077
+ raise ValueError(f"Sweep {entity}/{project}/{sweep} not found")
1078
+ data: Dict[str, Any] = response["project"]["sweep"]
1079
+ if data:
1080
+ data["runs"] = self._flatten_edges(data["runs"])
1081
+ return data
1082
+
1083
+ @normalize_exceptions
1084
+ def list_runs(
1085
+ self, project: str, entity: Optional[str] = None
1086
+ ) -> List[Dict[str, str]]:
1087
+ """List runs in W&B scoped by project.
1088
+
1089
+ Arguments:
1090
+ project (str): The project to scope the runs to
1091
+ entity (str, optional): The entity to scope this project to. Defaults to public models
1092
+
1093
+ Returns:
1094
+ [{"id","name","description"}]
1095
+ """
1096
+ query = gql(
1097
+ """
1098
+ query ProjectRuns($model: String!, $entity: String) {
1099
+ model(name: $model, entityName: $entity) {
1100
+ buckets(first: 10) {
1101
+ edges {
1102
+ node {
1103
+ id
1104
+ name
1105
+ displayName
1106
+ description
1107
+ }
1108
+ }
1109
+ }
1110
+ }
1111
+ }
1112
+ """
1113
+ )
1114
+ return self._flatten_edges(
1115
+ self.gql(
1116
+ query,
1117
+ variable_values={
1118
+ "entity": entity or self.settings("entity"),
1119
+ "model": project or self.settings("project"),
1120
+ },
1121
+ )["model"]["buckets"]
1122
+ )
1123
+
1124
+ @normalize_exceptions
1125
+ def run_config(
1126
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
1127
+ ) -> Tuple[str, Dict[str, Any], Optional[str], Dict[str, Any]]:
1128
+ """Get the relevant configs for a run.
1129
+
1130
+ Arguments:
1131
+ project (str): The project to download, (can include bucket)
1132
+ run (str, optional): The run to download
1133
+ entity (str, optional): The entity to scope this project to.
1134
+ """
1135
+ check_httpclient_logger_handler()
1136
+
1137
+ query = gql(
1138
+ """
1139
+ query RunConfigs(
1140
+ $name: String!,
1141
+ $entity: String,
1142
+ $run: String!,
1143
+ $pattern: String!,
1144
+ $includeConfig: Boolean!,
1145
+ ) {
1146
+ model(name: $name, entityName: $entity) {
1147
+ bucket(name: $run) {
1148
+ config @include(if: $includeConfig)
1149
+ commit @include(if: $includeConfig)
1150
+ files(pattern: $pattern) {
1151
+ pageInfo {
1152
+ hasNextPage
1153
+ endCursor
1154
+ }
1155
+ edges {
1156
+ node {
1157
+ name
1158
+ directUrl
1159
+ }
1160
+ }
1161
+ }
1162
+ }
1163
+ }
1164
+ }
1165
+ """
1166
+ )
1167
+
1168
+ variable_values = {
1169
+ "name": project,
1170
+ "run": run,
1171
+ "entity": entity,
1172
+ "includeConfig": True,
1173
+ }
1174
+
1175
+ commit: str = ""
1176
+ config: Dict[str, Any] = {}
1177
+ patch: Optional[str] = None
1178
+ metadata: Dict[str, Any] = {}
1179
+
1180
+ # If we use the `names` parameter on the `files` node, then the server
1181
+ # will helpfully give us and 'open' file handle to the files that don't
1182
+ # exist. This is so that we can upload data to it. However, in this
1183
+ # case, we just want to download that file and not upload to it, so
1184
+ # let's instead query for the files that do exist using `pattern`
1185
+ # (with no wildcards).
1186
+ #
1187
+ # Unfortunately we're unable to construct a single pattern that matches
1188
+ # our 2 files, we would need something like regex for that.
1189
+ for filename in [DIFF_FNAME, METADATA_FNAME]:
1190
+ variable_values["pattern"] = filename
1191
+ response = self.gql(query, variable_values=variable_values)
1192
+ if response["model"] is None:
1193
+ raise CommError(f"Run {entity}/{project}/{run} not found")
1194
+ run_obj: Dict = response["model"]["bucket"]
1195
+ # we only need to fetch this config once
1196
+ if variable_values["includeConfig"]:
1197
+ commit = run_obj["commit"]
1198
+ config = json.loads(run_obj["config"] or "{}")
1199
+ variable_values["includeConfig"] = False
1200
+ if run_obj["files"] is not None:
1201
+ for file_edge in run_obj["files"]["edges"]:
1202
+ name = file_edge["node"]["name"]
1203
+ url = file_edge["node"]["directUrl"]
1204
+ res = requests.get(url)
1205
+ res.raise_for_status()
1206
+ if name == METADATA_FNAME:
1207
+ metadata = res.json()
1208
+ elif name == DIFF_FNAME:
1209
+ patch = res.text
1210
+
1211
+ return commit, config, patch, metadata
1212
+
1213
+ @normalize_exceptions
1214
+ def run_resume_status(
1215
+ self, entity: str, project_name: str, name: str
1216
+ ) -> Optional[Dict[str, Any]]:
1217
+ """Check if a run exists and get resume information.
1218
+
1219
+ Arguments:
1220
+ entity (str): The entity to scope this project to.
1221
+ project_name (str): The project to download, (can include bucket)
1222
+ name (str): The run to download
1223
+ """
1224
+ # Pulling wandbConfig.start_time is required so that we can determine if a run has actually started
1225
+ query = gql(
1226
+ """
1227
+ query RunResumeStatus($project: String, $entity: String, $name: String!) {
1228
+ model(name: $project, entityName: $entity) {
1229
+ id
1230
+ name
1231
+ entity {
1232
+ id
1233
+ name
1234
+ }
1235
+
1236
+ bucket(name: $name, missingOk: true) {
1237
+ id
1238
+ name
1239
+ summaryMetrics
1240
+ displayName
1241
+ logLineCount
1242
+ historyLineCount
1243
+ eventsLineCount
1244
+ historyTail
1245
+ eventsTail
1246
+ config
1247
+ tags
1248
+ wandbConfig(keys: ["t"])
1249
+ }
1250
+ }
1251
+ }
1252
+ """
1253
+ )
1254
+
1255
+ response = self.gql(
1256
+ query,
1257
+ variable_values={
1258
+ "entity": entity,
1259
+ "project": project_name,
1260
+ "name": name,
1261
+ },
1262
+ )
1263
+
1264
+ if "model" not in response or "bucket" not in (response["model"] or {}):
1265
+ return None
1266
+
1267
+ project = response["model"]
1268
+ self.set_setting("project", project_name)
1269
+ if "entity" in project:
1270
+ self.set_setting("entity", project["entity"]["name"])
1271
+
1272
+ result: Dict[str, Any] = project["bucket"]
1273
+
1274
+ return result
1275
+
1276
+ @normalize_exceptions
1277
+ def check_stop_requested(
1278
+ self, project_name: str, entity_name: str, run_id: str
1279
+ ) -> bool:
1280
+ query = gql(
1281
+ """
1282
+ query RunStoppedStatus($projectName: String, $entityName: String, $runId: String!) {
1283
+ project(name:$projectName, entityName:$entityName) {
1284
+ run(name:$runId) {
1285
+ stopped
1286
+ }
1287
+ }
1288
+ }
1289
+ """
1290
+ )
1291
+
1292
+ response = self.gql(
1293
+ query,
1294
+ variable_values={
1295
+ "projectName": project_name,
1296
+ "entityName": entity_name,
1297
+ "runId": run_id,
1298
+ },
1299
+ )
1300
+
1301
+ project = response.get("project", None)
1302
+ if not project:
1303
+ return False
1304
+ run = project.get("run", None)
1305
+ if not run:
1306
+ return False
1307
+
1308
+ status: bool = run["stopped"]
1309
+ return status
1310
+
1311
+ def format_project(self, project: str) -> str:
1312
+ return re.sub(r"\W+", "-", project.lower()).strip("-_")
1313
+
1314
+ @normalize_exceptions
1315
+ def upsert_project(
1316
+ self,
1317
+ project: str,
1318
+ id: Optional[str] = None,
1319
+ description: Optional[str] = None,
1320
+ entity: Optional[str] = None,
1321
+ ) -> Dict[str, Any]:
1322
+ """Create a new project.
1323
+
1324
+ Arguments:
1325
+ project (str): The project to create
1326
+ description (str, optional): A description of this project
1327
+ entity (str, optional): The entity to scope this project to.
1328
+ """
1329
+ mutation = gql(
1330
+ """
1331
+ mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) {
1332
+ upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
1333
+ model {
1334
+ name
1335
+ description
1336
+ }
1337
+ }
1338
+ }
1339
+ """
1340
+ )
1341
+ response = self.gql(
1342
+ mutation,
1343
+ variable_values={
1344
+ "name": self.format_project(project),
1345
+ "entity": entity or self.settings("entity"),
1346
+ "description": description,
1347
+ "id": id,
1348
+ },
1349
+ )
1350
+ # TODO(jhr): Commenting out 'repo' field for cling, add back
1351
+ # 'description': description, 'repo': self.git.remote_url, 'id': id})
1352
+ result: Dict[str, Any] = response["upsertModel"]["model"]
1353
+ return result
1354
+
1355
+ @normalize_exceptions
1356
+ def entity_is_team(self, entity: str) -> bool:
1357
+ query = gql(
1358
+ """
1359
+ query EntityIsTeam($entity: String!) {
1360
+ entity(name: $entity) {
1361
+ id
1362
+ isTeam
1363
+ }
1364
+ }
1365
+ """
1366
+ )
1367
+ variable_values = {
1368
+ "entity": entity,
1369
+ }
1370
+
1371
+ res = self.gql(query, variable_values)
1372
+ if res.get("entity") is None:
1373
+ raise Exception(
1374
+ f"Error fetching entity {entity} "
1375
+ "check that you have access to this entity"
1376
+ )
1377
+
1378
+ is_team: bool = res["entity"]["isTeam"]
1379
+ return is_team
1380
+
1381
+ @normalize_exceptions
1382
+ def get_project_run_queues(self, entity: str, project: str) -> List[Dict[str, str]]:
1383
+ query = gql(
1384
+ """
1385
+ query ProjectRunQueues($entity: String!, $projectName: String!){
1386
+ project(entityName: $entity, name: $projectName) {
1387
+ runQueues {
1388
+ id
1389
+ name
1390
+ createdBy
1391
+ access
1392
+ }
1393
+ }
1394
+ }
1395
+ """
1396
+ )
1397
+ variable_values = {
1398
+ "projectName": project,
1399
+ "entity": entity,
1400
+ }
1401
+
1402
+ res = self.gql(query, variable_values)
1403
+ if res.get("project") is None:
1404
+ # circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry)
1405
+ if project == "model-registry":
1406
+ msg = (
1407
+ f"Error fetching run queues for {entity} "
1408
+ "check that you have access to this entity and project"
1409
+ )
1410
+ else:
1411
+ msg = (
1412
+ f"Error fetching run queues for {entity}/{project} "
1413
+ "check that you have access to this entity and project"
1414
+ )
1415
+
1416
+ raise Exception(msg)
1417
+
1418
+ project_run_queues: List[Dict[str, str]] = res["project"]["runQueues"]
1419
+ return project_run_queues
1420
+
1421
+ @normalize_exceptions
1422
+ def create_default_resource_config(
1423
+ self,
1424
+ entity: str,
1425
+ resource: str,
1426
+ config: str,
1427
+ template_variables: Optional[Dict[str, Union[float, int, str]]],
1428
+ ) -> Optional[Dict[str, Any]]:
1429
+ if not self.create_default_resource_config_introspection():
1430
+ raise Exception()
1431
+ supports_template_vars, _ = self.push_to_run_queue_introspection()
1432
+
1433
+ mutation_params = """
1434
+ $entityName: String!,
1435
+ $resource: String!,
1436
+ $config: JSONString!
1437
+ """
1438
+ mutation_inputs = """
1439
+ entityName: $entityName,
1440
+ resource: $resource,
1441
+ config: $config
1442
+ """
1443
+
1444
+ if supports_template_vars:
1445
+ mutation_params += ", $templateVariables: JSONString"
1446
+ mutation_inputs += ", templateVariables: $templateVariables"
1447
+ else:
1448
+ if template_variables is not None:
1449
+ raise UnsupportedError(
1450
+ "server does not support template variables, please update server instance to >=0.46"
1451
+ )
1452
+
1453
+ variable_values = {
1454
+ "entityName": entity,
1455
+ "resource": resource,
1456
+ "config": config,
1457
+ }
1458
+ if supports_template_vars:
1459
+ if template_variables is not None:
1460
+ variable_values["templateVariables"] = json.dumps(template_variables)
1461
+ else:
1462
+ variable_values["templateVariables"] = "{}"
1463
+
1464
+ query = gql(
1465
+ f"""
1466
+ mutation createDefaultResourceConfig(
1467
+ {mutation_params}
1468
+ ) {{
1469
+ createDefaultResourceConfig(
1470
+ input: {{
1471
+ {mutation_inputs}
1472
+ }}
1473
+ ) {{
1474
+ defaultResourceConfigID
1475
+ success
1476
+ }}
1477
+ }}
1478
+ """
1479
+ )
1480
+
1481
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1482
+ "createDefaultResourceConfig"
1483
+ ]
1484
+ return result
1485
+
1486
+ @normalize_exceptions
1487
+ def create_run_queue(
1488
+ self,
1489
+ entity: str,
1490
+ project: str,
1491
+ queue_name: str,
1492
+ access: str,
1493
+ prioritization_mode: Optional[str] = None,
1494
+ config_id: Optional[str] = None,
1495
+ ) -> Optional[Dict[str, Any]]:
1496
+ (
1497
+ create_run_queue,
1498
+ supports_drc,
1499
+ supports_prioritization,
1500
+ ) = self.create_run_queue_introspection()
1501
+ if not create_run_queue:
1502
+ raise UnsupportedError(
1503
+ "run queue creation is not supported by this version of "
1504
+ "wandb server. Consider updating to the latest version."
1505
+ )
1506
+ if not supports_drc and config_id is not None:
1507
+ raise UnsupportedError(
1508
+ "default resource configurations are not supported by this version "
1509
+ "of wandb server. Consider updating to the latest version."
1510
+ )
1511
+ if not supports_prioritization and prioritization_mode is not None:
1512
+ raise UnsupportedError(
1513
+ "launch prioritization is not supported by this version of "
1514
+ "wandb server. Consider updating to the latest version."
1515
+ )
1516
+
1517
+ if supports_prioritization:
1518
+ query = gql(
1519
+ """
1520
+ mutation createRunQueue(
1521
+ $entity: String!,
1522
+ $project: String!,
1523
+ $queueName: String!,
1524
+ $access: RunQueueAccessType!,
1525
+ $prioritizationMode: RunQueuePrioritizationMode,
1526
+ $defaultResourceConfigID: ID,
1527
+ ) {
1528
+ createRunQueue(
1529
+ input: {
1530
+ entityName: $entity,
1531
+ projectName: $project,
1532
+ queueName: $queueName,
1533
+ access: $access,
1534
+ prioritizationMode: $prioritizationMode
1535
+ defaultResourceConfigID: $defaultResourceConfigID
1536
+ }
1537
+ ) {
1538
+ success
1539
+ queueID
1540
+ }
1541
+ }
1542
+ """
1543
+ )
1544
+ variable_values = {
1545
+ "entity": entity,
1546
+ "project": project,
1547
+ "queueName": queue_name,
1548
+ "access": access,
1549
+ "prioritizationMode": prioritization_mode,
1550
+ "defaultResourceConfigID": config_id,
1551
+ }
1552
+ else:
1553
+ query = gql(
1554
+ """
1555
+ mutation createRunQueue(
1556
+ $entity: String!,
1557
+ $project: String!,
1558
+ $queueName: String!,
1559
+ $access: RunQueueAccessType!,
1560
+ $defaultResourceConfigID: ID,
1561
+ ) {
1562
+ createRunQueue(
1563
+ input: {
1564
+ entityName: $entity,
1565
+ projectName: $project,
1566
+ queueName: $queueName,
1567
+ access: $access,
1568
+ defaultResourceConfigID: $defaultResourceConfigID
1569
+ }
1570
+ ) {
1571
+ success
1572
+ queueID
1573
+ }
1574
+ }
1575
+ """
1576
+ )
1577
+ variable_values = {
1578
+ "entity": entity,
1579
+ "project": project,
1580
+ "queueName": queue_name,
1581
+ "access": access,
1582
+ "defaultResourceConfigID": config_id,
1583
+ }
1584
+
1585
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1586
+ "createRunQueue"
1587
+ ]
1588
+ return result
1589
+
1590
+ @normalize_exceptions
1591
+ def upsert_run_queue(
1592
+ self,
1593
+ queue_name: str,
1594
+ entity: str,
1595
+ resource_type: str,
1596
+ resource_config: dict,
1597
+ project: str = LAUNCH_DEFAULT_PROJECT,
1598
+ prioritization_mode: Optional[str] = None,
1599
+ template_variables: Optional[dict] = None,
1600
+ external_links: Optional[dict] = None,
1601
+ ) -> Optional[Dict[str, Any]]:
1602
+ if not self.upsert_run_queue_introspection():
1603
+ raise UnsupportedError(
1604
+ "upserting run queues is not supported by this version of "
1605
+ "wandb server. Consider updating to the latest version."
1606
+ )
1607
+ query = gql(
1608
+ """
1609
+ mutation upsertRunQueue(
1610
+ $entityName: String!
1611
+ $projectName: String!
1612
+ $queueName: String!
1613
+ $resourceType: String!
1614
+ $resourceConfig: JSONString!
1615
+ $templateVariables: JSONString
1616
+ $prioritizationMode: RunQueuePrioritizationMode
1617
+ $externalLinks: JSONString
1618
+ $clientMutationId: String
1619
+ ) {
1620
+ upsertRunQueue(
1621
+ input: {
1622
+ entityName: $entityName
1623
+ projectName: $projectName
1624
+ queueName: $queueName
1625
+ resourceType: $resourceType
1626
+ resourceConfig: $resourceConfig
1627
+ templateVariables: $templateVariables
1628
+ prioritizationMode: $prioritizationMode
1629
+ externalLinks: $externalLinks
1630
+ clientMutationId: $clientMutationId
1631
+ }
1632
+ ) {
1633
+ success
1634
+ configSchemaValidationErrors
1635
+ }
1636
+ }
1637
+ """
1638
+ )
1639
+ variable_values = {
1640
+ "entityName": entity,
1641
+ "projectName": project,
1642
+ "queueName": queue_name,
1643
+ "resourceType": resource_type,
1644
+ "resourceConfig": json.dumps(resource_config),
1645
+ "templateVariables": (
1646
+ json.dumps(template_variables) if template_variables else None
1647
+ ),
1648
+ "prioritizationMode": prioritization_mode,
1649
+ "externalLinks": json.dumps(external_links) if external_links else None,
1650
+ }
1651
+ result: Dict[str, Any] = self.gql(query, variable_values)
1652
+ return result["upsertRunQueue"]
1653
+
1654
+ @normalize_exceptions
1655
+ def push_to_run_queue_by_name(
1656
+ self,
1657
+ entity: str,
1658
+ project: str,
1659
+ queue_name: str,
1660
+ run_spec: str,
1661
+ template_variables: Optional[Dict[str, Union[int, float, str]]],
1662
+ priority: Optional[int] = None,
1663
+ ) -> Optional[Dict[str, Any]]:
1664
+ self.push_to_run_queue_introspection()
1665
+ """Queryless mutation, should be used before legacy fallback method."""
1666
+
1667
+ mutation_params = """
1668
+ $entityName: String!,
1669
+ $projectName: String!,
1670
+ $queueName: String!,
1671
+ $runSpec: JSONString!
1672
+ """
1673
+
1674
+ mutation_input = """
1675
+ entityName: $entityName,
1676
+ projectName: $projectName,
1677
+ queueName: $queueName,
1678
+ runSpec: $runSpec
1679
+ """
1680
+
1681
+ variables: Dict[str, Any] = {
1682
+ "entityName": entity,
1683
+ "projectName": project,
1684
+ "queueName": queue_name,
1685
+ "runSpec": run_spec,
1686
+ }
1687
+ if self.server_push_to_run_queue_supports_priority:
1688
+ if priority is not None:
1689
+ variables["priority"] = priority
1690
+ mutation_params += ", $priority: Int"
1691
+ mutation_input += ", priority: $priority"
1692
+ else:
1693
+ if priority is not None:
1694
+ raise UnsupportedError(
1695
+ "server does not support priority, please update server instance to >=0.46"
1696
+ )
1697
+
1698
+ if self.server_supports_template_variables:
1699
+ if template_variables is not None:
1700
+ variables.update(
1701
+ {"templateVariableValues": json.dumps(template_variables)}
1702
+ )
1703
+ mutation_params += ", $templateVariableValues: JSONString"
1704
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1705
+ else:
1706
+ if template_variables is not None:
1707
+ raise UnsupportedError(
1708
+ "server does not support template variables, please update server instance to >=0.46"
1709
+ )
1710
+
1711
+ mutation = gql(
1712
+ f"""
1713
+ mutation pushToRunQueueByName(
1714
+ {mutation_params}
1715
+ ) {{
1716
+ pushToRunQueueByName(
1717
+ input: {{
1718
+ {mutation_input}
1719
+ }}
1720
+ ) {{
1721
+ runQueueItemId
1722
+ runSpec
1723
+ }}
1724
+ }}
1725
+ """
1726
+ )
1727
+
1728
+ try:
1729
+ result: Optional[Dict[str, Any]] = self.gql(
1730
+ mutation, variables, check_retry_fn=util.no_retry_4xx
1731
+ ).get("pushToRunQueueByName")
1732
+ if not result:
1733
+ return None
1734
+
1735
+ if result.get("runSpec"):
1736
+ run_spec = json.loads(str(result["runSpec"]))
1737
+ result["runSpec"] = run_spec
1738
+
1739
+ return result
1740
+ except Exception as e:
1741
+ if (
1742
+ 'Cannot query field "runSpec" on type "PushToRunQueueByNamePayload"'
1743
+ not in str(e)
1744
+ ):
1745
+ return None
1746
+
1747
+ mutation_no_runspec = gql(
1748
+ """
1749
+ mutation pushToRunQueueByName(
1750
+ $entityName: String!,
1751
+ $projectName: String!,
1752
+ $queueName: String!,
1753
+ $runSpec: JSONString!,
1754
+ ) {
1755
+ pushToRunQueueByName(
1756
+ input: {
1757
+ entityName: $entityName,
1758
+ projectName: $projectName,
1759
+ queueName: $queueName,
1760
+ runSpec: $runSpec
1761
+ }
1762
+ ) {
1763
+ runQueueItemId
1764
+ }
1765
+ }
1766
+ """
1767
+ )
1768
+
1769
+ try:
1770
+ result = self.gql(
1771
+ mutation_no_runspec, variables, check_retry_fn=util.no_retry_4xx
1772
+ ).get("pushToRunQueueByName")
1773
+ except Exception:
1774
+ result = None
1775
+
1776
+ return result
1777
+
1778
+ @normalize_exceptions
1779
+ def push_to_run_queue(
1780
+ self,
1781
+ queue_name: str,
1782
+ launch_spec: Dict[str, str],
1783
+ template_variables: Optional[dict],
1784
+ project_queue: str,
1785
+ priority: Optional[int] = None,
1786
+ ) -> Optional[Dict[str, Any]]:
1787
+ self.push_to_run_queue_introspection()
1788
+ entity = launch_spec.get("queue_entity") or launch_spec["entity"]
1789
+ run_spec = json.dumps(launch_spec)
1790
+
1791
+ push_result = self.push_to_run_queue_by_name(
1792
+ entity, project_queue, queue_name, run_spec, template_variables, priority
1793
+ )
1794
+
1795
+ if push_result:
1796
+ return push_result
1797
+
1798
+ if priority is not None:
1799
+ # Cannot proceed with legacy method if priority is set
1800
+ return None
1801
+
1802
+ """ Legacy Method """
1803
+ queues_found = self.get_project_run_queues(entity, project_queue)
1804
+ matching_queues = [
1805
+ q
1806
+ for q in queues_found
1807
+ if q["name"] == queue_name
1808
+ # ensure user has access to queue
1809
+ and (
1810
+ # TODO: User created queues in the UI have USER access
1811
+ q["access"] in ["PROJECT", "USER"]
1812
+ or q["createdBy"] == self.default_entity
1813
+ )
1814
+ ]
1815
+ if not matching_queues:
1816
+ # in the case of a missing default queue. create it
1817
+ if queue_name == "default":
1818
+ wandb.termlog(
1819
+ f"No default queue existing for entity: {entity} in project: {project_queue}, creating one."
1820
+ )
1821
+ res = self.create_run_queue(
1822
+ launch_spec["entity"],
1823
+ project_queue,
1824
+ queue_name,
1825
+ access="PROJECT",
1826
+ )
1827
+
1828
+ if res is None or res.get("queueID") is None:
1829
+ wandb.termerror(
1830
+ f"Unable to create default queue for entity: {entity} on project: {project_queue}. Run could not be added to a queue"
1831
+ )
1832
+ return None
1833
+ queue_id = res["queueID"]
1834
+
1835
+ else:
1836
+ if project_queue == "model-registry":
1837
+ _msg = f"Unable to push to run queue {queue_name}. Queue not found."
1838
+ else:
1839
+ _msg = f"Unable to push to run queue {project_queue}/{queue_name}. Queue not found."
1840
+ wandb.termwarn(_msg)
1841
+ return None
1842
+ elif len(matching_queues) > 1:
1843
+ wandb.termerror(
1844
+ f"Unable to push to run queue {queue_name}. More than one queue found with this name."
1845
+ )
1846
+ return None
1847
+ else:
1848
+ queue_id = matching_queues[0]["id"]
1849
+ spec_json = json.dumps(launch_spec)
1850
+ variables = {"queueID": queue_id, "runSpec": spec_json}
1851
+
1852
+ mutation_params = """
1853
+ $queueID: ID!,
1854
+ $runSpec: JSONString!
1855
+ """
1856
+ mutation_input = """
1857
+ queueID: $queueID,
1858
+ runSpec: $runSpec
1859
+ """
1860
+ if self.server_supports_template_variables:
1861
+ if template_variables is not None:
1862
+ mutation_params += ", $templateVariableValues: JSONString"
1863
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1864
+ variables.update(
1865
+ {"templateVariableValues": json.dumps(template_variables)}
1866
+ )
1867
+ else:
1868
+ if template_variables is not None:
1869
+ raise UnsupportedError(
1870
+ "server does not support template variables, please update server instance to >=0.46"
1871
+ )
1872
+
1873
+ mutation = gql(
1874
+ f"""
1875
+ mutation pushToRunQueue(
1876
+ {mutation_params}
1877
+ ) {{
1878
+ pushToRunQueue(
1879
+ input: {{{mutation_input}}}
1880
+ ) {{
1881
+ runQueueItemId
1882
+ }}
1883
+ }}
1884
+ """
1885
+ )
1886
+
1887
+ response = self.gql(mutation, variable_values=variables)
1888
+ if not response.get("pushToRunQueue"):
1889
+ raise CommError(f"Error pushing run queue item to queue {queue_name}.")
1890
+
1891
+ result: Optional[Dict[str, Any]] = response["pushToRunQueue"]
1892
+ return result
1893
+
1894
+ @normalize_exceptions
1895
+ def pop_from_run_queue(
1896
+ self,
1897
+ queue_name: str,
1898
+ entity: Optional[str] = None,
1899
+ project: Optional[str] = None,
1900
+ agent_id: Optional[str] = None,
1901
+ ) -> Optional[Dict[str, Any]]:
1902
+ mutation = gql(
1903
+ """
1904
+ mutation popFromRunQueue($entity: String!, $project: String!, $queueName: String!, $launchAgentId: ID) {
1905
+ popFromRunQueue(input: {
1906
+ entityName: $entity,
1907
+ projectName: $project,
1908
+ queueName: $queueName,
1909
+ launchAgentId: $launchAgentId
1910
+ }) {
1911
+ runQueueItemId
1912
+ runSpec
1913
+ }
1914
+ }
1915
+ """
1916
+ )
1917
+ response = self.gql(
1918
+ mutation,
1919
+ variable_values={
1920
+ "entity": entity,
1921
+ "project": project,
1922
+ "queueName": queue_name,
1923
+ "launchAgentId": agent_id,
1924
+ },
1925
+ )
1926
+ result: Optional[Dict[str, Any]] = response["popFromRunQueue"]
1927
+ return result
1928
+
1929
+ @normalize_exceptions
1930
+ def ack_run_queue_item(self, item_id: str, run_id: Optional[str] = None) -> bool:
1931
+ mutation = gql(
1932
+ """
1933
+ mutation ackRunQueueItem($itemId: ID!, $runId: String!) {
1934
+ ackRunQueueItem(input: { runQueueItemId: $itemId, runName: $runId }) {
1935
+ success
1936
+ }
1937
+ }
1938
+ """
1939
+ )
1940
+ response = self.gql(
1941
+ mutation, variable_values={"itemId": item_id, "runId": str(run_id)}
1942
+ )
1943
+ if not response["ackRunQueueItem"]["success"]:
1944
+ raise CommError(
1945
+ "Error acking run queue item. Item may have already been acknowledged by another process"
1946
+ )
1947
+ result: bool = response["ackRunQueueItem"]["success"]
1948
+ return result
1949
+
1950
+ @normalize_exceptions
1951
+ def create_launch_agent_fields_introspection(self) -> List:
1952
+ if self.create_launch_agent_input_info:
1953
+ return self.create_launch_agent_input_info
1954
+ query_string = """
1955
+ query ProbeServerCreateLaunchAgentInput {
1956
+ CreateLaunchAgentInputInfoType: __type(name:"CreateLaunchAgentInput") {
1957
+ inputFields{
1958
+ name
1959
+ }
1960
+ }
1961
+ }
1962
+ """
1963
+
1964
+ query = gql(query_string)
1965
+ res = self.gql(query)
1966
+
1967
+ self.create_launch_agent_input_info = [
1968
+ field.get("name", "")
1969
+ for field in res.get("CreateLaunchAgentInputInfoType", {}).get(
1970
+ "inputFields", [{}]
1971
+ )
1972
+ ]
1973
+ return self.create_launch_agent_input_info
1974
+
1975
+ @normalize_exceptions
1976
+ def create_launch_agent(
1977
+ self,
1978
+ entity: str,
1979
+ project: str,
1980
+ queues: List[str],
1981
+ agent_config: Dict[str, Any],
1982
+ version: str,
1983
+ gorilla_agent_support: bool,
1984
+ ) -> dict:
1985
+ project_queues = self.get_project_run_queues(entity, project)
1986
+ if not project_queues:
1987
+ # create default queue if it doesn't already exist
1988
+ default = self.create_run_queue(
1989
+ entity, project, "default", access="PROJECT"
1990
+ )
1991
+ if default is None or default.get("queueID") is None:
1992
+ raise CommError(
1993
+ "Unable to create default queue for {}/{}. No queues for agent to poll".format(
1994
+ entity, project
1995
+ )
1996
+ )
1997
+ project_queues = [{"id": default["queueID"], "name": "default"}]
1998
+ polling_queue_ids = [
1999
+ q["id"] for q in project_queues if q["name"] in queues
2000
+ ] # filter to poll specified queues
2001
+ if len(polling_queue_ids) != len(queues):
2002
+ raise CommError(
2003
+ f"Could not start launch agent: Not all of requested queues ({', '.join(queues)}) found. "
2004
+ f"Available queues for this project: {','.join([q['name'] for q in project_queues])}"
2005
+ )
2006
+
2007
+ if not gorilla_agent_support:
2008
+ # if gorilla doesn't support launch agents, return a client-generated id
2009
+ return {
2010
+ "success": True,
2011
+ "launchAgentId": None,
2012
+ }
2013
+
2014
+ hostname = socket.gethostname()
2015
+
2016
+ variable_values = {
2017
+ "entity": entity,
2018
+ "project": project,
2019
+ "queues": polling_queue_ids,
2020
+ "hostname": hostname,
2021
+ }
2022
+
2023
+ mutation_params = """
2024
+ $entity: String!,
2025
+ $project: String!,
2026
+ $queues: [ID!]!,
2027
+ $hostname: String!
2028
+ """
2029
+
2030
+ mutation_input = """
2031
+ entityName: $entity,
2032
+ projectName: $project,
2033
+ runQueues: $queues,
2034
+ hostname: $hostname
2035
+ """
2036
+
2037
+ if "agentConfig" in self.create_launch_agent_fields_introspection():
2038
+ variable_values["agentConfig"] = json.dumps(agent_config)
2039
+ mutation_params += ", $agentConfig: JSONString"
2040
+ mutation_input += ", agentConfig: $agentConfig"
2041
+ if "version" in self.create_launch_agent_fields_introspection():
2042
+ variable_values["version"] = version
2043
+ mutation_params += ", $version: String"
2044
+ mutation_input += ", version: $version"
2045
+
2046
+ mutation = gql(
2047
+ f"""
2048
+ mutation createLaunchAgent(
2049
+ {mutation_params}
2050
+ ) {{
2051
+ createLaunchAgent(
2052
+ input: {{
2053
+ {mutation_input}
2054
+ }}
2055
+ ) {{
2056
+ launchAgentId
2057
+ }}
2058
+ }}
2059
+ """
2060
+ )
2061
+ result: dict = self.gql(mutation, variable_values)["createLaunchAgent"]
2062
+ return result
2063
+
2064
+ @normalize_exceptions
2065
+ def update_launch_agent_status(
2066
+ self,
2067
+ agent_id: str,
2068
+ status: str,
2069
+ gorilla_agent_support: bool,
2070
+ ) -> dict:
2071
+ if not gorilla_agent_support:
2072
+ # if gorilla doesn't support launch agents, this is a no-op
2073
+ return {
2074
+ "success": True,
2075
+ }
2076
+
2077
+ mutation = gql(
2078
+ """
2079
+ mutation updateLaunchAgent($agentId: ID!, $agentStatus: String){
2080
+ updateLaunchAgent(
2081
+ input: {
2082
+ launchAgentId: $agentId
2083
+ agentStatus: $agentStatus
2084
+ }
2085
+ ) {
2086
+ success
2087
+ }
2088
+ }
2089
+ """
2090
+ )
2091
+ variable_values = {
2092
+ "agentId": agent_id,
2093
+ "agentStatus": status,
2094
+ }
2095
+ result: dict = self.gql(mutation, variable_values)["updateLaunchAgent"]
2096
+ return result
2097
+
2098
+ @normalize_exceptions
2099
+ def get_launch_agent(self, agent_id: str, gorilla_agent_support: bool) -> dict:
2100
+ if not gorilla_agent_support:
2101
+ return {
2102
+ "id": None,
2103
+ "name": "",
2104
+ "stopPolling": False,
2105
+ }
2106
+ query = gql(
2107
+ """
2108
+ query LaunchAgent($agentId: ID!) {
2109
+ launchAgent(id: $agentId) {
2110
+ id
2111
+ name
2112
+ runQueues
2113
+ hostname
2114
+ agentStatus
2115
+ stopPolling
2116
+ heartbeatAt
2117
+ }
2118
+ }
2119
+ """
2120
+ )
2121
+ variable_values = {
2122
+ "agentId": agent_id,
2123
+ }
2124
+ result: dict = self.gql(query, variable_values)["launchAgent"]
2125
+ return result
2126
+
2127
+ @normalize_exceptions
2128
+ def upsert_run(
2129
+ self,
2130
+ id: Optional[str] = None,
2131
+ name: Optional[str] = None,
2132
+ project: Optional[str] = None,
2133
+ host: Optional[str] = None,
2134
+ group: Optional[str] = None,
2135
+ tags: Optional[List[str]] = None,
2136
+ config: Optional[dict] = None,
2137
+ description: Optional[str] = None,
2138
+ entity: Optional[str] = None,
2139
+ state: Optional[str] = None,
2140
+ display_name: Optional[str] = None,
2141
+ notes: Optional[str] = None,
2142
+ repo: Optional[str] = None,
2143
+ job_type: Optional[str] = None,
2144
+ program_path: Optional[str] = None,
2145
+ commit: Optional[str] = None,
2146
+ sweep_name: Optional[str] = None,
2147
+ summary_metrics: Optional[str] = None,
2148
+ num_retries: Optional[int] = None,
2149
+ ) -> Tuple[dict, bool, Optional[List]]:
2150
+ """Update a run.
2151
+
2152
+ Arguments:
2153
+ id (str, optional): The existing run to update
2154
+ name (str, optional): The name of the run to create
2155
+ group (str, optional): Name of the group this run is a part of
2156
+ project (str, optional): The name of the project
2157
+ host (str, optional): The name of the host
2158
+ tags (list, optional): A list of tags to apply to the run
2159
+ config (dict, optional): The latest config params
2160
+ description (str, optional): A description of this project
2161
+ entity (str, optional): The entity to scope this project to.
2162
+ display_name (str, optional): The display name of this project
2163
+ notes (str, optional): Notes about this run
2164
+ repo (str, optional): Url of the program's repository.
2165
+ state (str, optional): State of the program.
2166
+ job_type (str, optional): Type of job, e.g 'train'.
2167
+ program_path (str, optional): Path to the program.
2168
+ commit (str, optional): The Git SHA to associate the run with
2169
+ sweep_name (str, optional): The name of the sweep this run is a part of
2170
+ summary_metrics (str, optional): The JSON summary metrics
2171
+ num_retries (int, optional): Number of retries
2172
+ """
2173
+ query_string = """
2174
+ mutation UpsertBucket(
2175
+ $id: String,
2176
+ $name: String,
2177
+ $project: String,
2178
+ $entity: String,
2179
+ $groupName: String,
2180
+ $description: String,
2181
+ $displayName: String,
2182
+ $notes: String,
2183
+ $commit: String,
2184
+ $config: JSONString,
2185
+ $host: String,
2186
+ $debug: Boolean,
2187
+ $program: String,
2188
+ $repo: String,
2189
+ $jobType: String,
2190
+ $state: String,
2191
+ $sweep: String,
2192
+ $tags: [String!],
2193
+ $summaryMetrics: JSONString,
2194
+ ) {
2195
+ upsertBucket(input: {
2196
+ id: $id,
2197
+ name: $name,
2198
+ groupName: $groupName,
2199
+ modelName: $project,
2200
+ entityName: $entity,
2201
+ description: $description,
2202
+ displayName: $displayName,
2203
+ notes: $notes,
2204
+ config: $config,
2205
+ commit: $commit,
2206
+ host: $host,
2207
+ debug: $debug,
2208
+ jobProgram: $program,
2209
+ jobRepo: $repo,
2210
+ jobType: $jobType,
2211
+ state: $state,
2212
+ sweep: $sweep,
2213
+ tags: $tags,
2214
+ summaryMetrics: $summaryMetrics,
2215
+ }) {
2216
+ bucket {
2217
+ id
2218
+ name
2219
+ displayName
2220
+ description
2221
+ config
2222
+ sweepName
2223
+ project {
2224
+ id
2225
+ name
2226
+ entity {
2227
+ id
2228
+ name
2229
+ }
2230
+ }
2231
+ historyLineCount
2232
+ }
2233
+ inserted
2234
+ _Server_Settings_
2235
+ }
2236
+ }
2237
+ """
2238
+ self.server_settings_introspection()
2239
+
2240
+ server_settings_string = (
2241
+ """
2242
+ serverSettings {
2243
+ serverMessages{
2244
+ utfText
2245
+ plainText
2246
+ htmlText
2247
+ messageType
2248
+ messageLevel
2249
+ }
2250
+ }
2251
+ """
2252
+ if self._server_settings_type
2253
+ else ""
2254
+ )
2255
+
2256
+ query_string = query_string.replace("_Server_Settings_", server_settings_string)
2257
+ mutation = gql(query_string)
2258
+ config_str = json.dumps(config) if config else None
2259
+ if not description or description.isspace():
2260
+ description = None
2261
+
2262
+ kwargs = {}
2263
+ if num_retries is not None:
2264
+ kwargs["num_retries"] = num_retries
2265
+
2266
+ variable_values = {
2267
+ "id": id,
2268
+ "entity": entity or self.settings("entity"),
2269
+ "name": name,
2270
+ "project": project or util.auto_project_name(program_path),
2271
+ "groupName": group,
2272
+ "tags": tags,
2273
+ "description": description,
2274
+ "config": config_str,
2275
+ "commit": commit,
2276
+ "displayName": display_name,
2277
+ "notes": notes,
2278
+ "host": None if self.settings().get("anonymous") == "true" else host,
2279
+ "debug": env.is_debug(env=self._environ),
2280
+ "repo": repo,
2281
+ "program": program_path,
2282
+ "jobType": job_type,
2283
+ "state": state,
2284
+ "sweep": sweep_name,
2285
+ "summaryMetrics": summary_metrics,
2286
+ }
2287
+
2288
+ # retry conflict errors for 2 minutes, default to no_auth_retry
2289
+ check_retry_fn = util.make_check_retry_fn(
2290
+ check_fn=util.check_retry_conflict_or_gone,
2291
+ check_timedelta=datetime.timedelta(minutes=2),
2292
+ fallback_retry_fn=util.no_retry_auth,
2293
+ )
2294
+
2295
+ response = self.gql(
2296
+ mutation,
2297
+ variable_values=variable_values,
2298
+ check_retry_fn=check_retry_fn,
2299
+ **kwargs,
2300
+ )
2301
+
2302
+ run_obj: Dict[str, Dict[str, Dict[str, str]]] = response["upsertBucket"][
2303
+ "bucket"
2304
+ ]
2305
+ project_obj: Dict[str, Dict[str, str]] = run_obj.get("project", {})
2306
+ if project_obj:
2307
+ self.set_setting("project", project_obj["name"])
2308
+ entity_obj = project_obj.get("entity", {})
2309
+ if entity_obj:
2310
+ self.set_setting("entity", entity_obj["name"])
2311
+
2312
+ server_messages = None
2313
+ if self._server_settings_type:
2314
+ server_messages = (
2315
+ response["upsertBucket"]
2316
+ .get("serverSettings", {})
2317
+ .get("serverMessages", [])
2318
+ )
2319
+
2320
+ return (
2321
+ response["upsertBucket"]["bucket"],
2322
+ response["upsertBucket"]["inserted"],
2323
+ server_messages,
2324
+ )
2325
+
2326
+ @normalize_exceptions
2327
+ def rewind_run(
2328
+ self,
2329
+ run_name: str,
2330
+ metric_name: str,
2331
+ metric_value: float,
2332
+ program_path: Optional[str] = None,
2333
+ entity: Optional[str] = None,
2334
+ project: Optional[str] = None,
2335
+ num_retries: Optional[int] = None,
2336
+ ) -> dict:
2337
+ """Rewinds a run to a previous state.
2338
+
2339
+ Arguments:
2340
+ run_name (str): The name of the run to rewind
2341
+ metric_name (str): The name of the metric to rewind to
2342
+ metric_value (float): The value of the metric to rewind to
2343
+ program_path (str, optional): Path to the program
2344
+ entity (str, optional): The entity to scope this project to
2345
+ project (str, optional): The name of the project
2346
+ num_retries (int, optional): Number of retries
2347
+
2348
+ Returns:
2349
+ A dict with the rewound run
2350
+
2351
+ {
2352
+ "id": "run_id",
2353
+ "name": "run_name",
2354
+ "displayName": "run_display_name",
2355
+ "description": "run_description",
2356
+ "config": "stringified_run_config_json",
2357
+ "sweepName": "run_sweep_name",
2358
+ "project": {
2359
+ "id": "project_id",
2360
+ "name": "project_name",
2361
+ "entity": {
2362
+ "id": "entity_id",
2363
+ "name": "entity_name"
2364
+ }
2365
+ },
2366
+ "historyLineCount": 100,
2367
+ }
2368
+ """
2369
+ query_string = """
2370
+ mutation RewindRun($runName: String!, $entity: String, $project: String, $metricName: String!, $metricValue: Float!) {
2371
+ rewindRun(input: {runName: $runName, entityName: $entity, projectName: $project, metricName: $metricName, metricValue: $metricValue}) {
2372
+ rewoundRun {
2373
+ id
2374
+ name
2375
+ displayName
2376
+ description
2377
+ config
2378
+ sweepName
2379
+ project {
2380
+ id
2381
+ name
2382
+ entity {
2383
+ id
2384
+ name
2385
+ }
2386
+ }
2387
+ historyLineCount
2388
+ }
2389
+ }
2390
+ }
2391
+ """
2392
+
2393
+ mutation = gql(query_string)
2394
+
2395
+ kwargs = {}
2396
+ if num_retries is not None:
2397
+ kwargs["num_retries"] = num_retries
2398
+
2399
+ variable_values = {
2400
+ "runName": run_name,
2401
+ "entity": entity or self.settings("entity"),
2402
+ "project": project or util.auto_project_name(program_path),
2403
+ "metricName": metric_name,
2404
+ "metricValue": metric_value,
2405
+ }
2406
+
2407
+ # retry conflict errors for 2 minutes, default to no_auth_retry
2408
+ check_retry_fn = util.make_check_retry_fn(
2409
+ check_fn=util.check_retry_conflict_or_gone,
2410
+ check_timedelta=datetime.timedelta(minutes=2),
2411
+ fallback_retry_fn=util.no_retry_auth,
2412
+ )
2413
+
2414
+ response = self.gql(
2415
+ mutation,
2416
+ variable_values=variable_values,
2417
+ check_retry_fn=check_retry_fn,
2418
+ **kwargs,
2419
+ )
2420
+
2421
+ run_obj: Dict[str, Dict[str, Dict[str, str]]] = response.get(
2422
+ "rewindRun", {}
2423
+ ).get("rewoundRun", {})
2424
+ project_obj: Dict[str, Dict[str, str]] = run_obj.get("project", {})
2425
+ if project_obj:
2426
+ self.set_setting("project", project_obj["name"])
2427
+ entity_obj = project_obj.get("entity", {})
2428
+ if entity_obj:
2429
+ self.set_setting("entity", entity_obj["name"])
2430
+
2431
+ return run_obj
2432
+
2433
+ @normalize_exceptions
2434
+ def get_run_info(
2435
+ self,
2436
+ entity: str,
2437
+ project: str,
2438
+ name: str,
2439
+ ) -> dict:
2440
+ query = gql(
2441
+ """
2442
+ query RunInfo($project: String!, $entity: String!, $name: String!) {
2443
+ project(name: $project, entityName: $entity) {
2444
+ run(name: $name) {
2445
+ runInfo {
2446
+ program
2447
+ args
2448
+ os
2449
+ python
2450
+ colab
2451
+ executable
2452
+ codeSaved
2453
+ cpuCount
2454
+ gpuCount
2455
+ gpu
2456
+ git {
2457
+ remote
2458
+ commit
2459
+ }
2460
+ }
2461
+ }
2462
+ }
2463
+ }
2464
+ """
2465
+ )
2466
+ variable_values = {"project": project, "entity": entity, "name": name}
2467
+ res = self.gql(query, variable_values)
2468
+ if res.get("project") is None:
2469
+ raise CommError(
2470
+ "Error fetching run info for {}/{}/{}. Check that this project exists and you have access to this entity and project".format(
2471
+ entity, project, name
2472
+ )
2473
+ )
2474
+ elif res["project"].get("run") is None:
2475
+ raise CommError(
2476
+ "Error fetching run info for {}/{}/{}. Check that this run id exists".format(
2477
+ entity, project, name
2478
+ )
2479
+ )
2480
+ run_info: dict = res["project"]["run"]["runInfo"]
2481
+ return run_info
2482
+
2483
+ @normalize_exceptions
2484
+ def get_run_state(self, entity: str, project: str, name: str) -> str:
2485
+ query = gql(
2486
+ """
2487
+ query RunState(
2488
+ $project: String!,
2489
+ $entity: String!,
2490
+ $name: String!) {
2491
+ project(name: $project, entityName: $entity) {
2492
+ run(name: $name) {
2493
+ state
2494
+ }
2495
+ }
2496
+ }
2497
+ """
2498
+ )
2499
+ variable_values = {
2500
+ "project": project,
2501
+ "entity": entity,
2502
+ "name": name,
2503
+ }
2504
+ res = self.gql(query, variable_values)
2505
+ if res.get("project") is None or res["project"].get("run") is None:
2506
+ raise CommError(f"Error fetching run state for {entity}/{project}/{name}.")
2507
+ run_state: str = res["project"]["run"]["state"]
2508
+ return run_state
2509
+
2510
+ @normalize_exceptions
2511
+ def create_run_files_introspection(self) -> bool:
2512
+ _, _, mutations = self.server_info_introspection()
2513
+ return "createRunFiles" in mutations
2514
+
2515
+ @normalize_exceptions
2516
+ def upload_urls(
2517
+ self,
2518
+ project: str,
2519
+ files: Union[List[str], Dict[str, IO]],
2520
+ run: Optional[str] = None,
2521
+ entity: Optional[str] = None,
2522
+ description: Optional[str] = None,
2523
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2524
+ """Generate temporary resumable upload urls.
2525
+
2526
+ Arguments:
2527
+ project (str): The project to download
2528
+ files (list or dict): The filenames to upload
2529
+ run (str, optional): The run to upload to
2530
+ entity (str, optional): The entity to scope this project to.
2531
+ description (str, optional): description
2532
+
2533
+ Returns:
2534
+ (run_id, upload_headers, file_info)
2535
+ run_id: id of run we uploaded files to
2536
+ upload_headers: A list of headers to use when uploading files.
2537
+ file_info: A dict of filenames and urls.
2538
+ {
2539
+ "run_id": "run_id",
2540
+ "upload_headers": [""],
2541
+ "file_info": [
2542
+ { "weights.h5": { "uploadUrl": "https://weights.url" } },
2543
+ { "model.json": { "uploadUrl": "https://model.json" } }
2544
+ ]
2545
+ }
2546
+ """
2547
+ run_name = run or self.current_run_id
2548
+ assert run_name, "run must be specified"
2549
+ entity = entity or self.settings("entity")
2550
+ assert entity, "entity must be specified"
2551
+
2552
+ has_create_run_files_mutation = self.create_run_files_introspection()
2553
+ if not has_create_run_files_mutation:
2554
+ return self.legacy_upload_urls(project, files, run, entity, description)
2555
+
2556
+ query = gql(
2557
+ """
2558
+ mutation CreateRunFiles($entity: String!, $project: String!, $run: String!, $files: [String!]!) {
2559
+ createRunFiles(input: {entityName: $entity, projectName: $project, runName: $run, files: $files}) {
2560
+ runID
2561
+ uploadHeaders
2562
+ files {
2563
+ name
2564
+ uploadUrl
2565
+ }
2566
+ }
2567
+ }
2568
+ """
2569
+ )
2570
+
2571
+ query_result = self.gql(
2572
+ query,
2573
+ variable_values={
2574
+ "project": project,
2575
+ "run": run_name,
2576
+ "entity": entity,
2577
+ "files": [file for file in files],
2578
+ },
2579
+ )
2580
+
2581
+ result = query_result["createRunFiles"]
2582
+ run_id = result["runID"]
2583
+ if not run_id:
2584
+ raise CommError(
2585
+ f"Error uploading files to {entity}/{project}/{run_name}. Check that this project exists and you have access to this entity and project"
2586
+ )
2587
+ file_name_urls = {file["name"]: file for file in result["files"]}
2588
+ return run_id, result["uploadHeaders"], file_name_urls
2589
+
2590
+ def legacy_upload_urls(
2591
+ self,
2592
+ project: str,
2593
+ files: Union[List[str], Dict[str, IO]],
2594
+ run: Optional[str] = None,
2595
+ entity: Optional[str] = None,
2596
+ description: Optional[str] = None,
2597
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2598
+ """Generate temporary resumable upload urls.
2599
+
2600
+ A new mutation createRunFiles was introduced after 0.15.4.
2601
+ This function is used to support older versions.
2602
+ """
2603
+ query = gql(
2604
+ """
2605
+ query RunUploadUrls($name: String!, $files: [String]!, $entity: String, $run: String!, $description: String) {
2606
+ model(name: $name, entityName: $entity) {
2607
+ bucket(name: $run, desc: $description) {
2608
+ id
2609
+ files(names: $files) {
2610
+ uploadHeaders
2611
+ edges {
2612
+ node {
2613
+ name
2614
+ url(upload: true)
2615
+ updatedAt
2616
+ }
2617
+ }
2618
+ }
2619
+ }
2620
+ }
2621
+ }
2622
+ """
2623
+ )
2624
+ run_id = run or self.current_run_id
2625
+ assert run_id, "run must be specified"
2626
+ entity = entity or self.settings("entity")
2627
+ query_result = self.gql(
2628
+ query,
2629
+ variable_values={
2630
+ "name": project,
2631
+ "run": run_id,
2632
+ "entity": entity,
2633
+ "files": [file for file in files],
2634
+ "description": description,
2635
+ },
2636
+ )
2637
+
2638
+ run_obj = query_result["model"]["bucket"]
2639
+ if run_obj:
2640
+ for file_node in run_obj["files"]["edges"]:
2641
+ file = file_node["node"]
2642
+ # we previously used "url" field but now use "uploadUrl"
2643
+ # replace the "url" field with "uploadUrl for downstream compatibility
2644
+ if "url" in file and "uploadUrl" not in file:
2645
+ file["uploadUrl"] = file.pop("url")
2646
+
2647
+ result = {
2648
+ file["name"]: file for file in self._flatten_edges(run_obj["files"])
2649
+ }
2650
+ return run_obj["id"], run_obj["files"]["uploadHeaders"], result
2651
+ else:
2652
+ raise CommError(f"Run does not exist {entity}/{project}/{run_id}.")
2653
+
2654
+ @normalize_exceptions
2655
+ def download_urls(
2656
+ self,
2657
+ project: str,
2658
+ run: Optional[str] = None,
2659
+ entity: Optional[str] = None,
2660
+ ) -> Dict[str, Dict[str, str]]:
2661
+ """Generate download urls.
2662
+
2663
+ Arguments:
2664
+ project (str): The project to download
2665
+ run (str): The run to upload to
2666
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2667
+
2668
+ Returns:
2669
+ A dict of extensions and urls
2670
+
2671
+ {
2672
+ 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
2673
+ 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2674
+ }
2675
+ """
2676
+ query = gql(
2677
+ """
2678
+ query RunDownloadUrls($name: String!, $entity: String, $run: String!) {
2679
+ model(name: $name, entityName: $entity) {
2680
+ bucket(name: $run) {
2681
+ files {
2682
+ edges {
2683
+ node {
2684
+ name
2685
+ url
2686
+ md5
2687
+ updatedAt
2688
+ }
2689
+ }
2690
+ }
2691
+ }
2692
+ }
2693
+ }
2694
+ """
2695
+ )
2696
+ run = run or self.current_run_id
2697
+ assert run, "run must be specified"
2698
+ entity = entity or self.settings("entity")
2699
+ query_result = self.gql(
2700
+ query,
2701
+ variable_values={
2702
+ "name": project,
2703
+ "run": run,
2704
+ "entity": entity,
2705
+ },
2706
+ )
2707
+ if query_result["model"] is None:
2708
+ raise CommError(f"Run does not exist {entity}/{project}/{run}.")
2709
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2710
+ return {file["name"]: file for file in files if file}
2711
+
2712
+ @normalize_exceptions
2713
+ def download_url(
2714
+ self,
2715
+ project: str,
2716
+ file_name: str,
2717
+ run: Optional[str] = None,
2718
+ entity: Optional[str] = None,
2719
+ ) -> Optional[Dict[str, str]]:
2720
+ """Generate download urls.
2721
+
2722
+ Arguments:
2723
+ project (str): The project to download
2724
+ file_name (str): The name of the file to download
2725
+ run (str): The run to upload to
2726
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2727
+
2728
+ Returns:
2729
+ A dict of extensions and urls
2730
+
2731
+ { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2732
+
2733
+ """
2734
+ query = gql(
2735
+ """
2736
+ query RunDownloadUrl($name: String!, $fileName: String!, $entity: String, $run: String!) {
2737
+ model(name: $name, entityName: $entity) {
2738
+ bucket(name: $run) {
2739
+ files(names: [$fileName]) {
2740
+ edges {
2741
+ node {
2742
+ name
2743
+ url
2744
+ md5
2745
+ updatedAt
2746
+ }
2747
+ }
2748
+ }
2749
+ }
2750
+ }
2751
+ }
2752
+ """
2753
+ )
2754
+ run = run or self.current_run_id
2755
+ assert run, "run must be specified"
2756
+ query_result = self.gql(
2757
+ query,
2758
+ variable_values={
2759
+ "name": project,
2760
+ "run": run,
2761
+ "fileName": file_name,
2762
+ "entity": entity or self.settings("entity"),
2763
+ },
2764
+ )
2765
+ if query_result["model"]:
2766
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2767
+ return files[0] if len(files) > 0 and files[0].get("updatedAt") else None
2768
+ else:
2769
+ return None
2770
+
2771
+ @normalize_exceptions
2772
+ def download_file(self, url: str) -> Tuple[int, requests.Response]:
2773
+ """Initiate a streaming download.
2774
+
2775
+ Arguments:
2776
+ url (str): The url to download
2777
+
2778
+ Returns:
2779
+ A tuple of the content length and the streaming response
2780
+ """
2781
+ check_httpclient_logger_handler()
2782
+
2783
+ http_headers = _thread_local_api_settings.headers or {}
2784
+
2785
+ auth = None
2786
+ if self.access_token is not None:
2787
+ http_headers["Authorization"] = f"Bearer {self.access_token}"
2788
+ elif _thread_local_api_settings.cookies is None:
2789
+ auth = ("api", self.api_key or "")
2790
+
2791
+ response = requests.get(
2792
+ url,
2793
+ auth=auth,
2794
+ cookies=_thread_local_api_settings.cookies or {},
2795
+ headers=http_headers,
2796
+ stream=True,
2797
+ )
2798
+ response.raise_for_status()
2799
+ return int(response.headers.get("content-length", 0)), response
2800
+
2801
+ @normalize_exceptions
2802
+ def download_write_file(
2803
+ self,
2804
+ metadata: Dict[str, str],
2805
+ out_dir: Optional[str] = None,
2806
+ ) -> Tuple[str, Optional[requests.Response]]:
2807
+ """Download a file from a run and write it to wandb/.
2808
+
2809
+ Arguments:
2810
+ metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
2811
+ out_dir (str, optional): The directory to write the file to. Defaults to wandb/
2812
+
2813
+ Returns:
2814
+ A tuple of the file's local path and the streaming response. The streaming response is None if the file
2815
+ already existed and was up-to-date.
2816
+ """
2817
+ filename = metadata["name"]
2818
+ path = os.path.join(out_dir or self.settings("wandb_dir"), filename)
2819
+ if self.file_current(filename, B64MD5(metadata["md5"])):
2820
+ return path, None
2821
+
2822
+ size, response = self.download_file(metadata["url"])
2823
+
2824
+ with util.fsync_open(path, "wb") as file:
2825
+ for data in response.iter_content(chunk_size=1024):
2826
+ file.write(data)
2827
+
2828
+ return path, response
2829
+
2830
+ def upload_file_azure(
2831
+ self, url: str, file: Any, extra_headers: Dict[str, str]
2832
+ ) -> None:
2833
+ """Upload a file to azure."""
2834
+ from azure.core.exceptions import AzureError # type: ignore
2835
+
2836
+ # Configure the client without retries so our existing logic can handle them
2837
+ client = self._azure_blob_module.BlobClient.from_blob_url(
2838
+ url, retry_policy=self._azure_blob_module.LinearRetry(retry_total=0)
2839
+ )
2840
+ try:
2841
+ if extra_headers.get("Content-MD5") is not None:
2842
+ md5: Optional[bytes] = base64.b64decode(extra_headers["Content-MD5"])
2843
+ else:
2844
+ md5 = None
2845
+ content_settings = self._azure_blob_module.ContentSettings(
2846
+ content_md5=md5,
2847
+ content_type=extra_headers.get("Content-Type"),
2848
+ )
2849
+ client.upload_blob(
2850
+ file,
2851
+ max_concurrency=4,
2852
+ length=len(file),
2853
+ overwrite=True,
2854
+ content_settings=content_settings,
2855
+ )
2856
+ except AzureError as e:
2857
+ if hasattr(e, "response"):
2858
+ response = requests.models.Response()
2859
+ response.status_code = e.response.status_code
2860
+ response.headers = e.response.headers
2861
+ raise requests.exceptions.RequestException(e.message, response=response)
2862
+ else:
2863
+ raise requests.exceptions.ConnectionError(e.message)
2864
+
2865
+ def upload_multipart_file_chunk(
2866
+ self,
2867
+ url: str,
2868
+ upload_chunk: bytes,
2869
+ extra_headers: Optional[Dict[str, str]] = None,
2870
+ ) -> Optional[requests.Response]:
2871
+ """Upload a file chunk to S3 with failure resumption.
2872
+
2873
+ Arguments:
2874
+ url: The url to download
2875
+ upload_chunk: The path to the file you want to upload
2876
+ extra_headers: A dictionary of extra headers to send with the request
2877
+
2878
+ Returns:
2879
+ The `requests` library response object
2880
+ """
2881
+ check_httpclient_logger_handler()
2882
+ try:
2883
+ if env.is_debug(env=self._environ):
2884
+ logger.debug("upload_file: %s", url)
2885
+ response = self._upload_file_session.put(
2886
+ url, data=upload_chunk, headers=extra_headers
2887
+ )
2888
+ if env.is_debug(env=self._environ):
2889
+ logger.debug("upload_file: %s complete", url)
2890
+ response.raise_for_status()
2891
+ except requests.exceptions.RequestException as e:
2892
+ logger.error(f"upload_file exception {url}: {e}")
2893
+ request_headers = e.request.headers if e.request is not None else ""
2894
+ logger.error(f"upload_file request headers: {request_headers!r}")
2895
+ response_content = e.response.content if e.response is not None else ""
2896
+ logger.error(f"upload_file response body: {response_content!r}")
2897
+ status_code = e.response.status_code if e.response is not None else 0
2898
+ # S3 reports retryable request timeouts out-of-band
2899
+ is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
2900
+ response_content
2901
+ )
2902
+ # Retry errors from cloud storage or local network issues
2903
+ if (
2904
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2905
+ or isinstance(
2906
+ e,
2907
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2908
+ )
2909
+ or is_aws_retryable
2910
+ ):
2911
+ _e = retry.TransientError(exc=e)
2912
+ raise _e.with_traceback(sys.exc_info()[2])
2913
+ else:
2914
+ wandb._sentry.reraise(e)
2915
+ return response
2916
+
2917
+ def upload_file(
2918
+ self,
2919
+ url: str,
2920
+ file: IO[bytes],
2921
+ callback: Optional["ProgressFn"] = None,
2922
+ extra_headers: Optional[Dict[str, str]] = None,
2923
+ ) -> Optional[requests.Response]:
2924
+ """Upload a file to W&B with failure resumption.
2925
+
2926
+ Arguments:
2927
+ url: The url to download
2928
+ file: The path to the file you want to upload
2929
+ callback: A callback which is passed the number of
2930
+ bytes uploaded since the last time it was called, used to report progress
2931
+ extra_headers: A dictionary of extra headers to send with the request
2932
+
2933
+ Returns:
2934
+ The `requests` library response object
2935
+ """
2936
+ check_httpclient_logger_handler()
2937
+ extra_headers = extra_headers.copy() if extra_headers else {}
2938
+ response: Optional[requests.Response] = None
2939
+ progress = Progress(file, callback=callback)
2940
+ try:
2941
+ if "x-ms-blob-type" in extra_headers and self._azure_blob_module:
2942
+ self.upload_file_azure(url, progress, extra_headers)
2943
+ else:
2944
+ if "x-ms-blob-type" in extra_headers:
2945
+ wandb.termwarn(
2946
+ "Azure uploads over 256MB require the azure SDK, install with pip install wandb[azure]",
2947
+ repeat=False,
2948
+ )
2949
+ if env.is_debug(env=self._environ):
2950
+ logger.debug("upload_file: %s", url)
2951
+ response = self._upload_file_session.put(
2952
+ url, data=progress, headers=extra_headers
2953
+ )
2954
+ if env.is_debug(env=self._environ):
2955
+ logger.debug("upload_file: %s complete", url)
2956
+ response.raise_for_status()
2957
+ except requests.exceptions.RequestException as e:
2958
+ logger.error(f"upload_file exception {url}: {e}")
2959
+ request_headers = e.request.headers if e.request is not None else ""
2960
+ logger.error(f"upload_file request headers: {request_headers}")
2961
+ response_content = e.response.content if e.response is not None else ""
2962
+ logger.error(f"upload_file response body: {response_content!r}")
2963
+ status_code = e.response.status_code if e.response is not None else 0
2964
+ # S3 reports retryable request timeouts out-of-band
2965
+ is_aws_retryable = (
2966
+ "x-amz-meta-md5" in extra_headers
2967
+ and status_code == 400
2968
+ and "RequestTimeout" in str(response_content)
2969
+ )
2970
+ # We need to rewind the file for the next retry (the file passed in is `seek`'ed to 0)
2971
+ progress.rewind()
2972
+ # Retry errors from cloud storage or local network issues
2973
+ if (
2974
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2975
+ or isinstance(
2976
+ e,
2977
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2978
+ )
2979
+ or is_aws_retryable
2980
+ ):
2981
+ _e = retry.TransientError(exc=e)
2982
+ raise _e.with_traceback(sys.exc_info()[2])
2983
+ else:
2984
+ wandb._sentry.reraise(e)
2985
+
2986
+ return response
2987
+
2988
+ @normalize_exceptions
2989
+ def register_agent(
2990
+ self,
2991
+ host: str,
2992
+ sweep_id: Optional[str] = None,
2993
+ project_name: Optional[str] = None,
2994
+ entity: Optional[str] = None,
2995
+ ) -> dict:
2996
+ """Register a new agent.
2997
+
2998
+ Arguments:
2999
+ host (str): hostname
3000
+ sweep_id (str): sweep id
3001
+ project_name: (str): model that contains sweep
3002
+ entity: (str): entity that contains sweep
3003
+ """
3004
+ mutation = gql(
3005
+ """
3006
+ mutation CreateAgent(
3007
+ $host: String!
3008
+ $projectName: String,
3009
+ $entityName: String,
3010
+ $sweep: String!
3011
+ ) {
3012
+ createAgent(input: {
3013
+ host: $host,
3014
+ projectName: $projectName,
3015
+ entityName: $entityName,
3016
+ sweep: $sweep,
3017
+ }) {
3018
+ agent {
3019
+ id
3020
+ }
3021
+ }
3022
+ }
3023
+ """
3024
+ )
3025
+ if entity is None:
3026
+ entity = self.settings("entity")
3027
+ if project_name is None:
3028
+ project_name = self.settings("project")
3029
+
3030
+ response = self.gql(
3031
+ mutation,
3032
+ variable_values={
3033
+ "host": host,
3034
+ "entityName": entity,
3035
+ "projectName": project_name,
3036
+ "sweep": sweep_id,
3037
+ },
3038
+ check_retry_fn=util.no_retry_4xx,
3039
+ )
3040
+ result: dict = response["createAgent"]["agent"]
3041
+ return result
3042
+
3043
+ def agent_heartbeat(
3044
+ self, agent_id: str, metrics: dict, run_states: dict
3045
+ ) -> List[Dict[str, Any]]:
3046
+ """Notify server about agent state, receive commands.
3047
+
3048
+ Arguments:
3049
+ agent_id (str): agent_id
3050
+ metrics (dict): system metrics
3051
+ run_states (dict): run_id: state mapping
3052
+ Returns:
3053
+ List of commands to execute.
3054
+ """
3055
+ mutation = gql(
3056
+ """
3057
+ mutation Heartbeat(
3058
+ $id: ID!,
3059
+ $metrics: JSONString,
3060
+ $runState: JSONString
3061
+ ) {
3062
+ agentHeartbeat(input: {
3063
+ id: $id,
3064
+ metrics: $metrics,
3065
+ runState: $runState
3066
+ }) {
3067
+ agent {
3068
+ id
3069
+ }
3070
+ commands
3071
+ }
3072
+ }
3073
+ """
3074
+ )
3075
+
3076
+ if agent_id is None:
3077
+ raise ValueError("Cannot call heartbeat with an unregistered agent.")
3078
+
3079
+ try:
3080
+ response = self.gql(
3081
+ mutation,
3082
+ variable_values={
3083
+ "id": agent_id,
3084
+ "metrics": json.dumps(metrics),
3085
+ "runState": json.dumps(run_states),
3086
+ },
3087
+ timeout=60,
3088
+ )
3089
+ except Exception as e:
3090
+ # GQL raises exceptions with stringified python dictionaries :/
3091
+ message = ast.literal_eval(e.args[0])["message"]
3092
+ logger.error("Error communicating with W&B: %s", message)
3093
+ return []
3094
+ else:
3095
+ result: List[Dict[str, Any]] = json.loads(
3096
+ response["agentHeartbeat"]["commands"]
3097
+ )
3098
+ return result
3099
+
3100
+ @staticmethod
3101
+ def _validate_config_and_fill_distribution(config: dict) -> dict:
3102
+ # verify that parameters are well specified.
3103
+ # TODO(dag): deprecate this in favor of jsonschema validation once
3104
+ # apiVersion 2 is released and local controller is integrated with
3105
+ # wandb/client.
3106
+
3107
+ # avoid modifying the original config dict in
3108
+ # case it is reused outside the calling func
3109
+ config = deepcopy(config)
3110
+
3111
+ # explicitly cast to dict in case config was passed as a sweepconfig
3112
+ # sweepconfig does not serialize cleanly to yaml and breaks graphql,
3113
+ # but it is a subclass of dict, so this conversion is clean
3114
+ config = dict(config)
3115
+
3116
+ if "parameters" not in config:
3117
+ # still shows an anaconda warning, but doesn't error
3118
+ return config
3119
+
3120
+ for parameter_name in config["parameters"]:
3121
+ parameter = config["parameters"][parameter_name]
3122
+ if "min" in parameter and "max" in parameter:
3123
+ if "distribution" not in parameter:
3124
+ if isinstance(parameter["min"], int) and isinstance(
3125
+ parameter["max"], int
3126
+ ):
3127
+ parameter["distribution"] = "int_uniform"
3128
+ elif isinstance(parameter["min"], float) and isinstance(
3129
+ parameter["max"], float
3130
+ ):
3131
+ parameter["distribution"] = "uniform"
3132
+ else:
3133
+ raise ValueError(
3134
+ "Parameter {} is ambiguous, please specify bounds as both floats (for a float_"
3135
+ "uniform distribution) or ints (for an int_uniform distribution).".format(
3136
+ parameter_name
3137
+ )
3138
+ )
3139
+ return config
3140
+
3141
+ @normalize_exceptions
3142
+ def upsert_sweep(
3143
+ self,
3144
+ config: dict,
3145
+ controller: Optional[str] = None,
3146
+ launch_scheduler: Optional[str] = None,
3147
+ scheduler: Optional[str] = None,
3148
+ obj_id: Optional[str] = None,
3149
+ project: Optional[str] = None,
3150
+ entity: Optional[str] = None,
3151
+ state: Optional[str] = None,
3152
+ prior_runs: Optional[List[str]] = None,
3153
+ template_variable_values: Optional[Dict[str, Any]] = None,
3154
+ ) -> Tuple[str, List[str]]:
3155
+ """Upsert a sweep object.
3156
+
3157
+ Arguments:
3158
+ config (dict): sweep config (will be converted to yaml)
3159
+ controller (str): controller to use
3160
+ launch_scheduler (str): launch scheduler to use
3161
+ scheduler (str): scheduler to use
3162
+ obj_id (str): object id
3163
+ project (str): project to use
3164
+ entity (str): entity to use
3165
+ state (str): state
3166
+ prior_runs (list): IDs of existing runs to add to the sweep
3167
+ template_variable_values (dict): template variable values
3168
+ """
3169
+ project_query = """
3170
+ project {
3171
+ id
3172
+ name
3173
+ entity {
3174
+ id
3175
+ name
3176
+ }
3177
+ }
3178
+ """
3179
+ mutation_str = """
3180
+ mutation UpsertSweep(
3181
+ $id: ID,
3182
+ $config: String,
3183
+ $description: String,
3184
+ $entityName: String,
3185
+ $projectName: String,
3186
+ $controller: JSONString,
3187
+ $scheduler: JSONString,
3188
+ $state: String,
3189
+ $priorRunsFilters: JSONString,
3190
+ ) {
3191
+ upsertSweep(input: {
3192
+ id: $id,
3193
+ config: $config,
3194
+ description: $description,
3195
+ entityName: $entityName,
3196
+ projectName: $projectName,
3197
+ controller: $controller,
3198
+ scheduler: $scheduler,
3199
+ state: $state,
3200
+ priorRunsFilters: $priorRunsFilters,
3201
+ }) {
3202
+ sweep {
3203
+ name
3204
+ _PROJECT_QUERY_
3205
+ }
3206
+ configValidationWarnings
3207
+ }
3208
+ }
3209
+ """
3210
+ # TODO(jhr): we need protocol versioning to know schema is not supported
3211
+ # for now we will just try both new and old query
3212
+ mutation_5 = gql(
3213
+ mutation_str.replace(
3214
+ "$controller: JSONString,",
3215
+ "$controller: JSONString,$launchScheduler: JSONString, $templateVariableValues: JSONString,",
3216
+ )
3217
+ .replace(
3218
+ "controller: $controller,",
3219
+ "controller: $controller,launchScheduler: $launchScheduler,templateVariableValues: $templateVariableValues,",
3220
+ )
3221
+ .replace("_PROJECT_QUERY_", project_query)
3222
+ )
3223
+ # launchScheduler was introduced in core v0.14.0
3224
+ mutation_4 = gql(
3225
+ mutation_str.replace(
3226
+ "$controller: JSONString,",
3227
+ "$controller: JSONString,$launchScheduler: JSONString,",
3228
+ )
3229
+ .replace(
3230
+ "controller: $controller,",
3231
+ "controller: $controller,launchScheduler: $launchScheduler",
3232
+ )
3233
+ .replace("_PROJECT_QUERY_", project_query)
3234
+ )
3235
+
3236
+ # mutation 3 maps to backend that can support CLI version of at least 0.10.31
3237
+ mutation_3 = gql(mutation_str.replace("_PROJECT_QUERY_", project_query))
3238
+ mutation_2 = gql(
3239
+ mutation_str.replace("_PROJECT_QUERY_", project_query).replace(
3240
+ "configValidationWarnings", ""
3241
+ )
3242
+ )
3243
+ mutation_1 = gql(
3244
+ mutation_str.replace("_PROJECT_QUERY_", "").replace(
3245
+ "configValidationWarnings", ""
3246
+ )
3247
+ )
3248
+
3249
+ # TODO(dag): replace this with a query for protocol versioning
3250
+ mutations = [mutation_5, mutation_4, mutation_3, mutation_2, mutation_1]
3251
+
3252
+ config = self._validate_config_and_fill_distribution(config)
3253
+
3254
+ # Silly, but attr-dicts like EasyDicts don't serialize correctly to yaml.
3255
+ # This sanitizes them with a round trip pass through json to get a regular dict.
3256
+ config_str = yaml.dump(
3257
+ json.loads(json.dumps(config)), Dumper=util.NonOctalStringDumper
3258
+ )
3259
+ filters = None
3260
+ if prior_runs:
3261
+ filters = json.dumps({"$or": [{"name": r} for r in prior_runs]})
3262
+
3263
+ err: Optional[Exception] = None
3264
+ for mutation in mutations:
3265
+ try:
3266
+ variables = {
3267
+ "id": obj_id,
3268
+ "config": config_str,
3269
+ "description": config.get("description"),
3270
+ "entityName": entity or self.settings("entity"),
3271
+ "projectName": project or self.settings("project"),
3272
+ "controller": controller,
3273
+ "launchScheduler": launch_scheduler,
3274
+ "templateVariableValues": json.dumps(template_variable_values),
3275
+ "scheduler": scheduler,
3276
+ "priorRunsFilters": filters,
3277
+ }
3278
+ if state:
3279
+ variables["state"] = state
3280
+
3281
+ response = self.gql(
3282
+ mutation,
3283
+ variable_values=variables,
3284
+ check_retry_fn=util.no_retry_4xx,
3285
+ )
3286
+ except UsageError as e:
3287
+ raise e
3288
+ except Exception as e:
3289
+ # graphql schema exception is generic
3290
+ err = e
3291
+ continue
3292
+ err = None
3293
+ break
3294
+ if err:
3295
+ raise err
3296
+
3297
+ sweep: Dict[str, Dict[str, Dict]] = response["upsertSweep"]["sweep"]
3298
+ project_obj: Dict[str, Dict] = sweep.get("project", {})
3299
+ if project_obj:
3300
+ self.set_setting("project", project_obj["name"])
3301
+ entity_obj: dict = project_obj.get("entity", {})
3302
+ if entity_obj:
3303
+ self.set_setting("entity", entity_obj["name"])
3304
+
3305
+ warnings = response["upsertSweep"].get("configValidationWarnings", [])
3306
+ return response["upsertSweep"]["sweep"]["name"], warnings
3307
+
3308
+ @normalize_exceptions
3309
+ def create_anonymous_api_key(self) -> str:
3310
+ """Create a new API key belonging to a new anonymous user."""
3311
+ mutation = gql(
3312
+ """
3313
+ mutation CreateAnonymousApiKey {
3314
+ createAnonymousEntity(input: {}) {
3315
+ apiKey {
3316
+ name
3317
+ }
3318
+ }
3319
+ }
3320
+ """
3321
+ )
3322
+
3323
+ response = self.gql(mutation, variable_values={})
3324
+ key: str = str(response["createAnonymousEntity"]["apiKey"]["name"])
3325
+ return key
3326
+
3327
+ @staticmethod
3328
+ def file_current(fname: str, md5: B64MD5) -> bool:
3329
+ """Checksum a file and compare the md5 with the known md5."""
3330
+ return os.path.isfile(fname) and md5_file_b64(fname) == md5
3331
+
3332
+ @normalize_exceptions
3333
+ def pull(
3334
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
3335
+ ) -> "List[requests.Response]":
3336
+ """Download files from W&B.
3337
+
3338
+ Arguments:
3339
+ project (str): The project to download
3340
+ run (str, optional): The run to upload to
3341
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3342
+
3343
+ Returns:
3344
+ The `requests` library response object
3345
+ """
3346
+ project, run = self.parse_slug(project, run=run)
3347
+ urls = self.download_urls(project, run, entity)
3348
+ responses = []
3349
+ for filename in urls:
3350
+ _, response = self.download_write_file(urls[filename])
3351
+ if response:
3352
+ responses.append(response)
3353
+
3354
+ return responses
3355
+
3356
+ def get_project(self) -> str:
3357
+ project: str = self.default_settings.get("project") or self.settings("project")
3358
+ return project
3359
+
3360
+ @normalize_exceptions
3361
+ def push(
3362
+ self,
3363
+ files: Union[List[str], Dict[str, IO]],
3364
+ run: Optional[str] = None,
3365
+ entity: Optional[str] = None,
3366
+ project: Optional[str] = None,
3367
+ description: Optional[str] = None,
3368
+ force: bool = True,
3369
+ progress: Union[TextIO, bool] = False,
3370
+ ) -> "List[Optional[requests.Response]]":
3371
+ """Uploads multiple files to W&B.
3372
+
3373
+ Arguments:
3374
+ files (list or dict): The filenames to upload, when dict the values are open files
3375
+ run (str, optional): The run to upload to
3376
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3377
+ project (str, optional): The name of the project to upload to. Defaults to the one in settings.
3378
+ description (str, optional): The description of the changes
3379
+ force (bool, optional): Whether to prevent push if git has uncommitted changes
3380
+ progress (callable, or stream): If callable, will be called with (chunk_bytes,
3381
+ total_bytes) as argument else if True, renders a progress bar to stream.
3382
+
3383
+ Returns:
3384
+ A list of `requests.Response` objects
3385
+ """
3386
+ if project is None:
3387
+ project = self.get_project()
3388
+ if project is None:
3389
+ raise CommError("No project configured.")
3390
+ if run is None:
3391
+ run = self.current_run_id
3392
+
3393
+ # TODO(adrian): we use a retriable version of self.upload_file() so
3394
+ # will never retry self.upload_urls() here. Instead, maybe we should
3395
+ # make push itself retriable.
3396
+ _, upload_headers, result = self.upload_urls(
3397
+ project,
3398
+ files,
3399
+ run,
3400
+ entity,
3401
+ )
3402
+ extra_headers = {}
3403
+ for upload_header in upload_headers:
3404
+ key, val = upload_header.split(":", 1)
3405
+ extra_headers[key] = val
3406
+ responses = []
3407
+ for file_name, file_info in result.items():
3408
+ file_url = file_info["uploadUrl"]
3409
+
3410
+ # If the upload URL is relative, fill it in with the base URL,
3411
+ # since it's a proxied file store like the on-prem VM.
3412
+ if file_url.startswith("/"):
3413
+ file_url = f"{self.api_url}{file_url}"
3414
+
3415
+ try:
3416
+ # To handle Windows paths
3417
+ # TODO: this doesn't handle absolute paths...
3418
+ normal_name = os.path.join(*file_name.split("/"))
3419
+ open_file = (
3420
+ files[file_name]
3421
+ if isinstance(files, dict)
3422
+ else open(normal_name, "rb")
3423
+ )
3424
+ except OSError:
3425
+ print(f"{file_name} does not exist")
3426
+ continue
3427
+ if progress is False:
3428
+ responses.append(
3429
+ self.upload_file_retry(
3430
+ file_info["uploadUrl"], open_file, extra_headers=extra_headers
3431
+ )
3432
+ )
3433
+ else:
3434
+ if callable(progress):
3435
+ responses.append( # type: ignore
3436
+ self.upload_file_retry(
3437
+ file_url, open_file, progress, extra_headers=extra_headers
3438
+ )
3439
+ )
3440
+ else:
3441
+ length = os.fstat(open_file.fileno()).st_size
3442
+ with click.progressbar(
3443
+ file=progress, # type: ignore
3444
+ length=length,
3445
+ label=f"Uploading file: {file_name}",
3446
+ fill_char=click.style("&", fg="green"),
3447
+ ) as bar:
3448
+ responses.append(
3449
+ self.upload_file_retry(
3450
+ file_url,
3451
+ open_file,
3452
+ lambda bites, _: bar.update(bites),
3453
+ extra_headers=extra_headers,
3454
+ )
3455
+ )
3456
+ open_file.close()
3457
+ return responses
3458
+
3459
+ def link_artifact(
3460
+ self,
3461
+ client_id: str,
3462
+ server_id: str,
3463
+ portfolio_name: str,
3464
+ entity: str,
3465
+ project: str,
3466
+ aliases: Sequence[str],
3467
+ ) -> Dict[str, Any]:
3468
+ template = """
3469
+ mutation LinkArtifact(
3470
+ $artifactPortfolioName: String!,
3471
+ $entityName: String!,
3472
+ $projectName: String!,
3473
+ $aliases: [ArtifactAliasInput!],
3474
+ ID_TYPE
3475
+ ) {
3476
+ linkArtifact(input: {
3477
+ artifactPortfolioName: $artifactPortfolioName,
3478
+ entityName: $entityName,
3479
+ projectName: $projectName,
3480
+ aliases: $aliases,
3481
+ ID_VALUE
3482
+ }) {
3483
+ versionIndex
3484
+ }
3485
+ }
3486
+ """
3487
+
3488
+ def replace(a: str, b: str) -> None:
3489
+ nonlocal template
3490
+ template = template.replace(a, b)
3491
+
3492
+ if server_id:
3493
+ replace("ID_TYPE", "$artifactID: ID")
3494
+ replace("ID_VALUE", "artifactID: $artifactID")
3495
+ elif client_id:
3496
+ replace("ID_TYPE", "$clientID: ID")
3497
+ replace("ID_VALUE", "clientID: $clientID")
3498
+
3499
+ variable_values = {
3500
+ "clientID": client_id,
3501
+ "artifactID": server_id,
3502
+ "artifactPortfolioName": portfolio_name,
3503
+ "entityName": entity,
3504
+ "projectName": project,
3505
+ "aliases": [
3506
+ {"alias": alias, "artifactCollectionName": portfolio_name}
3507
+ for alias in aliases
3508
+ ],
3509
+ }
3510
+
3511
+ mutation = gql(template)
3512
+ response = self.gql(mutation, variable_values=variable_values)
3513
+ link_artifact: Dict[str, Any] = response["linkArtifact"]
3514
+ return link_artifact
3515
+
3516
+ def use_artifact(
3517
+ self,
3518
+ artifact_id: str,
3519
+ entity_name: Optional[str] = None,
3520
+ project_name: Optional[str] = None,
3521
+ run_name: Optional[str] = None,
3522
+ use_as: Optional[str] = None,
3523
+ ) -> Optional[Dict[str, Any]]:
3524
+ query_template = """
3525
+ mutation UseArtifact(
3526
+ $entityName: String!,
3527
+ $projectName: String!,
3528
+ $runName: String!,
3529
+ $artifactID: ID!,
3530
+ _USED_AS_TYPE_
3531
+ ) {
3532
+ useArtifact(input: {
3533
+ entityName: $entityName,
3534
+ projectName: $projectName,
3535
+ runName: $runName,
3536
+ artifactID: $artifactID,
3537
+ _USED_AS_VALUE_
3538
+ }) {
3539
+ artifact {
3540
+ id
3541
+ digest
3542
+ description
3543
+ state
3544
+ createdAt
3545
+ metadata
3546
+ }
3547
+ }
3548
+ }
3549
+ """
3550
+
3551
+ artifact_types = self.server_use_artifact_input_introspection()
3552
+ if "usedAs" in artifact_types:
3553
+ query_template = query_template.replace(
3554
+ "_USED_AS_TYPE_", "$usedAs: String"
3555
+ ).replace("_USED_AS_VALUE_", "usedAs: $usedAs")
3556
+ else:
3557
+ query_template = query_template.replace("_USED_AS_TYPE_", "").replace(
3558
+ "_USED_AS_VALUE_", ""
3559
+ )
3560
+
3561
+ query = gql(query_template)
3562
+
3563
+ entity_name = entity_name or self.settings("entity")
3564
+ project_name = project_name or self.settings("project")
3565
+ run_name = run_name or self.current_run_id
3566
+
3567
+ response = self.gql(
3568
+ query,
3569
+ variable_values={
3570
+ "entityName": entity_name,
3571
+ "projectName": project_name,
3572
+ "runName": run_name,
3573
+ "artifactID": artifact_id,
3574
+ "usedAs": use_as,
3575
+ },
3576
+ )
3577
+
3578
+ if response["useArtifact"]["artifact"]:
3579
+ artifact: Dict[str, Any] = response["useArtifact"]["artifact"]
3580
+ return artifact
3581
+ return None
3582
+
3583
+ def create_artifact_type(
3584
+ self,
3585
+ artifact_type_name: str,
3586
+ entity_name: Optional[str] = None,
3587
+ project_name: Optional[str] = None,
3588
+ description: Optional[str] = None,
3589
+ ) -> Optional[str]:
3590
+ mutation = gql(
3591
+ """
3592
+ mutation CreateArtifactType(
3593
+ $entityName: String!,
3594
+ $projectName: String!,
3595
+ $artifactTypeName: String!,
3596
+ $description: String
3597
+ ) {
3598
+ createArtifactType(input: {
3599
+ entityName: $entityName,
3600
+ projectName: $projectName,
3601
+ name: $artifactTypeName,
3602
+ description: $description
3603
+ }) {
3604
+ artifactType {
3605
+ id
3606
+ }
3607
+ }
3608
+ }
3609
+ """
3610
+ )
3611
+ entity_name = entity_name or self.settings("entity")
3612
+ project_name = project_name or self.settings("project")
3613
+ response = self.gql(
3614
+ mutation,
3615
+ variable_values={
3616
+ "entityName": entity_name,
3617
+ "projectName": project_name,
3618
+ "artifactTypeName": artifact_type_name,
3619
+ "description": description,
3620
+ },
3621
+ )
3622
+ _id: Optional[str] = response["createArtifactType"]["artifactType"]["id"]
3623
+ return _id
3624
+
3625
+ def server_artifact_introspection(self) -> List[str]:
3626
+ query_string = """
3627
+ query ProbeServerArtifact {
3628
+ ArtifactInfoType: __type(name:"Artifact") {
3629
+ fields {
3630
+ name
3631
+ }
3632
+ }
3633
+ }
3634
+ """
3635
+
3636
+ if self.server_artifact_fields_info is None:
3637
+ query = gql(query_string)
3638
+ res = self.gql(query)
3639
+ input_fields = res.get("ArtifactInfoType", {}).get("fields", [{}])
3640
+ self.server_artifact_fields_info = [
3641
+ field["name"] for field in input_fields if "name" in field
3642
+ ]
3643
+
3644
+ return self.server_artifact_fields_info
3645
+
3646
+ def server_create_artifact_introspection(self) -> List[str]:
3647
+ query_string = """
3648
+ query ProbeServerCreateArtifactInput {
3649
+ CreateArtifactInputInfoType: __type(name:"CreateArtifactInput") {
3650
+ inputFields{
3651
+ name
3652
+ }
3653
+ }
3654
+ }
3655
+ """
3656
+
3657
+ if self.server_create_artifact_input_info is None:
3658
+ query = gql(query_string)
3659
+ res = self.gql(query)
3660
+ input_fields = res.get("CreateArtifactInputInfoType", {}).get(
3661
+ "inputFields", [{}]
3662
+ )
3663
+ self.server_create_artifact_input_info = [
3664
+ field["name"] for field in input_fields if "name" in field
3665
+ ]
3666
+
3667
+ return self.server_create_artifact_input_info
3668
+
3669
+ def _get_create_artifact_mutation(
3670
+ self,
3671
+ fields: List,
3672
+ history_step: Optional[int],
3673
+ distributed_id: Optional[str],
3674
+ ) -> str:
3675
+ types = ""
3676
+ values = ""
3677
+
3678
+ if "historyStep" in fields and history_step not in [0, None]:
3679
+ types += "$historyStep: Int64!,"
3680
+ values += "historyStep: $historyStep,"
3681
+
3682
+ if distributed_id:
3683
+ types += "$distributedID: String,"
3684
+ values += "distributedID: $distributedID,"
3685
+
3686
+ if "clientID" in fields:
3687
+ types += "$clientID: ID,"
3688
+ values += "clientID: $clientID,"
3689
+
3690
+ if "sequenceClientID" in fields:
3691
+ types += "$sequenceClientID: ID,"
3692
+ values += "sequenceClientID: $sequenceClientID,"
3693
+
3694
+ if "enableDigestDeduplication" in fields:
3695
+ values += "enableDigestDeduplication: true,"
3696
+
3697
+ if "ttlDurationSeconds" in fields:
3698
+ types += "$ttlDurationSeconds: Int64,"
3699
+ values += "ttlDurationSeconds: $ttlDurationSeconds,"
3700
+
3701
+ if "tags" in fields:
3702
+ types += "$tags: [TagInput!],"
3703
+ values += "tags: $tags,"
3704
+
3705
+ query_template = """
3706
+ mutation CreateArtifact(
3707
+ $artifactTypeName: String!,
3708
+ $artifactCollectionNames: [String!],
3709
+ $entityName: String!,
3710
+ $projectName: String!,
3711
+ $runName: String,
3712
+ $description: String,
3713
+ $digest: String!,
3714
+ $aliases: [ArtifactAliasInput!],
3715
+ $metadata: JSONString,
3716
+ _CREATE_ARTIFACT_ADDITIONAL_TYPE_
3717
+ ) {
3718
+ createArtifact(input: {
3719
+ artifactTypeName: $artifactTypeName,
3720
+ artifactCollectionNames: $artifactCollectionNames,
3721
+ entityName: $entityName,
3722
+ projectName: $projectName,
3723
+ runName: $runName,
3724
+ description: $description,
3725
+ digest: $digest,
3726
+ digestAlgorithm: MANIFEST_MD5,
3727
+ aliases: $aliases,
3728
+ metadata: $metadata,
3729
+ _CREATE_ARTIFACT_ADDITIONAL_VALUE_
3730
+ }) {
3731
+ artifact {
3732
+ id
3733
+ state
3734
+ artifactSequence {
3735
+ id
3736
+ latestArtifact {
3737
+ id
3738
+ versionIndex
3739
+ }
3740
+ }
3741
+ }
3742
+ }
3743
+ }
3744
+ """
3745
+
3746
+ return query_template.replace(
3747
+ "_CREATE_ARTIFACT_ADDITIONAL_TYPE_", types
3748
+ ).replace("_CREATE_ARTIFACT_ADDITIONAL_VALUE_", values)
3749
+
3750
+ def create_artifact(
3751
+ self,
3752
+ artifact_type_name: str,
3753
+ artifact_collection_name: str,
3754
+ digest: str,
3755
+ client_id: Optional[str] = None,
3756
+ sequence_client_id: Optional[str] = None,
3757
+ entity_name: Optional[str] = None,
3758
+ project_name: Optional[str] = None,
3759
+ run_name: Optional[str] = None,
3760
+ description: Optional[str] = None,
3761
+ metadata: Optional[Dict] = None,
3762
+ ttl_duration_seconds: Optional[int] = None,
3763
+ aliases: Optional[List[Dict[str, str]]] = None,
3764
+ tags: Optional[List[Dict[str, str]]] = None,
3765
+ distributed_id: Optional[str] = None,
3766
+ is_user_created: Optional[bool] = False,
3767
+ history_step: Optional[int] = None,
3768
+ ) -> Tuple[Dict, Dict]:
3769
+ fields = self.server_create_artifact_introspection()
3770
+ artifact_fields = self.server_artifact_introspection()
3771
+ if ("ttlIsInherited" not in artifact_fields) and ttl_duration_seconds:
3772
+ wandb.termwarn(
3773
+ "Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL"
3774
+ )
3775
+ # ttlDurationSeconds is only usable if ttlIsInherited is also present
3776
+ ttl_duration_seconds = None
3777
+ if ("tags" not in artifact_fields) and tags:
3778
+ wandb.termwarn(
3779
+ "Server not compatible with Artifact tags. "
3780
+ "To use Artifact tags, please upgrade the server to v0.85 or higher."
3781
+ )
3782
+
3783
+ query_template = self._get_create_artifact_mutation(
3784
+ fields, history_step, distributed_id
3785
+ )
3786
+
3787
+ entity_name = entity_name or self.settings("entity")
3788
+ project_name = project_name or self.settings("project")
3789
+ if not is_user_created:
3790
+ run_name = run_name or self.current_run_id
3791
+
3792
+ mutation = gql(query_template)
3793
+ response = self.gql(
3794
+ mutation,
3795
+ variable_values={
3796
+ "entityName": entity_name,
3797
+ "projectName": project_name,
3798
+ "runName": run_name,
3799
+ "artifactTypeName": artifact_type_name,
3800
+ "artifactCollectionNames": [artifact_collection_name],
3801
+ "clientID": client_id,
3802
+ "sequenceClientID": sequence_client_id,
3803
+ "digest": digest,
3804
+ "description": description,
3805
+ "aliases": list(aliases or []),
3806
+ "tags": list(tags or []),
3807
+ "metadata": json.dumps(util.make_safe_for_json(metadata))
3808
+ if metadata
3809
+ else None,
3810
+ "ttlDurationSeconds": ttl_duration_seconds,
3811
+ "distributedID": distributed_id,
3812
+ "historyStep": history_step,
3813
+ },
3814
+ )
3815
+ av = response["createArtifact"]["artifact"]
3816
+ latest = response["createArtifact"]["artifact"]["artifactSequence"].get(
3817
+ "latestArtifact"
3818
+ )
3819
+ return av, latest
3820
+
3821
+ def commit_artifact(self, artifact_id: str) -> "_Response":
3822
+ mutation = gql(
3823
+ """
3824
+ mutation CommitArtifact(
3825
+ $artifactID: ID!,
3826
+ ) {
3827
+ commitArtifact(input: {
3828
+ artifactID: $artifactID,
3829
+ }) {
3830
+ artifact {
3831
+ id
3832
+ digest
3833
+ }
3834
+ }
3835
+ }
3836
+ """
3837
+ )
3838
+
3839
+ response: _Response = self.gql(
3840
+ mutation,
3841
+ variable_values={"artifactID": artifact_id},
3842
+ timeout=60,
3843
+ )
3844
+ return response
3845
+
3846
+ def complete_multipart_upload_artifact(
3847
+ self,
3848
+ artifact_id: str,
3849
+ storage_path: str,
3850
+ completed_parts: List[Dict[str, Any]],
3851
+ upload_id: Optional[str],
3852
+ complete_multipart_action: str = "Complete",
3853
+ ) -> Optional[str]:
3854
+ mutation = gql(
3855
+ """
3856
+ mutation CompleteMultipartUploadArtifact(
3857
+ $completeMultipartAction: CompleteMultipartAction!,
3858
+ $completedParts: [UploadPartsInput!]!,
3859
+ $artifactID: ID!
3860
+ $storagePath: String!
3861
+ $uploadID: String!
3862
+ ) {
3863
+ completeMultipartUploadArtifact(
3864
+ input: {
3865
+ completeMultipartAction: $completeMultipartAction,
3866
+ completedParts: $completedParts,
3867
+ artifactID: $artifactID,
3868
+ storagePath: $storagePath
3869
+ uploadID: $uploadID
3870
+ }
3871
+ ) {
3872
+ digest
3873
+ }
3874
+ }
3875
+ """
3876
+ )
3877
+ response = self.gql(
3878
+ mutation,
3879
+ variable_values={
3880
+ "completeMultipartAction": complete_multipart_action,
3881
+ "artifactID": artifact_id,
3882
+ "storagePath": storage_path,
3883
+ "completedParts": completed_parts,
3884
+ "uploadID": upload_id,
3885
+ },
3886
+ )
3887
+ digest: Optional[str] = response["completeMultipartUploadArtifact"]["digest"]
3888
+ return digest
3889
+
3890
+ def create_artifact_manifest(
3891
+ self,
3892
+ name: str,
3893
+ digest: str,
3894
+ artifact_id: Optional[str],
3895
+ base_artifact_id: Optional[str] = None,
3896
+ entity: Optional[str] = None,
3897
+ project: Optional[str] = None,
3898
+ run: Optional[str] = None,
3899
+ include_upload: bool = True,
3900
+ type: str = "FULL",
3901
+ ) -> Tuple[str, Dict[str, Any]]:
3902
+ mutation = gql(
3903
+ """
3904
+ mutation CreateArtifactManifest(
3905
+ $name: String!,
3906
+ $digest: String!,
3907
+ $artifactID: ID!,
3908
+ $baseArtifactID: ID,
3909
+ $entityName: String!,
3910
+ $projectName: String!,
3911
+ $runName: String!,
3912
+ $includeUpload: Boolean!,
3913
+ {}
3914
+ ) {{
3915
+ createArtifactManifest(input: {{
3916
+ name: $name,
3917
+ digest: $digest,
3918
+ artifactID: $artifactID,
3919
+ baseArtifactID: $baseArtifactID,
3920
+ entityName: $entityName,
3921
+ projectName: $projectName,
3922
+ runName: $runName,
3923
+ {}
3924
+ }}) {{
3925
+ artifactManifest {{
3926
+ id
3927
+ file {{
3928
+ id
3929
+ name
3930
+ displayName
3931
+ uploadUrl @include(if: $includeUpload)
3932
+ uploadHeaders @include(if: $includeUpload)
3933
+ }}
3934
+ }}
3935
+ }}
3936
+ }}
3937
+ """.format(
3938
+ "$type: ArtifactManifestType = FULL" if type != "FULL" else "",
3939
+ "type: $type" if type != "FULL" else "",
3940
+ )
3941
+ )
3942
+
3943
+ entity_name = entity or self.settings("entity")
3944
+ project_name = project or self.settings("project")
3945
+ run_name = run or self.current_run_id
3946
+
3947
+ response = self.gql(
3948
+ mutation,
3949
+ variable_values={
3950
+ "name": name,
3951
+ "digest": digest,
3952
+ "artifactID": artifact_id,
3953
+ "baseArtifactID": base_artifact_id,
3954
+ "entityName": entity_name,
3955
+ "projectName": project_name,
3956
+ "runName": run_name,
3957
+ "includeUpload": include_upload,
3958
+ "type": type,
3959
+ },
3960
+ )
3961
+ return (
3962
+ response["createArtifactManifest"]["artifactManifest"]["id"],
3963
+ response["createArtifactManifest"]["artifactManifest"]["file"],
3964
+ )
3965
+
3966
+ def update_artifact_manifest(
3967
+ self,
3968
+ artifact_manifest_id: str,
3969
+ base_artifact_id: Optional[str] = None,
3970
+ digest: Optional[str] = None,
3971
+ include_upload: Optional[bool] = True,
3972
+ ) -> Tuple[str, Dict[str, Any]]:
3973
+ mutation = gql(
3974
+ """
3975
+ mutation UpdateArtifactManifest(
3976
+ $artifactManifestID: ID!,
3977
+ $digest: String,
3978
+ $baseArtifactID: ID,
3979
+ $includeUpload: Boolean!,
3980
+ ) {
3981
+ updateArtifactManifest(input: {
3982
+ artifactManifestID: $artifactManifestID,
3983
+ digest: $digest,
3984
+ baseArtifactID: $baseArtifactID,
3985
+ }) {
3986
+ artifactManifest {
3987
+ id
3988
+ file {
3989
+ id
3990
+ name
3991
+ displayName
3992
+ uploadUrl @include(if: $includeUpload)
3993
+ uploadHeaders @include(if: $includeUpload)
3994
+ }
3995
+ }
3996
+ }
3997
+ }
3998
+ """
3999
+ )
4000
+
4001
+ response = self.gql(
4002
+ mutation,
4003
+ variable_values={
4004
+ "artifactManifestID": artifact_manifest_id,
4005
+ "digest": digest,
4006
+ "baseArtifactID": base_artifact_id,
4007
+ "includeUpload": include_upload,
4008
+ },
4009
+ )
4010
+
4011
+ return (
4012
+ response["updateArtifactManifest"]["artifactManifest"]["id"],
4013
+ response["updateArtifactManifest"]["artifactManifest"]["file"],
4014
+ )
4015
+
4016
+ def update_artifact_metadata(
4017
+ self, artifact_id: str, metadata: Dict[str, Any]
4018
+ ) -> Dict[str, Any]:
4019
+ """Set the metadata of the given artifact version."""
4020
+ mutation = gql(
4021
+ """
4022
+ mutation UpdateArtifact(
4023
+ $artifactID: ID!,
4024
+ $metadata: JSONString,
4025
+ ) {
4026
+ updateArtifact(input: {
4027
+ artifactID: $artifactID,
4028
+ metadata: $metadata,
4029
+ }) {
4030
+ artifact {
4031
+ id
4032
+ }
4033
+ }
4034
+ }
4035
+ """
4036
+ )
4037
+ response = self.gql(
4038
+ mutation,
4039
+ variable_values={
4040
+ "artifactID": artifact_id,
4041
+ "metadata": json.dumps(metadata),
4042
+ },
4043
+ )
4044
+ return response["updateArtifact"]["artifact"]
4045
+
4046
+ def _resolve_client_id(
4047
+ self,
4048
+ client_id: str,
4049
+ ) -> Optional[str]:
4050
+ if client_id in self._client_id_mapping:
4051
+ return self._client_id_mapping[client_id]
4052
+
4053
+ query = gql(
4054
+ """
4055
+ query ClientIDMapping($clientID: ID!) {
4056
+ clientIDMapping(clientID: $clientID) {
4057
+ serverID
4058
+ }
4059
+ }
4060
+ """
4061
+ )
4062
+ response = self.gql(
4063
+ query,
4064
+ variable_values={
4065
+ "clientID": client_id,
4066
+ },
4067
+ )
4068
+ server_id = None
4069
+ if response is not None:
4070
+ client_id_mapping = response.get("clientIDMapping")
4071
+ if client_id_mapping is not None:
4072
+ server_id = client_id_mapping.get("serverID")
4073
+ if server_id is not None:
4074
+ self._client_id_mapping[client_id] = server_id
4075
+ return server_id
4076
+
4077
+ def server_create_artifact_file_spec_input_introspection(self) -> List:
4078
+ query_string = """
4079
+ query ProbeServerCreateArtifactFileSpecInput {
4080
+ CreateArtifactFileSpecInputInfoType: __type(name:"CreateArtifactFileSpecInput") {
4081
+ inputFields{
4082
+ name
4083
+ }
4084
+ }
4085
+ }
4086
+ """
4087
+
4088
+ query = gql(query_string)
4089
+ res = self.gql(query)
4090
+ create_artifact_file_spec_input_info = [
4091
+ field.get("name", "")
4092
+ for field in res.get("CreateArtifactFileSpecInputInfoType", {}).get(
4093
+ "inputFields", [{}]
4094
+ )
4095
+ ]
4096
+ return create_artifact_file_spec_input_info
4097
+
4098
+ @normalize_exceptions
4099
+ def create_artifact_files(
4100
+ self, artifact_files: Iterable["CreateArtifactFileSpecInput"]
4101
+ ) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
4102
+ query_template = """
4103
+ mutation CreateArtifactFiles(
4104
+ $storageLayout: ArtifactStorageLayout!
4105
+ $artifactFiles: [CreateArtifactFileSpecInput!]!
4106
+ ) {
4107
+ createArtifactFiles(input: {
4108
+ artifactFiles: $artifactFiles,
4109
+ storageLayout: $storageLayout,
4110
+ }) {
4111
+ files {
4112
+ edges {
4113
+ node {
4114
+ id
4115
+ name
4116
+ displayName
4117
+ uploadUrl
4118
+ uploadHeaders
4119
+ _MULTIPART_UPLOAD_FIELDS_
4120
+ artifact {
4121
+ id
4122
+ }
4123
+ }
4124
+ }
4125
+ }
4126
+ }
4127
+ }
4128
+ """
4129
+ multipart_upload_url_query = """
4130
+ storagePath
4131
+ uploadMultipartUrls {
4132
+ uploadID
4133
+ uploadUrlParts {
4134
+ partNumber
4135
+ uploadUrl
4136
+ }
4137
+ }
4138
+ """
4139
+
4140
+ # TODO: we should use constants here from interface/artifacts.py
4141
+ # but probably don't want the dependency. We're going to remove
4142
+ # this setting in a future release, so I'm just hard-coding the strings.
4143
+ storage_layout = "V2"
4144
+ if env.get_use_v1_artifacts():
4145
+ storage_layout = "V1"
4146
+
4147
+ create_artifact_file_spec_input_fields = (
4148
+ self.server_create_artifact_file_spec_input_introspection()
4149
+ )
4150
+ if "uploadPartsInput" in create_artifact_file_spec_input_fields:
4151
+ query_template = query_template.replace(
4152
+ "_MULTIPART_UPLOAD_FIELDS_", multipart_upload_url_query
4153
+ )
4154
+ else:
4155
+ query_template = query_template.replace("_MULTIPART_UPLOAD_FIELDS_", "")
4156
+
4157
+ mutation = gql(query_template)
4158
+ response = self.gql(
4159
+ mutation,
4160
+ variable_values={
4161
+ "storageLayout": storage_layout,
4162
+ "artifactFiles": [af for af in artifact_files],
4163
+ },
4164
+ )
4165
+
4166
+ result = {}
4167
+ for edge in response["createArtifactFiles"]["files"]["edges"]:
4168
+ node = edge["node"]
4169
+ result[node["displayName"]] = node
4170
+ return result
4171
+
4172
+ @normalize_exceptions
4173
+ def notify_scriptable_run_alert(
4174
+ self,
4175
+ title: str,
4176
+ text: str,
4177
+ level: Optional[str] = None,
4178
+ wait_duration: Optional["Number"] = None,
4179
+ ) -> bool:
4180
+ mutation = gql(
4181
+ """
4182
+ mutation NotifyScriptableRunAlert(
4183
+ $entityName: String!,
4184
+ $projectName: String!,
4185
+ $runName: String!,
4186
+ $title: String!,
4187
+ $text: String!,
4188
+ $severity: AlertSeverity = INFO,
4189
+ $waitDuration: Duration
4190
+ ) {
4191
+ notifyScriptableRunAlert(input: {
4192
+ entityName: $entityName,
4193
+ projectName: $projectName,
4194
+ runName: $runName,
4195
+ title: $title,
4196
+ text: $text,
4197
+ severity: $severity,
4198
+ waitDuration: $waitDuration
4199
+ }) {
4200
+ success
4201
+ }
4202
+ }
4203
+ """
4204
+ )
4205
+
4206
+ response = self.gql(
4207
+ mutation,
4208
+ variable_values={
4209
+ "entityName": self.settings("entity"),
4210
+ "projectName": self.settings("project"),
4211
+ "runName": self.current_run_id,
4212
+ "title": title,
4213
+ "text": text,
4214
+ "severity": level,
4215
+ "waitDuration": wait_duration,
4216
+ },
4217
+ )
4218
+ success: bool = response["notifyScriptableRunAlert"]["success"]
4219
+ return success
4220
+
4221
+ def get_sweep_state(
4222
+ self, sweep: str, entity: Optional[str] = None, project: Optional[str] = None
4223
+ ) -> "SweepState":
4224
+ state: SweepState = self.sweep(
4225
+ sweep=sweep, entity=entity, project=project, specs="{}"
4226
+ )["state"]
4227
+ return state
4228
+
4229
+ def set_sweep_state(
4230
+ self,
4231
+ sweep: str,
4232
+ state: "SweepState",
4233
+ entity: Optional[str] = None,
4234
+ project: Optional[str] = None,
4235
+ ) -> None:
4236
+ assert state in ("RUNNING", "PAUSED", "CANCELED", "FINISHED")
4237
+ s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
4238
+ curr_state = s["state"].upper()
4239
+ if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
4240
+ raise Exception("Cannot pause {} sweep.".format(curr_state.lower()))
4241
+ elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
4242
+ raise Exception("Sweep already {}.".format(curr_state.lower()))
4243
+ sweep_id = s["id"]
4244
+ mutation = gql(
4245
+ """
4246
+ mutation UpsertSweep(
4247
+ $id: ID,
4248
+ $state: String,
4249
+ $entityName: String,
4250
+ $projectName: String
4251
+ ) {
4252
+ upsertSweep(input: {
4253
+ id: $id,
4254
+ state: $state,
4255
+ entityName: $entityName,
4256
+ projectName: $projectName
4257
+ }){
4258
+ sweep {
4259
+ name
4260
+ }
4261
+ }
4262
+ }
4263
+ """
4264
+ )
4265
+ self.gql(
4266
+ mutation,
4267
+ variable_values={
4268
+ "id": sweep_id,
4269
+ "state": state,
4270
+ "entityName": entity or self.settings("entity"),
4271
+ "projectName": project or self.settings("project"),
4272
+ },
4273
+ )
4274
+
4275
+ def stop_sweep(
4276
+ self,
4277
+ sweep: str,
4278
+ entity: Optional[str] = None,
4279
+ project: Optional[str] = None,
4280
+ ) -> None:
4281
+ """Finish the sweep to stop running new runs and let currently running runs finish."""
4282
+ self.set_sweep_state(
4283
+ sweep=sweep, state="FINISHED", entity=entity, project=project
4284
+ )
4285
+
4286
+ def cancel_sweep(
4287
+ self,
4288
+ sweep: str,
4289
+ entity: Optional[str] = None,
4290
+ project: Optional[str] = None,
4291
+ ) -> None:
4292
+ """Cancel the sweep to kill all running runs and stop running new runs."""
4293
+ self.set_sweep_state(
4294
+ sweep=sweep, state="CANCELED", entity=entity, project=project
4295
+ )
4296
+
4297
+ def pause_sweep(
4298
+ self,
4299
+ sweep: str,
4300
+ entity: Optional[str] = None,
4301
+ project: Optional[str] = None,
4302
+ ) -> None:
4303
+ """Pause the sweep to temporarily stop running new runs."""
4304
+ self.set_sweep_state(
4305
+ sweep=sweep, state="PAUSED", entity=entity, project=project
4306
+ )
4307
+
4308
+ def resume_sweep(
4309
+ self,
4310
+ sweep: str,
4311
+ entity: Optional[str] = None,
4312
+ project: Optional[str] = None,
4313
+ ) -> None:
4314
+ """Resume the sweep to continue running new runs."""
4315
+ self.set_sweep_state(
4316
+ sweep=sweep, state="RUNNING", entity=entity, project=project
4317
+ )
4318
+
4319
+ def _status_request(self, url: str, length: int) -> requests.Response:
4320
+ """Ask google how much we've uploaded."""
4321
+ check_httpclient_logger_handler()
4322
+ return requests.put(
4323
+ url=url,
4324
+ headers={"Content-Length": "0", "Content-Range": "bytes */%i" % length},
4325
+ )
4326
+
4327
+ def _flatten_edges(self, response: "_Response") -> List[Dict]:
4328
+ """Return an array from the nested graphql relay structure."""
4329
+ return [node["node"] for node in response["edges"]]
4330
+
4331
+ @normalize_exceptions
4332
+ def stop_run(
4333
+ self,
4334
+ run_id: str,
4335
+ ) -> bool:
4336
+ mutation = gql(
4337
+ """
4338
+ mutation stopRun($id: ID!) {
4339
+ stopRun(input: {
4340
+ id: $id
4341
+ }) {
4342
+ clientMutationId
4343
+ success
4344
+ }
4345
+ }
4346
+ """
4347
+ )
4348
+
4349
+ response = self.gql(
4350
+ mutation,
4351
+ variable_values={
4352
+ "id": run_id,
4353
+ },
4354
+ )
4355
+
4356
+ success: bool = response["stopRun"].get("success")
4357
+
4358
+ return success