ob-metaflow-extensions 1.1.78__py2.py3-none-any.whl → 1.1.80__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (19) hide show
  1. metaflow_extensions/outerbounds/config/__init__.py +28 -0
  2. metaflow_extensions/outerbounds/plugins/__init__.py +17 -3
  3. metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  4. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +268 -0
  5. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +160 -0
  6. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
  7. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
  8. metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  9. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +299 -0
  10. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +271 -0
  11. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +123 -0
  12. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +264 -0
  13. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
  14. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +235 -0
  15. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
  16. {ob_metaflow_extensions-1.1.78.dist-info → ob_metaflow_extensions-1.1.80.dist-info}/METADATA +2 -2
  17. {ob_metaflow_extensions-1.1.78.dist-info → ob_metaflow_extensions-1.1.80.dist-info}/RECORD +19 -6
  18. {ob_metaflow_extensions-1.1.78.dist-info → ob_metaflow_extensions-1.1.80.dist-info}/WHEEL +0 -0
  19. {ob_metaflow_extensions-1.1.78.dist-info → ob_metaflow_extensions-1.1.80.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
1
+ import os
2
+ import shlex
3
+ import atexit
4
+ import json
5
+ import math
6
+ import time
7
+
8
+ from metaflow import util
9
+
10
+ from metaflow.metaflow_config import (
11
+ SERVICE_INTERNAL_URL,
12
+ SERVICE_HEADERS,
13
+ DEFAULT_METADATA,
14
+ DATASTORE_SYSROOT_S3,
15
+ DATATOOLS_S3ROOT,
16
+ KUBERNETES_SANDBOX_INIT_SCRIPT,
17
+ OTEL_ENDPOINT,
18
+ DEFAULT_SECRETS_BACKEND_TYPE,
19
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
20
+ S3_SERVER_SIDE_ENCRYPTION,
21
+ S3_ENDPOINT_URL,
22
+ )
23
+ from metaflow.metaflow_config_funcs import config_values
24
+
25
+ from metaflow.mflog import (
26
+ export_mflog_env_vars,
27
+ bash_capture_logs,
28
+ BASH_SAVE_LOGS,
29
+ get_log_tailer,
30
+ tail_logs,
31
+ )
32
+
33
+ from .snowpark_client import SnowparkClient
34
+ from .snowpark_exceptions import SnowparkException
35
+ from .snowpark_job import SnowparkJob
36
+
37
+ # Redirect structured logs to $PWD/.logs/
38
+ LOGS_DIR = "$PWD/.logs"
39
+ STDOUT_FILE = "mflog_stdout"
40
+ STDERR_FILE = "mflog_stderr"
41
+ STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
42
+ STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
43
+
44
+
45
+ class Snowpark(object):
46
+ def __init__(
47
+ self,
48
+ datastore,
49
+ metadata,
50
+ environment,
51
+ client_credentials,
52
+ ):
53
+ self.datastore = datastore
54
+ self.metadata = metadata
55
+ self.environment = environment
56
+ self.snowpark_client = SnowparkClient(**client_credentials)
57
+ atexit.register(lambda: self.job.kill() if hasattr(self, "job") else None)
58
+
59
+ def _job_name(self, user, flow_name, run_id, step_name, task_id, retry_count):
60
+ return "{user}-{flow_name}-{run_id}-{step_name}-{task_id}-{retry_count}".format(
61
+ user=user,
62
+ flow_name=flow_name,
63
+ run_id=str(run_id) if run_id is not None else "",
64
+ step_name=step_name,
65
+ task_id=str(task_id) if task_id is not None else "",
66
+ retry_count=str(retry_count) if retry_count is not None else "",
67
+ )
68
+
69
+ def _command(self, environment, code_package_url, step_name, step_cmds, task_spec):
70
+ mflog_expr = export_mflog_env_vars(
71
+ datastore_type=self.datastore.TYPE,
72
+ stdout_path=STDOUT_PATH,
73
+ stderr_path=STDERR_PATH,
74
+ **task_spec
75
+ )
76
+ init_cmds = environment.get_package_commands(
77
+ code_package_url, self.datastore.TYPE
78
+ )
79
+ init_expr = " && ".join(init_cmds)
80
+ step_expr = bash_capture_logs(
81
+ " && ".join(
82
+ environment.bootstrap_commands(step_name, self.datastore.TYPE)
83
+ + step_cmds
84
+ )
85
+ )
86
+
87
+ # construct an entry point that
88
+ # 1) initializes the mflog environment (mflog_expr)
89
+ # 2) bootstraps a metaflow environment (init_expr)
90
+ # 3) executes a task (step_expr)
91
+
92
+ # the `true` command is to make sure that the generated command
93
+ # plays well with docker containers which have entrypoint set as
94
+ # eval $@
95
+ cmd_str = "true && mkdir -p %s && %s && %s && %s; " % (
96
+ LOGS_DIR,
97
+ mflog_expr,
98
+ init_expr,
99
+ step_expr,
100
+ )
101
+ # after the task has finished, we save its exit code (fail/success)
102
+ # and persist the final logs. The whole entrypoint should exit
103
+ # with the exit code (c) of the task.
104
+ #
105
+ # Note that if step_expr OOMs, this tail expression is never executed.
106
+ # We lose the last logs in this scenario.
107
+ cmd_str += "c=$?; %s; exit $c" % BASH_SAVE_LOGS
108
+ # For supporting sandboxes, ensure that a custom script is executed before
109
+ # anything else is executed. The script is passed in as an env var.
110
+ cmd_str = (
111
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
112
+ % cmd_str
113
+ )
114
+ return shlex.split('bash -c "%s"' % cmd_str)
115
+
116
+ def create_job(
117
+ self,
118
+ step_name,
119
+ step_cli,
120
+ task_spec,
121
+ code_package_sha,
122
+ code_package_url,
123
+ code_package_ds,
124
+ image=None,
125
+ stage=None,
126
+ compute_pool=None,
127
+ volume_mounts=None,
128
+ external_integration=None,
129
+ cpu=None,
130
+ gpu=None,
131
+ memory=None,
132
+ run_time_limit=None,
133
+ env=None,
134
+ attrs=None,
135
+ ) -> SnowparkJob:
136
+ if env is None:
137
+ env = {}
138
+ if attrs is None:
139
+ attrs = {}
140
+
141
+ job_name = self._job_name(
142
+ attrs.get("metaflow.user"),
143
+ attrs.get("metaflow.flow_name"),
144
+ attrs.get("metaflow.run_id"),
145
+ attrs.get("metaflow.step_name"),
146
+ attrs.get("metaflow.task_id"),
147
+ attrs.get("metaflow.retry_count"),
148
+ )
149
+
150
+ snowpark_job = (
151
+ SnowparkJob(
152
+ client=self.snowpark_client,
153
+ name=job_name,
154
+ command=self._command(
155
+ self.environment, code_package_url, step_name, [step_cli], task_spec
156
+ ),
157
+ step_name=step_name,
158
+ step_cli=step_cli,
159
+ task_spec=task_spec,
160
+ code_package_sha=code_package_sha,
161
+ code_package_url=code_package_url,
162
+ code_package_ds=code_package_ds,
163
+ image=image,
164
+ stage=stage,
165
+ compute_pool=compute_pool,
166
+ volume_mounts=volume_mounts,
167
+ external_integration=external_integration,
168
+ cpu=cpu,
169
+ gpu=gpu,
170
+ memory=memory,
171
+ run_time_limit=run_time_limit,
172
+ env=env,
173
+ attrs=attrs,
174
+ )
175
+ .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
176
+ .environment_variable("METAFLOW_CODE_URL", code_package_url)
177
+ .environment_variable("METAFLOW_CODE_DS", code_package_ds)
178
+ .environment_variable("METAFLOW_USER", attrs["metaflow.user"])
179
+ .environment_variable("METAFLOW_SERVICE_URL", SERVICE_INTERNAL_URL)
180
+ .environment_variable(
181
+ "METAFLOW_SERVICE_HEADERS", json.dumps(SERVICE_HEADERS)
182
+ )
183
+ .environment_variable("METAFLOW_DATASTORE_SYSROOT_S3", DATASTORE_SYSROOT_S3)
184
+ .environment_variable("METAFLOW_DATATOOLS_S3ROOT", DATATOOLS_S3ROOT)
185
+ .environment_variable("METAFLOW_DEFAULT_DATASTORE", self.datastore.TYPE)
186
+ .environment_variable("METAFLOW_DEFAULT_METADATA", DEFAULT_METADATA)
187
+ .environment_variable("METAFLOW_SNOWPARK_WORKLOAD", 1)
188
+ .environment_variable("METAFLOW_RUNTIME_ENVIRONMENT", "snowpark")
189
+ .environment_variable(
190
+ "METAFLOW_INIT_SCRIPT", KUBERNETES_SANDBOX_INIT_SCRIPT
191
+ )
192
+ .environment_variable("METAFLOW_OTEL_ENDPOINT", OTEL_ENDPOINT)
193
+ .environment_variable(
194
+ "SNOWFLAKE_WAREHOUSE",
195
+ self.snowpark_client.connection_parameters.get("warehouse"),
196
+ )
197
+ )
198
+
199
+ for k, v in config_values():
200
+ if k.startswith("METAFLOW_CONDA_") or k.startswith("METAFLOW_DEBUG_"):
201
+ snowpark_job.environment_variable(k, v)
202
+
203
+ if DEFAULT_SECRETS_BACKEND_TYPE is not None:
204
+ snowpark_job.environment_variable(
205
+ "METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE", DEFAULT_SECRETS_BACKEND_TYPE
206
+ )
207
+ if AWS_SECRETS_MANAGER_DEFAULT_REGION is not None:
208
+ snowpark_job.environment_variable(
209
+ "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
210
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
211
+ )
212
+ if S3_SERVER_SIDE_ENCRYPTION is not None:
213
+ snowpark_job.environment_variable(
214
+ "METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
215
+ )
216
+ if S3_ENDPOINT_URL is not None:
217
+ snowpark_job.environment_variable(
218
+ "METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL
219
+ )
220
+
221
+ for name, value in env.items():
222
+ snowpark_job.environment_variable(name, value)
223
+
224
+ return snowpark_job
225
+
226
+ def launch_job(self, **kwargs):
227
+ self.job = self.create_job(**kwargs).create().execute()
228
+
229
+ def wait(self, stdout_location, stderr_location, echo=None):
230
+ def update_delay(secs_since_start):
231
+ # this sigmoid function reaches
232
+ # - 0.1 after 11 minutes
233
+ # - 0.5 after 15 minutes
234
+ # - 1.0 after 23 minutes
235
+ # in other words, the user will see very frequent updates
236
+ # during the first 10 minutes
237
+ sigmoid = 1.0 / (1.0 + math.exp(-0.01 * secs_since_start + 9.0))
238
+ return 0.5 + sigmoid * 30.0
239
+
240
+ def wait_for_launch(job):
241
+ status = job.status
242
+
243
+ echo(
244
+ "Task is starting (%s)..." % status,
245
+ "stderr",
246
+ job_id=job.id,
247
+ )
248
+ t = time.time()
249
+ start_time = time.time()
250
+ while job.is_waiting:
251
+ new_status = job.status
252
+ if status != new_status or (time.time() - t) > 30:
253
+ status = new_status
254
+ echo(
255
+ "Task is starting (%s)..." % status,
256
+ "stderr",
257
+ job_id=job.id,
258
+ )
259
+ t = time.time()
260
+ time.sleep(update_delay(time.time() - start_time))
261
+
262
+ _make_prefix = lambda: b"[%s] " % util.to_bytes(self.job.id)
263
+
264
+ stdout_tail = get_log_tailer(stdout_location, self.datastore.TYPE)
265
+ stderr_tail = get_log_tailer(stderr_location, self.datastore.TYPE)
266
+
267
+ # 1) Loop until the job has started
268
+ wait_for_launch(self.job)
269
+
270
+ # 2) Tail logs until the job has finished
271
+ tail_logs(
272
+ prefix=_make_prefix(),
273
+ stdout_tail=stdout_tail,
274
+ stderr_tail=stderr_tail,
275
+ echo=echo,
276
+ has_log_updates=lambda: self.job.is_running,
277
+ )
278
+
279
+ if self.job.has_failed:
280
+ msg = next(
281
+ msg
282
+ for msg in [
283
+ self.job.message,
284
+ "Task crashed.",
285
+ ]
286
+ if msg is not None
287
+ )
288
+ raise SnowparkException(
289
+ "%s " "This could be a transient error. " "Use @retry to retry." % msg
290
+ )
291
+ else:
292
+ if self.job.is_running:
293
+ # Kill the job if it is still running by throwing an exception.
294
+ raise SnowparkException("Task failed!")
295
+ echo(
296
+ "Task finished with message '%s'." % self.job.message,
297
+ "stderr",
298
+ job_id=self.job.id,
299
+ )
@@ -0,0 +1,271 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import traceback
5
+
6
+ from metaflow._vendor import click
7
+ from metaflow import util
8
+ from metaflow import R
9
+ from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
10
+ from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
11
+ from metaflow.metadata.util import sync_local_metadata_from_datastore
12
+ from metaflow.mflog import TASK_LOG_SOURCE
13
+
14
+ from .snowpark_exceptions import SnowparkKilledException
15
+ from .snowpark import Snowpark
16
+
17
+
18
+ @click.group()
19
+ def cli():
20
+ pass
21
+
22
+
23
+ @cli.group(help="Commands related to Snowpark.")
24
+ def snowpark():
25
+ pass
26
+
27
+
28
+ @snowpark.command(
29
+ help="Execute a single task using Snowpark. This command calls the "
30
+ "top-level step command inside a Snowpark job with the given options. "
31
+ "Typically you do not call this command directly; it is used internally by "
32
+ "Metaflow."
33
+ )
34
+ @click.argument("step-name")
35
+ @click.argument("code-package-sha")
36
+ @click.argument("code-package-url")
37
+ @click.option(
38
+ "--executable", help="Executable requirement for Snowpark Container Services."
39
+ )
40
+ # below are snowpark specific options
41
+ @click.option(
42
+ "--account",
43
+ help="Account ID for Snowpark Container Services.",
44
+ )
45
+ @click.option(
46
+ "--user",
47
+ help="Username for Snowpark Container Services.",
48
+ )
49
+ @click.option(
50
+ "--password",
51
+ help="Password for Snowpark Container Services.",
52
+ )
53
+ @click.option(
54
+ "--role",
55
+ help="Role for Snowpark Container Services.",
56
+ )
57
+ @click.option(
58
+ "--database",
59
+ help="Database for Snowpark Container Services.",
60
+ )
61
+ @click.option(
62
+ "--warehouse",
63
+ help="Warehouse for Snowpark Container Services.",
64
+ )
65
+ @click.option(
66
+ "--schema",
67
+ help="Schema for Snowpark Container Services.",
68
+ )
69
+ @click.option(
70
+ "--image",
71
+ help="Docker image requirement for Snowpark Container Services. In name:version format.",
72
+ )
73
+ @click.option("--stage", help="Stage requirement for Snowpark Container Services.")
74
+ @click.option(
75
+ "--compute-pool", help="Compute Pool requirement for Snowpark Container Services."
76
+ )
77
+ @click.option("--volume-mounts", multiple=True)
78
+ @click.option(
79
+ "--external-integration",
80
+ multiple=True,
81
+ help="External Integration requirement for Snowpark Container Services.",
82
+ )
83
+ @click.option("--cpu", help="CPU requirement for Snowpark Container Services.")
84
+ @click.option("--gpu", help="GPU requirement for Snowpark Container Services.")
85
+ @click.option("--memory", help="Memory requirement for Snowpark Container Services.")
86
+ # TODO: secrets, volumes for snowpark
87
+ # TODO: others to consider: ubf-context, num-parallel
88
+ @click.option("--run-id", help="Passed to the top-level 'step'.")
89
+ @click.option("--task-id", help="Passed to the top-level 'step'.")
90
+ @click.option("--input-paths", help="Passed to the top-level 'step'.")
91
+ @click.option("--split-index", help="Passed to the top-level 'step'.")
92
+ @click.option("--clone-path", help="Passed to the top-level 'step'.")
93
+ @click.option("--clone-run-id", help="Passed to the top-level 'step'.")
94
+ @click.option(
95
+ "--tag", multiple=True, default=None, help="Passed to the top-level 'step'."
96
+ )
97
+ @click.option("--namespace", default=None, help="Passed to the top-level 'step'.")
98
+ @click.option("--retry-count", default=0, help="Passed to the top-level 'step'.")
99
+ @click.option(
100
+ "--max-user-code-retries", default=0, help="Passed to the top-level 'step'."
101
+ )
102
+ # TODO: this is not used anywhere as of now...
103
+ @click.option(
104
+ "--run-time-limit",
105
+ default=5 * 24 * 60 * 60, # Default is set to 5 days
106
+ help="Run time limit in seconds for Snowpark container.",
107
+ )
108
+ @click.pass_context
109
+ def step(
110
+ ctx,
111
+ step_name,
112
+ code_package_sha,
113
+ code_package_url,
114
+ executable=None,
115
+ account=None,
116
+ user=None,
117
+ password=None,
118
+ role=None,
119
+ database=None,
120
+ warehouse=None,
121
+ schema=None,
122
+ image=None,
123
+ stage=None,
124
+ compute_pool=None,
125
+ volume_mounts=None,
126
+ external_integration=None,
127
+ cpu=None,
128
+ gpu=None,
129
+ memory=None,
130
+ run_time_limit=None,
131
+ **kwargs
132
+ ):
133
+ def echo(msg, stream="stderr", job_id=None, **kwargs):
134
+ msg = util.to_unicode(msg)
135
+ if job_id:
136
+ msg = "[%s] %s" % (job_id, msg)
137
+ ctx.obj.echo_always(msg, err=(stream == sys.stderr), **kwargs)
138
+
139
+ if R.use_r():
140
+ entrypoint = R.entrypoint()
141
+ else:
142
+ executable = ctx.obj.environment.executable(step_name, executable)
143
+ entrypoint = "%s -u %s" % (executable, os.path.basename(sys.argv[0]))
144
+
145
+ top_args = " ".join(util.dict_to_cli_options(ctx.parent.parent.params))
146
+
147
+ input_paths = kwargs.get("input_paths")
148
+ split_vars = None
149
+ if input_paths:
150
+ max_size = 30 * 1024
151
+ split_vars = {
152
+ "METAFLOW_INPUT_PATHS_%d" % (i // max_size): input_paths[i : i + max_size]
153
+ for i in range(0, len(input_paths), max_size)
154
+ }
155
+ kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
156
+
157
+ step_args = " ".join(util.dict_to_cli_options(kwargs))
158
+
159
+ step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
160
+ entrypoint=entrypoint,
161
+ top_args=top_args,
162
+ step=step_name,
163
+ step_args=step_args,
164
+ )
165
+ node = ctx.obj.graph[step_name]
166
+
167
+ # Get retry information
168
+ retry_count = kwargs.get("retry_count", 0)
169
+ retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
170
+ minutes_between_retries = None
171
+ if retry_deco:
172
+ minutes_between_retries = int(
173
+ retry_deco[0].attributes.get("minutes_between_retries", 1)
174
+ )
175
+ if retry_count:
176
+ ctx.obj.echo_always(
177
+ "Sleeping %d minutes before the next retry" % minutes_between_retries
178
+ )
179
+ time.sleep(minutes_between_retries * 60)
180
+
181
+ # Set batch attributes
182
+ task_spec = {
183
+ "flow_name": ctx.obj.flow.name,
184
+ "step_name": step_name,
185
+ "run_id": kwargs["run_id"],
186
+ "task_id": kwargs["task_id"],
187
+ "retry_count": str(retry_count),
188
+ }
189
+ attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()}
190
+ attrs["metaflow.user"] = util.get_username()
191
+ attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[
192
+ "metaflow_version"
193
+ ]
194
+
195
+ # Set environment
196
+ env = {}
197
+ env_deco = [deco for deco in node.decorators if deco.name == "environment"]
198
+ if env_deco:
199
+ env = env_deco[0].attributes["vars"]
200
+
201
+ # Add the environment variables related to the input-paths argument
202
+ if split_vars:
203
+ env.update(split_vars)
204
+
205
+ # Set log tailing.
206
+ ds = ctx.obj.flow_datastore.get_task_datastore(
207
+ mode="w",
208
+ run_id=kwargs["run_id"],
209
+ step_name=step_name,
210
+ task_id=kwargs["task_id"],
211
+ attempt=int(retry_count),
212
+ )
213
+ stdout_location = ds.get_log_location(TASK_LOG_SOURCE, "stdout")
214
+ stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr")
215
+
216
+ def _sync_metadata():
217
+ if ctx.obj.metadata.TYPE == "local":
218
+ sync_local_metadata_from_datastore(
219
+ DATASTORE_LOCAL_DIR,
220
+ ctx.obj.flow_datastore.get_task_datastore(
221
+ kwargs["run_id"], step_name, kwargs["task_id"]
222
+ ),
223
+ )
224
+
225
+ try:
226
+ snowpark = Snowpark(
227
+ datastore=ctx.obj.flow_datastore,
228
+ metadata=ctx.obj.metadata,
229
+ environment=ctx.obj.environment,
230
+ client_credentials={
231
+ "account": account,
232
+ "user": user,
233
+ "password": password,
234
+ "role": role,
235
+ "database": database,
236
+ "warehouse": warehouse,
237
+ "schema": schema,
238
+ },
239
+ )
240
+ with ctx.obj.monitor.measure("metaflow.snowpark.launch_job"):
241
+ snowpark.launch_job(
242
+ step_name=step_name,
243
+ step_cli=step_cli,
244
+ task_spec=task_spec,
245
+ code_package_sha=code_package_sha,
246
+ code_package_url=code_package_url,
247
+ code_package_ds=ctx.obj.flow_datastore.TYPE,
248
+ image=image,
249
+ stage=stage,
250
+ compute_pool=compute_pool,
251
+ volume_mounts=volume_mounts,
252
+ external_integration=external_integration,
253
+ cpu=cpu,
254
+ gpu=gpu,
255
+ memory=memory,
256
+ run_time_limit=run_time_limit,
257
+ env=env,
258
+ attrs=attrs,
259
+ )
260
+ except Exception:
261
+ traceback.print_exc(chain=False)
262
+ _sync_metadata()
263
+ sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
264
+ try:
265
+ snowpark.wait(stdout_location, stderr_location, echo=echo)
266
+ except SnowparkKilledException:
267
+ # don't retry killed tasks
268
+ traceback.print_exc()
269
+ sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
270
+ finally:
271
+ _sync_metadata()
@@ -0,0 +1,123 @@
1
+ import sys
2
+ import tempfile
3
+
4
+ from .snowpark_exceptions import SnowflakeException
5
+ from .snowpark_service_spec import generate_spec_file
6
+
7
+ from metaflow.exception import MetaflowException
8
+
9
+
10
+ class SnowparkClient(object):
11
+ def __init__(
12
+ self,
13
+ account: str,
14
+ user: str,
15
+ password: str,
16
+ role: str,
17
+ database: str,
18
+ warehouse: str,
19
+ schema: str,
20
+ autocommit: bool = True,
21
+ ):
22
+ try:
23
+ from snowflake.core import Root
24
+ from snowflake.snowpark import Session
25
+
26
+ from snowflake.connector.errors import DatabaseError
27
+ except (NameError, ImportError, ModuleNotFoundError):
28
+ raise SnowflakeException(
29
+ "Could not import module 'snowflake'.\n\nInstall Snowflake "
30
+ "Python package (https://pypi.org/project/snowflake/) first.\n"
31
+ "You can install the module by executing - "
32
+ "%s -m pip install snowflake\n"
33
+ "or equivalent through your favorite Python package manager."
34
+ % sys.executable
35
+ )
36
+
37
+ self.connection_parameters = {
38
+ "account": account,
39
+ "user": user,
40
+ "password": password,
41
+ "role": role,
42
+ "warehouse": warehouse,
43
+ "database": database,
44
+ "schema": schema,
45
+ "autocommit": autocommit,
46
+ }
47
+
48
+ try:
49
+ self.session = Session.builder.configs(self.connection_parameters).create()
50
+ self.root = Root(self.session)
51
+ except DatabaseError as e:
52
+ raise ConnectionError(e)
53
+
54
+ def __del__(self):
55
+ if hasattr(self, "session"):
56
+ self.session.close()
57
+
58
+ def __check_existence_of_stage_and_compute_pool(
59
+ self, db, schema, stage, compute_pool
60
+ ):
61
+ from snowflake.core.exceptions import NotFoundError
62
+
63
+ # check if stage exists, will raise an error otherwise
64
+ try:
65
+ self.root.databases[db].schemas[schema].stages[stage].fetch()
66
+ except NotFoundError:
67
+ raise MetaflowException(
68
+ "Stage *%s* does not exist or not authorized." % stage
69
+ )
70
+
71
+ # check if compute_pool exists, will raise an error otherwise
72
+ try:
73
+ self.root.compute_pools[compute_pool].fetch()
74
+ except NotFoundError:
75
+ raise MetaflowException(
76
+ "Compute pool *%s* does not exist or not authorized." % compute_pool
77
+ )
78
+
79
+ def submit(self, name: str, spec, stage, compute_pool, external_integration):
80
+ db = self.session.get_current_database()
81
+ schema = self.session.get_current_schema()
82
+
83
+ with tempfile.TemporaryDirectory() as temp_dir:
84
+ snowpark_spec_file = tempfile.NamedTemporaryFile(
85
+ dir=temp_dir, delete=False, suffix=".yaml"
86
+ )
87
+ generate_spec_file(spec, snowpark_spec_file.name, format="yaml")
88
+
89
+ self.__check_existence_of_stage_and_compute_pool(
90
+ db, schema, stage, compute_pool
91
+ )
92
+
93
+ # upload the spec file to stage
94
+ result = self.session.file.put(snowpark_spec_file.name, "@%s" % stage)
95
+
96
+ service_name = name.replace("-", "_")
97
+ external_access = (
98
+ "EXTERNAL_ACCESS_INTEGRATIONS=(%s) " % ",".join(external_integration)
99
+ if external_integration
100
+ else ""
101
+ )
102
+
103
+ # cannot pass 'is_job' parameter using the API, thus we need to use SQL directly..
104
+ query = """
105
+ EXECUTE JOB SERVICE IN COMPUTE POOL {compute_pool}
106
+ NAME = {db}.{schema}.{service_name}
107
+ {external_access}
108
+ FROM @{stage} SPECIFICATION_FILE={specification_file}
109
+ """.format(
110
+ compute_pool=compute_pool,
111
+ db=db,
112
+ schema=schema,
113
+ service_name=service_name,
114
+ external_access=external_access,
115
+ stage=stage,
116
+ specification_file=result[0].target,
117
+ )
118
+
119
+ async_job = self.session.sql(query).collect(block=False)
120
+ return async_job.query_id, service_name
121
+
122
+ def terminate_job(self, service):
123
+ service.delete()