ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (128) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -7
  2. metaflow_extensions/outerbounds/config/__init__.py +35 -0
  3. metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
  4. metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
  5. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  6. metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
  7. metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  33. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  34. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  35. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  36. metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
  37. metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
  38. metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
  39. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  40. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  41. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
  42. metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
  43. metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
  44. metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
  45. metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
  46. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  47. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  48. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  49. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  50. metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  51. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  52. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
  53. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
  54. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
  55. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
  56. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
  57. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  58. metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
  59. metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
  60. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
  61. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  62. metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
  63. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
  64. metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
  65. metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
  66. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
  67. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
  68. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
  69. metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
  70. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  71. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  72. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  73. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  74. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  75. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  76. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  77. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
  78. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  79. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  80. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
  81. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  82. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  83. metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
  84. metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
  85. metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
  86. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  87. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  88. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  89. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  90. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  91. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  92. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  93. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  94. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  95. metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
  96. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
  97. metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
  98. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
  99. metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  100. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
  101. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
  102. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
  103. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
  104. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
  105. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
  106. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
  107. metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
  108. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  109. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  110. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  111. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  112. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  113. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  114. metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
  115. metaflow_extensions/outerbounds/remote_config.py +53 -16
  116. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
  117. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  118. metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
  119. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  120. metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
  121. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  122. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  123. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  124. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
  125. ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
  126. ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
  127. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
  128. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,378 @@
1
+ """
2
+ This library is an abstraction layer for connecting to snowflake using Outerbounds
3
+ OIDC tokens. It expects that a security integration that authenticates tokens minted
4
+ by Outerbounds has already been configured in the target snowflake account.
5
+ """
6
+ from metaflow.metaflow_config import SERVICE_URL
7
+ from metaflow.metaflow_config_funcs import init_config
8
+ from typing import Dict
9
+ from os import environ
10
+ import sys
11
+ import json
12
+ import requests
13
+ import random
14
+ import time
15
+
16
+
17
+ class OuterboundsSnowflakeConnectorException(Exception):
18
+ pass
19
+
20
+
21
+ class OuterboundsSnowflakeIntegrationSpecApiResponse:
22
+ def __init__(self, response):
23
+ self.response = response
24
+
25
+ @property
26
+ def account(self):
27
+ return self.response["account"]
28
+
29
+ @property
30
+ def user(self):
31
+ return self.response["user"]
32
+
33
+ @property
34
+ def default_role(self):
35
+ return self.response["default_role"]
36
+
37
+ @property
38
+ def warehouse(self):
39
+ return self.response["warehouse"]
40
+
41
+ @property
42
+ def database(self):
43
+ return self.response["database"]
44
+
45
+
46
+ def get_snowflake_token(user: str = "", role: str = "", integration: str = "") -> str:
47
+ """
48
+ Uses the Outerbounds source token to request for a snowflake compatible OIDC
49
+ token. This token can then be used to connect to snowflake.
50
+ user: str
51
+ The user the token will be minted for
52
+ role: str
53
+ The role to which the token will be scoped to
54
+ integration: str
55
+ The name of the snowflake integration to use. If not set, an existing integration will be used provided that only one exists per perimeter. If integration is not set and more than one exists, then we raise an exception.
56
+ """
57
+ provisioner = SnowflakeIntegrationProvisioner(integration)
58
+ if not user or not role or not integration:
59
+ integration_spec = provisioner.get_snowflake_integration_spec()
60
+ if not user:
61
+ user = integration_spec.user
62
+
63
+ if not role:
64
+ role = integration_spec.default_role
65
+
66
+ if not integration:
67
+ integration = provisioner.get_integration_name()
68
+
69
+ snowflake_token_url = provisioner.get_snowflake_token_url()
70
+ perimeter = provisioner.get_perimeter()
71
+ payload = {
72
+ "perimeterName": perimeter,
73
+ "snowflakeUser": user,
74
+ "snowflakeRole": role,
75
+ "integrationName": integration,
76
+ }
77
+ json_payload = json.dumps(payload)
78
+ headers = provisioner.get_service_auth_header()
79
+ response = _api_server_get(
80
+ snowflake_token_url, data=json_payload, headers=headers, conn_error_retries=5
81
+ )
82
+ response.raise_for_status()
83
+ return response.json()["token"]
84
+
85
+
86
+ def get_oauth_connection_params(
87
+ user: str = "", role: str = "", integration: str = "", **kwargs
88
+ ) -> Dict:
89
+ """
90
+ Get OAuth connection parameters for Snowflake authentication using Outerbounds integration.
91
+
92
+ This is a helper function that returns connection parameters dict that can be used
93
+ with both snowflake-connector-python and snowflake-snowpark-python.
94
+
95
+ user: str
96
+ The user name used to authenticate with snowflake
97
+ role: str
98
+ The role to request when connecting with snowflake
99
+ integration: str
100
+ The name of the snowflake integration to use. If not set, an existing integration
101
+ will be used provided that only one exists in the current perimeter.
102
+ kwargs: dict
103
+ Additional arguments to include in the connection parameters
104
+
105
+ Returns:
106
+ Dict with connection parameters including OAuth token
107
+ """
108
+ # ensure password is not set
109
+ if "password" in kwargs:
110
+ raise OuterboundsSnowflakeConnectorException(
111
+ "Password should not be set when using Outerbounds OAuth authentication."
112
+ )
113
+
114
+ provisioner = SnowflakeIntegrationProvisioner(integration)
115
+ get_defaults = any(
116
+ key not in kwargs for key in ["account", "warehouse", "database"]
117
+ )
118
+ if not user or not role or not integration or get_defaults:
119
+ integration_spec = provisioner.get_snowflake_integration_spec()
120
+ if not user:
121
+ user = integration_spec.user
122
+
123
+ if not role:
124
+ role = integration_spec.default_role
125
+
126
+ if not integration:
127
+ integration = provisioner.get_integration_name()
128
+
129
+ if "account" not in kwargs:
130
+ kwargs["account"] = integration_spec.account
131
+
132
+ if "warehouse" not in kwargs:
133
+ kwargs["warehouse"] = integration_spec.warehouse
134
+
135
+ # if the user is attempting to use a warehouse different from what is specified in the
136
+ # integration we will not set the database
137
+ if (
138
+ "database" not in kwargs
139
+ and kwargs["warehouse"] == integration_spec.warehouse
140
+ ):
141
+ kwargs["database"] = integration_spec.database
142
+
143
+ # get snowflake token
144
+ token = get_snowflake_token(user=user, role=role, integration=integration)
145
+ kwargs["token"] = token
146
+ kwargs["authenticator"] = "oauth"
147
+ kwargs["role"] = role
148
+ kwargs["user"] = user
149
+
150
+ return kwargs
151
+
152
+
153
+ def connect(user: str = "", role: str = "", integration: str = "", **kwargs):
154
+ """
155
+ Connect to snowflake using the token minted by Outerbounds
156
+ user: str
157
+ The user name used to authenticate with snowflake
158
+ role: str
159
+ The role to request when connect with snowflake
160
+ integration: str
161
+ The name of the snowflake integration to use. If not set, an existing integration will be used provided that only one exists in the current perimeter. If integration is not set and more than one exists in the current perimeter, then we raise an exception.
162
+ kwargs: dict
163
+ Additional arguments to pass to the python snowflake connector
164
+ """
165
+ # Get OAuth connection params using the helper
166
+ connection_params = get_oauth_connection_params(
167
+ user=user, role=role, integration=integration, **kwargs
168
+ )
169
+
170
+ # connect to snowflake
171
+ try:
172
+ from snowflake.connector import connect
173
+
174
+ cn = connect(**connection_params)
175
+ return cn
176
+ except ImportError as ie:
177
+ raise OuterboundsSnowflakeConnectorException(
178
+ f"Error importing snowflake connector: {ie}.\nPlease make sure the 'snowflake-connector-python' package has been installed by running 'pip install -U \"outerbounds[snowflake]\"' or using the Metaflow decorators @pypi or @conda."
179
+ )
180
+ except Exception as e:
181
+ raise OuterboundsSnowflakeConnectorException(
182
+ f"Error connecting to snowflake: {e}"
183
+ )
184
+
185
+
186
+ def _api_server_get(*args, conn_error_retries=2, **kwargs):
187
+ """
188
+ There are two categories of errors that we need to handle when dealing with any API server.
189
+ 1. HTTP errors. These are are errors that are returned from the API server.
190
+ - How to handle retries for this case will be application specific.
191
+ 2. Errors when the API server may not be reachable (DNS resolution / network issues)
192
+ - In this scenario, we know that something external to the API server is going wrong causing the issue.
193
+ - Failing pre-maturely in the case might not be the best course of action since critical user jobs might crash on intermittent issues.
194
+ - So in this case, we can just planely retry the request.
195
+
196
+ This function handles the second case. It's a simple wrapper to handle the retry logic for connection errors.
197
+ If this function is provided a `conn_error_retries` of 5, then the last retry will have waited 32 seconds.
198
+ Generally this is a safe enough number of retries after which we can assume that something is really broken. Until then,
199
+ there can be intermittent issues that would resolve themselves if we retry gracefully.
200
+ """
201
+ _num_retries = 0
202
+ noise = random.uniform(-0.5, 0.5)
203
+ while _num_retries < conn_error_retries:
204
+ try:
205
+ return requests.get(*args, **kwargs)
206
+ except requests.exceptions.ConnectionError:
207
+ if _num_retries <= conn_error_retries - 1:
208
+ # Exponential backoff with 2^(_num_retries+1) seconds
209
+ time.sleep((2 ** (_num_retries + 1)) + noise)
210
+ _num_retries += 1
211
+ else:
212
+ print(
213
+ "[@snowflake] Failed to connect to the API server. ",
214
+ file=sys.stderr,
215
+ )
216
+ raise
217
+
218
+
219
+ class Snowflake:
220
+ def __init__(
221
+ self, user: str = "", role: str = "", integration: str = "", **kwargs
222
+ ) -> None:
223
+ self.cn = connect(user, role, integration, **kwargs)
224
+
225
+ def __enter__(self):
226
+ return self.cn
227
+
228
+ def __exit__(self, exception_type, exception_value, traceback):
229
+ self.cn.close()
230
+
231
+ def close(self):
232
+ self.cn.close()
233
+
234
+
235
+ class SnowflakeIntegrationProvisioner:
236
+ def __init__(self, integration_name: str) -> None:
237
+ self.conf = init_config()
238
+ self.integration_name = integration_name
239
+
240
+ def get_snowflake_integration_spec(
241
+ self,
242
+ ) -> OuterboundsSnowflakeIntegrationSpecApiResponse:
243
+ integrations_url = self._get_integration_url()
244
+ perimeter = self.get_perimeter()
245
+ headers = {"Content-Type": "application/json", "Connection": "keep-alive"}
246
+ request_payload = {
247
+ "perimeter_name": perimeter,
248
+ }
249
+ # if integration is not set, list all integrations
250
+ if not self.integration_name:
251
+ list_snowflake_integrations_url = f"{integrations_url}/snowflake"
252
+ response = self._make_request(
253
+ list_snowflake_integrations_url, headers, request_payload
254
+ )
255
+ snowflake_integrations = response.get("integrations", [])
256
+ if not snowflake_integrations:
257
+ raise OuterboundsSnowflakeConnectorException(
258
+ "No snowflake integrations found. Please make sure you have created a Snowflake integration on the Outerbounds UI first."
259
+ )
260
+
261
+ if len(snowflake_integrations) > 1:
262
+ raise OuterboundsSnowflakeConnectorException(
263
+ f"Multiple snowflake integrations found. Please specify a specific integration name you would like to use."
264
+ )
265
+
266
+ self.integration_name = snowflake_integrations[0]["integration_name"]
267
+ return OuterboundsSnowflakeIntegrationSpecApiResponse(
268
+ snowflake_integrations[0]["integration_spec"]
269
+ )
270
+
271
+ get_snowflake_integration_url = (
272
+ f"{integrations_url}/snowflake/{self.integration_name}"
273
+ )
274
+ response = self._make_request(
275
+ get_snowflake_integration_url, headers, request_payload
276
+ )
277
+ self.integration_name = response["integration_name"]
278
+ return OuterboundsSnowflakeIntegrationSpecApiResponse(
279
+ response["integration_spec"]
280
+ )
281
+
282
+ def get_integration_name(self) -> str:
283
+ return self.integration_name
284
+
285
+ def get_perimeter(self) -> str:
286
+ if "OBP_PERIMETER" in self.conf:
287
+ perimeter = self.conf["OBP_PERIMETER"]
288
+ else:
289
+ # if the perimeter is not in metaflow config, try to get it from the environment
290
+ perimeter = environ.get("OBP_PERIMETER", "")
291
+ if not perimeter:
292
+ raise OuterboundsSnowflakeConnectorException(
293
+ "No perimeter set. Please make sure to run `outerbounds configure <...>` command which can be found on the Ourebounds UI or reach out to your Outerbounds support team."
294
+ )
295
+
296
+ return perimeter
297
+
298
+ def get_snowflake_token_url(self) -> str:
299
+ if "OBP_AUTH_SERVER" in self.conf:
300
+ auth_host = self.conf["OBP_AUTH_SERVER"]
301
+ else:
302
+ from urllib.parse import urlparse
303
+
304
+ auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
305
+
306
+ return "https://" + auth_host + "/generate/snowflake"
307
+
308
+ def get_service_auth_header(self) -> str:
309
+ if "METAFLOW_SERVICE_AUTH_KEY" in self.conf:
310
+ return {"x-api-key": self.conf["METAFLOW_SERVICE_AUTH_KEY"]}
311
+ else:
312
+ return json.loads(environ.get("METAFLOW_SERVICE_HEADERS"))
313
+
314
+ def _get_integration_url(self) -> str:
315
+ from metaflow_extensions.outerbounds.remote_config import init_config
316
+ from os import environ
317
+
318
+ if "OBP_INTEGRATIONS_URL" in self.conf:
319
+ integrations_url = self.conf["OBP_INTEGRATIONS_URL"]
320
+ else:
321
+ # if the integrations url is not in metaflow config, try to get it from the environment
322
+ integrations_url = environ.get("OBP_INTEGRATIONS_URL", "")
323
+
324
+ if not integrations_url:
325
+ raise OuterboundsSnowflakeConnectorException(
326
+ "No integrations url set. Please notify your Outerbounds support team about this issue."
327
+ )
328
+
329
+ return integrations_url
330
+
331
+ def _make_request(self, url, headers: Dict, payload: Dict) -> Dict:
332
+ try:
333
+ from metaflow.metaflow_config import SERVICE_HEADERS
334
+
335
+ request_headers = {**headers, **(SERVICE_HEADERS or {})}
336
+ except ImportError:
337
+ headers = headers
338
+
339
+ retryable_status_codes = [409]
340
+ json_payload = json.dumps(payload)
341
+ for attempt in range(2): # 0 = initial attempt, 1-2 = retries
342
+ response = _api_server_get(
343
+ url, data=json_payload, headers=request_headers, conn_error_retries=5
344
+ )
345
+ if response.status_code not in retryable_status_codes:
346
+ break
347
+
348
+ if attempt < 2: # Don't sleep after the last attempt
349
+ sleep_time = 0.5 * (attempt + 1)
350
+ time.sleep(sleep_time)
351
+
352
+ response = _api_server_get(
353
+ url, data=json_payload, headers=request_headers, conn_error_retries=5
354
+ )
355
+ self._handle_error_response(response)
356
+ return response.json()
357
+
358
+ @staticmethod
359
+ def _handle_error_response(response: requests.Response):
360
+ if response.status_code >= 500:
361
+ raise OuterboundsSnowflakeConnectorException(
362
+ f"Server error: {response.text}. Please reach out to your Outerbounds support team."
363
+ )
364
+
365
+ body = response.json()
366
+ status_code = body.get("error", {}).get("statusCode", response.status_code)
367
+ if status_code == 404:
368
+ raise OuterboundsSnowflakeConnectorException(f"Secret not found: {body}")
369
+
370
+ if status_code >= 400:
371
+ try:
372
+ raise OuterboundsSnowflakeConnectorException(
373
+ f"status_code={status_code}\t*{body['error']['details']['kind']}*\n{body['error']['details']['message']}"
374
+ )
375
+ except KeyError:
376
+ raise OuterboundsSnowflakeConnectorException(
377
+ f"status_code={status_code} Unexpected error: {body}"
378
+ )
@@ -0,0 +1,309 @@
1
+ import os
2
+ import re
3
+ import shlex
4
+ import atexit
5
+ import json
6
+ import math
7
+ import time
8
+ import hashlib
9
+
10
+ from metaflow import util
11
+
12
+ from metaflow.metaflow_config import (
13
+ SERVICE_INTERNAL_URL,
14
+ SERVICE_HEADERS,
15
+ DEFAULT_METADATA,
16
+ DATASTORE_SYSROOT_S3,
17
+ DATATOOLS_S3ROOT,
18
+ KUBERNETES_SANDBOX_INIT_SCRIPT,
19
+ OTEL_ENDPOINT,
20
+ DEFAULT_SECRETS_BACKEND_TYPE,
21
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
22
+ S3_SERVER_SIDE_ENCRYPTION,
23
+ S3_ENDPOINT_URL,
24
+ )
25
+ from metaflow.metaflow_config_funcs import config_values
26
+
27
+ from metaflow.mflog import (
28
+ export_mflog_env_vars,
29
+ bash_capture_logs,
30
+ BASH_SAVE_LOGS,
31
+ get_log_tailer,
32
+ tail_logs,
33
+ )
34
+
35
+ from .snowpark_client import SnowparkClient
36
+ from .snowpark_exceptions import SnowparkException, SnowparkKilledException
37
+ from .snowpark_job import SnowparkJob
38
+
39
+ # Redirect structured logs to $PWD/.logs/
40
+ LOGS_DIR = "$PWD/.logs"
41
+ STDOUT_FILE = "mflog_stdout"
42
+ STDERR_FILE = "mflog_stderr"
43
+ STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
44
+ STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
45
+
46
+
47
+ class Snowpark(object):
48
+ def __init__(
49
+ self,
50
+ datastore,
51
+ metadata,
52
+ environment,
53
+ client_credentials,
54
+ ):
55
+ self.datastore = datastore
56
+ self.metadata = metadata
57
+ self.environment = environment
58
+ self.snowpark_client = SnowparkClient(**client_credentials)
59
+ atexit.register(lambda: self.job.kill() if hasattr(self, "job") else None)
60
+
61
+ def _job_name(self, user, flow_name, run_id, step_name, task_id, retry_count):
62
+ unique_str = (
63
+ "{user}-{flow_name}-{run_id}-{step_name}-{task_id}-{retry_count}".format(
64
+ user=user,
65
+ flow_name=flow_name,
66
+ run_id=str(run_id) if run_id is not None else "",
67
+ step_name=step_name,
68
+ task_id=str(task_id) if task_id is not None else "",
69
+ retry_count=str(retry_count) if retry_count is not None else "",
70
+ )
71
+ )
72
+ unique_hash = hashlib.md5(unique_str.encode("utf-8")).hexdigest()[:8]
73
+ raw_prefix = f"{flow_name}-{step_name}"
74
+ safe_prefix = re.sub(r"[^a-z0-9]", "-", raw_prefix.lower())
75
+ safe_prefix = safe_prefix[:54]
76
+ safe_prefix = safe_prefix.lstrip("-")
77
+ return f"{safe_prefix}-{unique_hash}"
78
+
79
+ def _command(self, environment, code_package_url, step_name, step_cmds, task_spec):
80
+ mflog_expr = export_mflog_env_vars(
81
+ datastore_type=self.datastore.TYPE,
82
+ stdout_path=STDOUT_PATH,
83
+ stderr_path=STDERR_PATH,
84
+ **task_spec,
85
+ )
86
+ init_cmds = environment.get_package_commands(
87
+ code_package_url, self.datastore.TYPE
88
+ )
89
+ init_expr = " && ".join(init_cmds)
90
+ step_expr = bash_capture_logs(
91
+ " && ".join(
92
+ environment.bootstrap_commands(step_name, self.datastore.TYPE)
93
+ + step_cmds
94
+ )
95
+ )
96
+
97
+ # construct an entry point that
98
+ # 1) initializes the mflog environment (mflog_expr)
99
+ # 2) bootstraps a metaflow environment (init_expr)
100
+ # 3) executes a task (step_expr)
101
+
102
+ # the `true` command is to make sure that the generated command
103
+ # plays well with docker containers which have entrypoint set as
104
+ # eval $@
105
+ cmd_str = "true && mkdir -p %s && %s && %s && %s; " % (
106
+ LOGS_DIR,
107
+ mflog_expr,
108
+ init_expr,
109
+ step_expr,
110
+ )
111
+ # after the task has finished, we save its exit code (fail/success)
112
+ # and persist the final logs. The whole entrypoint should exit
113
+ # with the exit code (c) of the task.
114
+ #
115
+ # Note that if step_expr OOMs, this tail expression is never executed.
116
+ # We lose the last logs in this scenario.
117
+ cmd_str += "c=$?; %s; exit $c" % BASH_SAVE_LOGS
118
+ # For supporting sandboxes, ensure that a custom script is executed before
119
+ # anything else is executed. The script is passed in as an env var.
120
+ cmd_str = (
121
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
122
+ % cmd_str
123
+ )
124
+ return shlex.split('bash -c "%s"' % cmd_str)
125
+
126
+ def create_job(
127
+ self,
128
+ step_name,
129
+ step_cli,
130
+ task_spec,
131
+ code_package_sha,
132
+ code_package_url,
133
+ code_package_ds,
134
+ image=None,
135
+ stage=None,
136
+ compute_pool=None,
137
+ volume_mounts=None,
138
+ external_integration=None,
139
+ cpu=None,
140
+ gpu=None,
141
+ memory=None,
142
+ run_time_limit=None,
143
+ env=None,
144
+ attrs=None,
145
+ ) -> SnowparkJob:
146
+ if env is None:
147
+ env = {}
148
+ if attrs is None:
149
+ attrs = {}
150
+
151
+ job_name = self._job_name(
152
+ attrs.get("metaflow.user"),
153
+ attrs.get("metaflow.flow_name"),
154
+ attrs.get("metaflow.run_id"),
155
+ attrs.get("metaflow.step_name"),
156
+ attrs.get("metaflow.task_id"),
157
+ attrs.get("metaflow.retry_count"),
158
+ )
159
+
160
+ snowpark_job = (
161
+ SnowparkJob(
162
+ client=self.snowpark_client,
163
+ name=job_name,
164
+ command=self._command(
165
+ self.environment, code_package_url, step_name, [step_cli], task_spec
166
+ ),
167
+ step_name=step_name,
168
+ step_cli=step_cli,
169
+ task_spec=task_spec,
170
+ code_package_sha=code_package_sha,
171
+ code_package_url=code_package_url,
172
+ code_package_ds=code_package_ds,
173
+ image=image,
174
+ stage=stage,
175
+ compute_pool=compute_pool,
176
+ volume_mounts=volume_mounts,
177
+ external_integration=external_integration,
178
+ cpu=cpu,
179
+ gpu=gpu,
180
+ memory=memory,
181
+ run_time_limit=run_time_limit,
182
+ env=env,
183
+ attrs=attrs,
184
+ )
185
+ .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
186
+ .environment_variable("METAFLOW_CODE_URL", code_package_url)
187
+ .environment_variable("METAFLOW_CODE_DS", code_package_ds)
188
+ .environment_variable("METAFLOW_USER", attrs["metaflow.user"])
189
+ .environment_variable("METAFLOW_SERVICE_URL", SERVICE_INTERNAL_URL)
190
+ .environment_variable(
191
+ "METAFLOW_SERVICE_HEADERS", json.dumps(SERVICE_HEADERS)
192
+ )
193
+ .environment_variable("METAFLOW_DATASTORE_SYSROOT_S3", DATASTORE_SYSROOT_S3)
194
+ .environment_variable("METAFLOW_DATATOOLS_S3ROOT", DATATOOLS_S3ROOT)
195
+ .environment_variable("METAFLOW_DEFAULT_DATASTORE", self.datastore.TYPE)
196
+ .environment_variable("METAFLOW_DEFAULT_METADATA", DEFAULT_METADATA)
197
+ .environment_variable("METAFLOW_SNOWPARK_WORKLOAD", 1)
198
+ .environment_variable("METAFLOW_RUNTIME_ENVIRONMENT", "snowpark")
199
+ .environment_variable(
200
+ "METAFLOW_INIT_SCRIPT", KUBERNETES_SANDBOX_INIT_SCRIPT
201
+ )
202
+ .environment_variable("METAFLOW_OTEL_ENDPOINT", OTEL_ENDPOINT)
203
+ .environment_variable(
204
+ "SNOWFLAKE_WAREHOUSE",
205
+ self.snowpark_client.connection_parameters.get("warehouse"),
206
+ )
207
+ )
208
+
209
+ for k, v in config_values():
210
+ if k.startswith("METAFLOW_CONDA_") or k.startswith("METAFLOW_DEBUG_"):
211
+ snowpark_job.environment_variable(k, v)
212
+
213
+ if DEFAULT_SECRETS_BACKEND_TYPE is not None:
214
+ snowpark_job.environment_variable(
215
+ "METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE", DEFAULT_SECRETS_BACKEND_TYPE
216
+ )
217
+ if AWS_SECRETS_MANAGER_DEFAULT_REGION is not None:
218
+ snowpark_job.environment_variable(
219
+ "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
220
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
221
+ )
222
+ if S3_SERVER_SIDE_ENCRYPTION is not None:
223
+ snowpark_job.environment_variable(
224
+ "METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
225
+ )
226
+ if S3_ENDPOINT_URL is not None:
227
+ snowpark_job.environment_variable(
228
+ "METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL
229
+ )
230
+
231
+ for name, value in env.items():
232
+ snowpark_job.environment_variable(name, value)
233
+
234
+ return snowpark_job
235
+
236
+ def launch_job(self, **kwargs):
237
+ self.job = self.create_job(**kwargs).create().execute()
238
+
239
+ def wait(self, stdout_location, stderr_location, echo=None):
240
+ def update_delay(secs_since_start):
241
+ # this sigmoid function reaches
242
+ # - 0.1 after 11 minutes
243
+ # - 0.5 after 15 minutes
244
+ # - 1.0 after 23 minutes
245
+ # in other words, the user will see very frequent updates
246
+ # during the first 10 minutes
247
+ sigmoid = 1.0 / (1.0 + math.exp(-0.01 * secs_since_start + 9.0))
248
+ return 0.5 + sigmoid * 30.0
249
+
250
+ def wait_for_launch(job):
251
+ status = job.status
252
+
253
+ echo(
254
+ "Task is starting (%s)..." % status,
255
+ "stderr",
256
+ job_id=job.id,
257
+ )
258
+ t = time.time()
259
+ start_time = time.time()
260
+ while job.is_waiting:
261
+ new_status = job.status
262
+ if status != new_status or (time.time() - t) > 30:
263
+ status = new_status
264
+ echo(
265
+ "Task is starting (%s)..." % status,
266
+ "stderr",
267
+ job_id=job.id,
268
+ )
269
+ t = time.time()
270
+ time.sleep(update_delay(time.time() - start_time))
271
+
272
+ _make_prefix = lambda: b"[%s] " % util.to_bytes(self.job.id)
273
+
274
+ stdout_tail = get_log_tailer(stdout_location, self.datastore.TYPE)
275
+ stderr_tail = get_log_tailer(stderr_location, self.datastore.TYPE)
276
+
277
+ # 1) Loop until the job has started
278
+ wait_for_launch(self.job)
279
+
280
+ # 2) Tail logs until the job has finished
281
+ tail_logs(
282
+ prefix=_make_prefix(),
283
+ stdout_tail=stdout_tail,
284
+ stderr_tail=stderr_tail,
285
+ echo=echo,
286
+ has_log_updates=lambda: self.job.is_running,
287
+ )
288
+
289
+ if self.job.has_failed:
290
+ msg = next(
291
+ msg
292
+ for msg in [
293
+ self.job.message,
294
+ "Task crashed.",
295
+ ]
296
+ if msg is not None
297
+ )
298
+ raise SnowparkException(
299
+ "%s " "This could be a transient error. " "Use @retry to retry." % msg
300
+ )
301
+ else:
302
+ if self.job.is_running:
303
+ # Kill the job if it is still running by throwing an exception.
304
+ raise SnowparkKilledException("Task failed!")
305
+ echo(
306
+ "Task finished with message '%s'." % self.job.message,
307
+ "stderr",
308
+ job_id=self.job.id,
309
+ )