ob-metaflow-extensions 1.1.79__py2.py3-none-any.whl → 1.1.81__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow-extensions might be problematic. Click here for more details.
- metaflow_extensions/outerbounds/config/__init__.py +28 -0
- metaflow_extensions/outerbounds/plugins/__init__.py +17 -3
- metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +268 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +160 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
- metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +299 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +271 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +123 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +264 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +235 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
- {ob_metaflow_extensions-1.1.79.dist-info → ob_metaflow_extensions-1.1.81.dist-info}/METADATA +2 -2
- {ob_metaflow_extensions-1.1.79.dist-info → ob_metaflow_extensions-1.1.81.dist-info}/RECORD +19 -6
- {ob_metaflow_extensions-1.1.79.dist-info → ob_metaflow_extensions-1.1.81.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.79.dist-info → ob_metaflow_extensions-1.1.81.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shlex
|
|
3
|
+
import atexit
|
|
4
|
+
import json
|
|
5
|
+
import math
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
from metaflow import util
|
|
9
|
+
|
|
10
|
+
from metaflow.metaflow_config import (
|
|
11
|
+
SERVICE_INTERNAL_URL,
|
|
12
|
+
SERVICE_HEADERS,
|
|
13
|
+
DEFAULT_METADATA,
|
|
14
|
+
DATASTORE_SYSROOT_S3,
|
|
15
|
+
DATATOOLS_S3ROOT,
|
|
16
|
+
KUBERNETES_SANDBOX_INIT_SCRIPT,
|
|
17
|
+
OTEL_ENDPOINT,
|
|
18
|
+
DEFAULT_SECRETS_BACKEND_TYPE,
|
|
19
|
+
AWS_SECRETS_MANAGER_DEFAULT_REGION,
|
|
20
|
+
S3_SERVER_SIDE_ENCRYPTION,
|
|
21
|
+
S3_ENDPOINT_URL,
|
|
22
|
+
)
|
|
23
|
+
from metaflow.metaflow_config_funcs import config_values
|
|
24
|
+
|
|
25
|
+
from metaflow.mflog import (
|
|
26
|
+
export_mflog_env_vars,
|
|
27
|
+
bash_capture_logs,
|
|
28
|
+
BASH_SAVE_LOGS,
|
|
29
|
+
get_log_tailer,
|
|
30
|
+
tail_logs,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from .snowpark_client import SnowparkClient
|
|
34
|
+
from .snowpark_exceptions import SnowparkException
|
|
35
|
+
from .snowpark_job import SnowparkJob
|
|
36
|
+
|
|
37
|
+
# Redirect structured logs to $PWD/.logs/
|
|
38
|
+
LOGS_DIR = "$PWD/.logs"
|
|
39
|
+
STDOUT_FILE = "mflog_stdout"
|
|
40
|
+
STDERR_FILE = "mflog_stderr"
|
|
41
|
+
STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
|
|
42
|
+
STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Snowpark(object):
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
datastore,
|
|
49
|
+
metadata,
|
|
50
|
+
environment,
|
|
51
|
+
client_credentials,
|
|
52
|
+
):
|
|
53
|
+
self.datastore = datastore
|
|
54
|
+
self.metadata = metadata
|
|
55
|
+
self.environment = environment
|
|
56
|
+
self.snowpark_client = SnowparkClient(**client_credentials)
|
|
57
|
+
atexit.register(lambda: self.job.kill() if hasattr(self, "job") else None)
|
|
58
|
+
|
|
59
|
+
def _job_name(self, user, flow_name, run_id, step_name, task_id, retry_count):
|
|
60
|
+
return "{user}-{flow_name}-{run_id}-{step_name}-{task_id}-{retry_count}".format(
|
|
61
|
+
user=user,
|
|
62
|
+
flow_name=flow_name,
|
|
63
|
+
run_id=str(run_id) if run_id is not None else "",
|
|
64
|
+
step_name=step_name,
|
|
65
|
+
task_id=str(task_id) if task_id is not None else "",
|
|
66
|
+
retry_count=str(retry_count) if retry_count is not None else "",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def _command(self, environment, code_package_url, step_name, step_cmds, task_spec):
|
|
70
|
+
mflog_expr = export_mflog_env_vars(
|
|
71
|
+
datastore_type=self.datastore.TYPE,
|
|
72
|
+
stdout_path=STDOUT_PATH,
|
|
73
|
+
stderr_path=STDERR_PATH,
|
|
74
|
+
**task_spec
|
|
75
|
+
)
|
|
76
|
+
init_cmds = environment.get_package_commands(
|
|
77
|
+
code_package_url, self.datastore.TYPE
|
|
78
|
+
)
|
|
79
|
+
init_expr = " && ".join(init_cmds)
|
|
80
|
+
step_expr = bash_capture_logs(
|
|
81
|
+
" && ".join(
|
|
82
|
+
environment.bootstrap_commands(step_name, self.datastore.TYPE)
|
|
83
|
+
+ step_cmds
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# construct an entry point that
|
|
88
|
+
# 1) initializes the mflog environment (mflog_expr)
|
|
89
|
+
# 2) bootstraps a metaflow environment (init_expr)
|
|
90
|
+
# 3) executes a task (step_expr)
|
|
91
|
+
|
|
92
|
+
# the `true` command is to make sure that the generated command
|
|
93
|
+
# plays well with docker containers which have entrypoint set as
|
|
94
|
+
# eval $@
|
|
95
|
+
cmd_str = "true && mkdir -p %s && %s && %s && %s; " % (
|
|
96
|
+
LOGS_DIR,
|
|
97
|
+
mflog_expr,
|
|
98
|
+
init_expr,
|
|
99
|
+
step_expr,
|
|
100
|
+
)
|
|
101
|
+
# after the task has finished, we save its exit code (fail/success)
|
|
102
|
+
# and persist the final logs. The whole entrypoint should exit
|
|
103
|
+
# with the exit code (c) of the task.
|
|
104
|
+
#
|
|
105
|
+
# Note that if step_expr OOMs, this tail expression is never executed.
|
|
106
|
+
# We lose the last logs in this scenario.
|
|
107
|
+
cmd_str += "c=$?; %s; exit $c" % BASH_SAVE_LOGS
|
|
108
|
+
# For supporting sandboxes, ensure that a custom script is executed before
|
|
109
|
+
# anything else is executed. The script is passed in as an env var.
|
|
110
|
+
cmd_str = (
|
|
111
|
+
'${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
|
|
112
|
+
% cmd_str
|
|
113
|
+
)
|
|
114
|
+
return shlex.split('bash -c "%s"' % cmd_str)
|
|
115
|
+
|
|
116
|
+
def create_job(
|
|
117
|
+
self,
|
|
118
|
+
step_name,
|
|
119
|
+
step_cli,
|
|
120
|
+
task_spec,
|
|
121
|
+
code_package_sha,
|
|
122
|
+
code_package_url,
|
|
123
|
+
code_package_ds,
|
|
124
|
+
image=None,
|
|
125
|
+
stage=None,
|
|
126
|
+
compute_pool=None,
|
|
127
|
+
volume_mounts=None,
|
|
128
|
+
external_integration=None,
|
|
129
|
+
cpu=None,
|
|
130
|
+
gpu=None,
|
|
131
|
+
memory=None,
|
|
132
|
+
run_time_limit=None,
|
|
133
|
+
env=None,
|
|
134
|
+
attrs=None,
|
|
135
|
+
) -> SnowparkJob:
|
|
136
|
+
if env is None:
|
|
137
|
+
env = {}
|
|
138
|
+
if attrs is None:
|
|
139
|
+
attrs = {}
|
|
140
|
+
|
|
141
|
+
job_name = self._job_name(
|
|
142
|
+
attrs.get("metaflow.user"),
|
|
143
|
+
attrs.get("metaflow.flow_name"),
|
|
144
|
+
attrs.get("metaflow.run_id"),
|
|
145
|
+
attrs.get("metaflow.step_name"),
|
|
146
|
+
attrs.get("metaflow.task_id"),
|
|
147
|
+
attrs.get("metaflow.retry_count"),
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
snowpark_job = (
|
|
151
|
+
SnowparkJob(
|
|
152
|
+
client=self.snowpark_client,
|
|
153
|
+
name=job_name,
|
|
154
|
+
command=self._command(
|
|
155
|
+
self.environment, code_package_url, step_name, [step_cli], task_spec
|
|
156
|
+
),
|
|
157
|
+
step_name=step_name,
|
|
158
|
+
step_cli=step_cli,
|
|
159
|
+
task_spec=task_spec,
|
|
160
|
+
code_package_sha=code_package_sha,
|
|
161
|
+
code_package_url=code_package_url,
|
|
162
|
+
code_package_ds=code_package_ds,
|
|
163
|
+
image=image,
|
|
164
|
+
stage=stage,
|
|
165
|
+
compute_pool=compute_pool,
|
|
166
|
+
volume_mounts=volume_mounts,
|
|
167
|
+
external_integration=external_integration,
|
|
168
|
+
cpu=cpu,
|
|
169
|
+
gpu=gpu,
|
|
170
|
+
memory=memory,
|
|
171
|
+
run_time_limit=run_time_limit,
|
|
172
|
+
env=env,
|
|
173
|
+
attrs=attrs,
|
|
174
|
+
)
|
|
175
|
+
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
|
|
176
|
+
.environment_variable("METAFLOW_CODE_URL", code_package_url)
|
|
177
|
+
.environment_variable("METAFLOW_CODE_DS", code_package_ds)
|
|
178
|
+
.environment_variable("METAFLOW_USER", attrs["metaflow.user"])
|
|
179
|
+
.environment_variable("METAFLOW_SERVICE_URL", SERVICE_INTERNAL_URL)
|
|
180
|
+
.environment_variable(
|
|
181
|
+
"METAFLOW_SERVICE_HEADERS", json.dumps(SERVICE_HEADERS)
|
|
182
|
+
)
|
|
183
|
+
.environment_variable("METAFLOW_DATASTORE_SYSROOT_S3", DATASTORE_SYSROOT_S3)
|
|
184
|
+
.environment_variable("METAFLOW_DATATOOLS_S3ROOT", DATATOOLS_S3ROOT)
|
|
185
|
+
.environment_variable("METAFLOW_DEFAULT_DATASTORE", self.datastore.TYPE)
|
|
186
|
+
.environment_variable("METAFLOW_DEFAULT_METADATA", DEFAULT_METADATA)
|
|
187
|
+
.environment_variable("METAFLOW_SNOWPARK_WORKLOAD", 1)
|
|
188
|
+
.environment_variable("METAFLOW_RUNTIME_ENVIRONMENT", "snowpark")
|
|
189
|
+
.environment_variable(
|
|
190
|
+
"METAFLOW_INIT_SCRIPT", KUBERNETES_SANDBOX_INIT_SCRIPT
|
|
191
|
+
)
|
|
192
|
+
.environment_variable("METAFLOW_OTEL_ENDPOINT", OTEL_ENDPOINT)
|
|
193
|
+
.environment_variable(
|
|
194
|
+
"SNOWFLAKE_WAREHOUSE",
|
|
195
|
+
self.snowpark_client.connection_parameters.get("warehouse"),
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
for k, v in config_values():
|
|
200
|
+
if k.startswith("METAFLOW_CONDA_") or k.startswith("METAFLOW_DEBUG_"):
|
|
201
|
+
snowpark_job.environment_variable(k, v)
|
|
202
|
+
|
|
203
|
+
if DEFAULT_SECRETS_BACKEND_TYPE is not None:
|
|
204
|
+
snowpark_job.environment_variable(
|
|
205
|
+
"METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE", DEFAULT_SECRETS_BACKEND_TYPE
|
|
206
|
+
)
|
|
207
|
+
if AWS_SECRETS_MANAGER_DEFAULT_REGION is not None:
|
|
208
|
+
snowpark_job.environment_variable(
|
|
209
|
+
"METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION",
|
|
210
|
+
AWS_SECRETS_MANAGER_DEFAULT_REGION,
|
|
211
|
+
)
|
|
212
|
+
if S3_SERVER_SIDE_ENCRYPTION is not None:
|
|
213
|
+
snowpark_job.environment_variable(
|
|
214
|
+
"METAFLOW_S3_SERVER_SIDE_ENCRYPTION", S3_SERVER_SIDE_ENCRYPTION
|
|
215
|
+
)
|
|
216
|
+
if S3_ENDPOINT_URL is not None:
|
|
217
|
+
snowpark_job.environment_variable(
|
|
218
|
+
"METAFLOW_S3_ENDPOINT_URL", S3_ENDPOINT_URL
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
for name, value in env.items():
|
|
222
|
+
snowpark_job.environment_variable(name, value)
|
|
223
|
+
|
|
224
|
+
return snowpark_job
|
|
225
|
+
|
|
226
|
+
def launch_job(self, **kwargs):
|
|
227
|
+
self.job = self.create_job(**kwargs).create().execute()
|
|
228
|
+
|
|
229
|
+
def wait(self, stdout_location, stderr_location, echo=None):
|
|
230
|
+
def update_delay(secs_since_start):
|
|
231
|
+
# this sigmoid function reaches
|
|
232
|
+
# - 0.1 after 11 minutes
|
|
233
|
+
# - 0.5 after 15 minutes
|
|
234
|
+
# - 1.0 after 23 minutes
|
|
235
|
+
# in other words, the user will see very frequent updates
|
|
236
|
+
# during the first 10 minutes
|
|
237
|
+
sigmoid = 1.0 / (1.0 + math.exp(-0.01 * secs_since_start + 9.0))
|
|
238
|
+
return 0.5 + sigmoid * 30.0
|
|
239
|
+
|
|
240
|
+
def wait_for_launch(job):
|
|
241
|
+
status = job.status
|
|
242
|
+
|
|
243
|
+
echo(
|
|
244
|
+
"Task is starting (%s)..." % status,
|
|
245
|
+
"stderr",
|
|
246
|
+
job_id=job.id,
|
|
247
|
+
)
|
|
248
|
+
t = time.time()
|
|
249
|
+
start_time = time.time()
|
|
250
|
+
while job.is_waiting:
|
|
251
|
+
new_status = job.status
|
|
252
|
+
if status != new_status or (time.time() - t) > 30:
|
|
253
|
+
status = new_status
|
|
254
|
+
echo(
|
|
255
|
+
"Task is starting (%s)..." % status,
|
|
256
|
+
"stderr",
|
|
257
|
+
job_id=job.id,
|
|
258
|
+
)
|
|
259
|
+
t = time.time()
|
|
260
|
+
time.sleep(update_delay(time.time() - start_time))
|
|
261
|
+
|
|
262
|
+
_make_prefix = lambda: b"[%s] " % util.to_bytes(self.job.id)
|
|
263
|
+
|
|
264
|
+
stdout_tail = get_log_tailer(stdout_location, self.datastore.TYPE)
|
|
265
|
+
stderr_tail = get_log_tailer(stderr_location, self.datastore.TYPE)
|
|
266
|
+
|
|
267
|
+
# 1) Loop until the job has started
|
|
268
|
+
wait_for_launch(self.job)
|
|
269
|
+
|
|
270
|
+
# 2) Tail logs until the job has finished
|
|
271
|
+
tail_logs(
|
|
272
|
+
prefix=_make_prefix(),
|
|
273
|
+
stdout_tail=stdout_tail,
|
|
274
|
+
stderr_tail=stderr_tail,
|
|
275
|
+
echo=echo,
|
|
276
|
+
has_log_updates=lambda: self.job.is_running,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if self.job.has_failed:
|
|
280
|
+
msg = next(
|
|
281
|
+
msg
|
|
282
|
+
for msg in [
|
|
283
|
+
self.job.message,
|
|
284
|
+
"Task crashed.",
|
|
285
|
+
]
|
|
286
|
+
if msg is not None
|
|
287
|
+
)
|
|
288
|
+
raise SnowparkException(
|
|
289
|
+
"%s " "This could be a transient error. " "Use @retry to retry." % msg
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
if self.job.is_running:
|
|
293
|
+
# Kill the job if it is still running by throwing an exception.
|
|
294
|
+
raise SnowparkException("Task failed!")
|
|
295
|
+
echo(
|
|
296
|
+
"Task finished with message '%s'." % self.job.message,
|
|
297
|
+
"stderr",
|
|
298
|
+
job_id=self.job.id,
|
|
299
|
+
)
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import traceback
|
|
5
|
+
|
|
6
|
+
from metaflow._vendor import click
|
|
7
|
+
from metaflow import util
|
|
8
|
+
from metaflow import R
|
|
9
|
+
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
|
|
10
|
+
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
|
|
11
|
+
from metaflow.metadata.util import sync_local_metadata_from_datastore
|
|
12
|
+
from metaflow.mflog import TASK_LOG_SOURCE
|
|
13
|
+
|
|
14
|
+
from .snowpark_exceptions import SnowparkKilledException
|
|
15
|
+
from .snowpark import Snowpark
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@click.group()
|
|
19
|
+
def cli():
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@cli.group(help="Commands related to Snowpark.")
|
|
24
|
+
def snowpark():
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@snowpark.command(
|
|
29
|
+
help="Execute a single task using Snowpark. This command calls the "
|
|
30
|
+
"top-level step command inside a Snowpark job with the given options. "
|
|
31
|
+
"Typically you do not call this command directly; it is used internally by "
|
|
32
|
+
"Metaflow."
|
|
33
|
+
)
|
|
34
|
+
@click.argument("step-name")
|
|
35
|
+
@click.argument("code-package-sha")
|
|
36
|
+
@click.argument("code-package-url")
|
|
37
|
+
@click.option(
|
|
38
|
+
"--executable", help="Executable requirement for Snowpark Container Services."
|
|
39
|
+
)
|
|
40
|
+
# below are snowpark specific options
|
|
41
|
+
@click.option(
|
|
42
|
+
"--account",
|
|
43
|
+
help="Account ID for Snowpark Container Services.",
|
|
44
|
+
)
|
|
45
|
+
@click.option(
|
|
46
|
+
"--user",
|
|
47
|
+
help="Username for Snowpark Container Services.",
|
|
48
|
+
)
|
|
49
|
+
@click.option(
|
|
50
|
+
"--password",
|
|
51
|
+
help="Password for Snowpark Container Services.",
|
|
52
|
+
)
|
|
53
|
+
@click.option(
|
|
54
|
+
"--role",
|
|
55
|
+
help="Role for Snowpark Container Services.",
|
|
56
|
+
)
|
|
57
|
+
@click.option(
|
|
58
|
+
"--database",
|
|
59
|
+
help="Database for Snowpark Container Services.",
|
|
60
|
+
)
|
|
61
|
+
@click.option(
|
|
62
|
+
"--warehouse",
|
|
63
|
+
help="Warehouse for Snowpark Container Services.",
|
|
64
|
+
)
|
|
65
|
+
@click.option(
|
|
66
|
+
"--schema",
|
|
67
|
+
help="Schema for Snowpark Container Services.",
|
|
68
|
+
)
|
|
69
|
+
@click.option(
|
|
70
|
+
"--image",
|
|
71
|
+
help="Docker image requirement for Snowpark Container Services. In name:version format.",
|
|
72
|
+
)
|
|
73
|
+
@click.option("--stage", help="Stage requirement for Snowpark Container Services.")
|
|
74
|
+
@click.option(
|
|
75
|
+
"--compute-pool", help="Compute Pool requirement for Snowpark Container Services."
|
|
76
|
+
)
|
|
77
|
+
@click.option("--volume-mounts", multiple=True)
|
|
78
|
+
@click.option(
|
|
79
|
+
"--external-integration",
|
|
80
|
+
multiple=True,
|
|
81
|
+
help="External Integration requirement for Snowpark Container Services.",
|
|
82
|
+
)
|
|
83
|
+
@click.option("--cpu", help="CPU requirement for Snowpark Container Services.")
|
|
84
|
+
@click.option("--gpu", help="GPU requirement for Snowpark Container Services.")
|
|
85
|
+
@click.option("--memory", help="Memory requirement for Snowpark Container Services.")
|
|
86
|
+
# TODO: secrets, volumes for snowpark
|
|
87
|
+
# TODO: others to consider: ubf-context, num-parallel
|
|
88
|
+
@click.option("--run-id", help="Passed to the top-level 'step'.")
|
|
89
|
+
@click.option("--task-id", help="Passed to the top-level 'step'.")
|
|
90
|
+
@click.option("--input-paths", help="Passed to the top-level 'step'.")
|
|
91
|
+
@click.option("--split-index", help="Passed to the top-level 'step'.")
|
|
92
|
+
@click.option("--clone-path", help="Passed to the top-level 'step'.")
|
|
93
|
+
@click.option("--clone-run-id", help="Passed to the top-level 'step'.")
|
|
94
|
+
@click.option(
|
|
95
|
+
"--tag", multiple=True, default=None, help="Passed to the top-level 'step'."
|
|
96
|
+
)
|
|
97
|
+
@click.option("--namespace", default=None, help="Passed to the top-level 'step'.")
|
|
98
|
+
@click.option("--retry-count", default=0, help="Passed to the top-level 'step'.")
|
|
99
|
+
@click.option(
|
|
100
|
+
"--max-user-code-retries", default=0, help="Passed to the top-level 'step'."
|
|
101
|
+
)
|
|
102
|
+
# TODO: this is not used anywhere as of now...
|
|
103
|
+
@click.option(
|
|
104
|
+
"--run-time-limit",
|
|
105
|
+
default=5 * 24 * 60 * 60, # Default is set to 5 days
|
|
106
|
+
help="Run time limit in seconds for Snowpark container.",
|
|
107
|
+
)
|
|
108
|
+
@click.pass_context
|
|
109
|
+
def step(
|
|
110
|
+
ctx,
|
|
111
|
+
step_name,
|
|
112
|
+
code_package_sha,
|
|
113
|
+
code_package_url,
|
|
114
|
+
executable=None,
|
|
115
|
+
account=None,
|
|
116
|
+
user=None,
|
|
117
|
+
password=None,
|
|
118
|
+
role=None,
|
|
119
|
+
database=None,
|
|
120
|
+
warehouse=None,
|
|
121
|
+
schema=None,
|
|
122
|
+
image=None,
|
|
123
|
+
stage=None,
|
|
124
|
+
compute_pool=None,
|
|
125
|
+
volume_mounts=None,
|
|
126
|
+
external_integration=None,
|
|
127
|
+
cpu=None,
|
|
128
|
+
gpu=None,
|
|
129
|
+
memory=None,
|
|
130
|
+
run_time_limit=None,
|
|
131
|
+
**kwargs
|
|
132
|
+
):
|
|
133
|
+
def echo(msg, stream="stderr", job_id=None, **kwargs):
|
|
134
|
+
msg = util.to_unicode(msg)
|
|
135
|
+
if job_id:
|
|
136
|
+
msg = "[%s] %s" % (job_id, msg)
|
|
137
|
+
ctx.obj.echo_always(msg, err=(stream == sys.stderr), **kwargs)
|
|
138
|
+
|
|
139
|
+
if R.use_r():
|
|
140
|
+
entrypoint = R.entrypoint()
|
|
141
|
+
else:
|
|
142
|
+
executable = ctx.obj.environment.executable(step_name, executable)
|
|
143
|
+
entrypoint = "%s -u %s" % (executable, os.path.basename(sys.argv[0]))
|
|
144
|
+
|
|
145
|
+
top_args = " ".join(util.dict_to_cli_options(ctx.parent.parent.params))
|
|
146
|
+
|
|
147
|
+
input_paths = kwargs.get("input_paths")
|
|
148
|
+
split_vars = None
|
|
149
|
+
if input_paths:
|
|
150
|
+
max_size = 30 * 1024
|
|
151
|
+
split_vars = {
|
|
152
|
+
"METAFLOW_INPUT_PATHS_%d" % (i // max_size): input_paths[i : i + max_size]
|
|
153
|
+
for i in range(0, len(input_paths), max_size)
|
|
154
|
+
}
|
|
155
|
+
kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
|
|
156
|
+
|
|
157
|
+
step_args = " ".join(util.dict_to_cli_options(kwargs))
|
|
158
|
+
|
|
159
|
+
step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
|
|
160
|
+
entrypoint=entrypoint,
|
|
161
|
+
top_args=top_args,
|
|
162
|
+
step=step_name,
|
|
163
|
+
step_args=step_args,
|
|
164
|
+
)
|
|
165
|
+
node = ctx.obj.graph[step_name]
|
|
166
|
+
|
|
167
|
+
# Get retry information
|
|
168
|
+
retry_count = kwargs.get("retry_count", 0)
|
|
169
|
+
retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
|
|
170
|
+
minutes_between_retries = None
|
|
171
|
+
if retry_deco:
|
|
172
|
+
minutes_between_retries = int(
|
|
173
|
+
retry_deco[0].attributes.get("minutes_between_retries", 1)
|
|
174
|
+
)
|
|
175
|
+
if retry_count:
|
|
176
|
+
ctx.obj.echo_always(
|
|
177
|
+
"Sleeping %d minutes before the next retry" % minutes_between_retries
|
|
178
|
+
)
|
|
179
|
+
time.sleep(minutes_between_retries * 60)
|
|
180
|
+
|
|
181
|
+
# Set batch attributes
|
|
182
|
+
task_spec = {
|
|
183
|
+
"flow_name": ctx.obj.flow.name,
|
|
184
|
+
"step_name": step_name,
|
|
185
|
+
"run_id": kwargs["run_id"],
|
|
186
|
+
"task_id": kwargs["task_id"],
|
|
187
|
+
"retry_count": str(retry_count),
|
|
188
|
+
}
|
|
189
|
+
attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()}
|
|
190
|
+
attrs["metaflow.user"] = util.get_username()
|
|
191
|
+
attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[
|
|
192
|
+
"metaflow_version"
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
# Set environment
|
|
196
|
+
env = {}
|
|
197
|
+
env_deco = [deco for deco in node.decorators if deco.name == "environment"]
|
|
198
|
+
if env_deco:
|
|
199
|
+
env = env_deco[0].attributes["vars"]
|
|
200
|
+
|
|
201
|
+
# Add the environment variables related to the input-paths argument
|
|
202
|
+
if split_vars:
|
|
203
|
+
env.update(split_vars)
|
|
204
|
+
|
|
205
|
+
# Set log tailing.
|
|
206
|
+
ds = ctx.obj.flow_datastore.get_task_datastore(
|
|
207
|
+
mode="w",
|
|
208
|
+
run_id=kwargs["run_id"],
|
|
209
|
+
step_name=step_name,
|
|
210
|
+
task_id=kwargs["task_id"],
|
|
211
|
+
attempt=int(retry_count),
|
|
212
|
+
)
|
|
213
|
+
stdout_location = ds.get_log_location(TASK_LOG_SOURCE, "stdout")
|
|
214
|
+
stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr")
|
|
215
|
+
|
|
216
|
+
def _sync_metadata():
|
|
217
|
+
if ctx.obj.metadata.TYPE == "local":
|
|
218
|
+
sync_local_metadata_from_datastore(
|
|
219
|
+
DATASTORE_LOCAL_DIR,
|
|
220
|
+
ctx.obj.flow_datastore.get_task_datastore(
|
|
221
|
+
kwargs["run_id"], step_name, kwargs["task_id"]
|
|
222
|
+
),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
snowpark = Snowpark(
|
|
227
|
+
datastore=ctx.obj.flow_datastore,
|
|
228
|
+
metadata=ctx.obj.metadata,
|
|
229
|
+
environment=ctx.obj.environment,
|
|
230
|
+
client_credentials={
|
|
231
|
+
"account": account,
|
|
232
|
+
"user": user,
|
|
233
|
+
"password": password,
|
|
234
|
+
"role": role,
|
|
235
|
+
"database": database,
|
|
236
|
+
"warehouse": warehouse,
|
|
237
|
+
"schema": schema,
|
|
238
|
+
},
|
|
239
|
+
)
|
|
240
|
+
with ctx.obj.monitor.measure("metaflow.snowpark.launch_job"):
|
|
241
|
+
snowpark.launch_job(
|
|
242
|
+
step_name=step_name,
|
|
243
|
+
step_cli=step_cli,
|
|
244
|
+
task_spec=task_spec,
|
|
245
|
+
code_package_sha=code_package_sha,
|
|
246
|
+
code_package_url=code_package_url,
|
|
247
|
+
code_package_ds=ctx.obj.flow_datastore.TYPE,
|
|
248
|
+
image=image,
|
|
249
|
+
stage=stage,
|
|
250
|
+
compute_pool=compute_pool,
|
|
251
|
+
volume_mounts=volume_mounts,
|
|
252
|
+
external_integration=external_integration,
|
|
253
|
+
cpu=cpu,
|
|
254
|
+
gpu=gpu,
|
|
255
|
+
memory=memory,
|
|
256
|
+
run_time_limit=run_time_limit,
|
|
257
|
+
env=env,
|
|
258
|
+
attrs=attrs,
|
|
259
|
+
)
|
|
260
|
+
except Exception:
|
|
261
|
+
traceback.print_exc(chain=False)
|
|
262
|
+
_sync_metadata()
|
|
263
|
+
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
|
|
264
|
+
try:
|
|
265
|
+
snowpark.wait(stdout_location, stderr_location, echo=echo)
|
|
266
|
+
except SnowparkKilledException:
|
|
267
|
+
# don't retry killed tasks
|
|
268
|
+
traceback.print_exc()
|
|
269
|
+
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
|
|
270
|
+
finally:
|
|
271
|
+
_sync_metadata()
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import tempfile
|
|
3
|
+
|
|
4
|
+
from .snowpark_exceptions import SnowflakeException
|
|
5
|
+
from .snowpark_service_spec import generate_spec_file
|
|
6
|
+
|
|
7
|
+
from metaflow.exception import MetaflowException
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SnowparkClient(object):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
account: str,
|
|
14
|
+
user: str,
|
|
15
|
+
password: str,
|
|
16
|
+
role: str,
|
|
17
|
+
database: str,
|
|
18
|
+
warehouse: str,
|
|
19
|
+
schema: str,
|
|
20
|
+
autocommit: bool = True,
|
|
21
|
+
):
|
|
22
|
+
try:
|
|
23
|
+
from snowflake.core import Root
|
|
24
|
+
from snowflake.snowpark import Session
|
|
25
|
+
|
|
26
|
+
from snowflake.connector.errors import DatabaseError
|
|
27
|
+
except (NameError, ImportError, ModuleNotFoundError):
|
|
28
|
+
raise SnowflakeException(
|
|
29
|
+
"Could not import module 'snowflake'.\n\nInstall Snowflake "
|
|
30
|
+
"Python package (https://pypi.org/project/snowflake/) first.\n"
|
|
31
|
+
"You can install the module by executing - "
|
|
32
|
+
"%s -m pip install snowflake\n"
|
|
33
|
+
"or equivalent through your favorite Python package manager."
|
|
34
|
+
% sys.executable
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
self.connection_parameters = {
|
|
38
|
+
"account": account,
|
|
39
|
+
"user": user,
|
|
40
|
+
"password": password,
|
|
41
|
+
"role": role,
|
|
42
|
+
"warehouse": warehouse,
|
|
43
|
+
"database": database,
|
|
44
|
+
"schema": schema,
|
|
45
|
+
"autocommit": autocommit,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
self.session = Session.builder.configs(self.connection_parameters).create()
|
|
50
|
+
self.root = Root(self.session)
|
|
51
|
+
except DatabaseError as e:
|
|
52
|
+
raise ConnectionError(e)
|
|
53
|
+
|
|
54
|
+
def __del__(self):
|
|
55
|
+
if hasattr(self, "session"):
|
|
56
|
+
self.session.close()
|
|
57
|
+
|
|
58
|
+
def __check_existence_of_stage_and_compute_pool(
|
|
59
|
+
self, db, schema, stage, compute_pool
|
|
60
|
+
):
|
|
61
|
+
from snowflake.core.exceptions import NotFoundError
|
|
62
|
+
|
|
63
|
+
# check if stage exists, will raise an error otherwise
|
|
64
|
+
try:
|
|
65
|
+
self.root.databases[db].schemas[schema].stages[stage].fetch()
|
|
66
|
+
except NotFoundError:
|
|
67
|
+
raise MetaflowException(
|
|
68
|
+
"Stage *%s* does not exist or not authorized." % stage
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# check if compute_pool exists, will raise an error otherwise
|
|
72
|
+
try:
|
|
73
|
+
self.root.compute_pools[compute_pool].fetch()
|
|
74
|
+
except NotFoundError:
|
|
75
|
+
raise MetaflowException(
|
|
76
|
+
"Compute pool *%s* does not exist or not authorized." % compute_pool
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def submit(self, name: str, spec, stage, compute_pool, external_integration):
|
|
80
|
+
db = self.session.get_current_database()
|
|
81
|
+
schema = self.session.get_current_schema()
|
|
82
|
+
|
|
83
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
84
|
+
snowpark_spec_file = tempfile.NamedTemporaryFile(
|
|
85
|
+
dir=temp_dir, delete=False, suffix=".yaml"
|
|
86
|
+
)
|
|
87
|
+
generate_spec_file(spec, snowpark_spec_file.name, format="yaml")
|
|
88
|
+
|
|
89
|
+
self.__check_existence_of_stage_and_compute_pool(
|
|
90
|
+
db, schema, stage, compute_pool
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# upload the spec file to stage
|
|
94
|
+
result = self.session.file.put(snowpark_spec_file.name, "@%s" % stage)
|
|
95
|
+
|
|
96
|
+
service_name = name.replace("-", "_")
|
|
97
|
+
external_access = (
|
|
98
|
+
"EXTERNAL_ACCESS_INTEGRATIONS=(%s) " % ",".join(external_integration)
|
|
99
|
+
if external_integration
|
|
100
|
+
else ""
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# cannot pass 'is_job' parameter using the API, thus we need to use SQL directly..
|
|
104
|
+
query = """
|
|
105
|
+
EXECUTE JOB SERVICE IN COMPUTE POOL {compute_pool}
|
|
106
|
+
NAME = {db}.{schema}.{service_name}
|
|
107
|
+
{external_access}
|
|
108
|
+
FROM @{stage} SPECIFICATION_FILE={specification_file}
|
|
109
|
+
""".format(
|
|
110
|
+
compute_pool=compute_pool,
|
|
111
|
+
db=db,
|
|
112
|
+
schema=schema,
|
|
113
|
+
service_name=service_name,
|
|
114
|
+
external_access=external_access,
|
|
115
|
+
stage=stage,
|
|
116
|
+
specification_file=result[0].target,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
async_job = self.session.sql(query).collect(block=False)
|
|
120
|
+
return async_job.query_id, service_name
|
|
121
|
+
|
|
122
|
+
def terminate_job(self, service):
|
|
123
|
+
service.delete()
|