ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (128) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -7
  2. metaflow_extensions/outerbounds/config/__init__.py +35 -0
  3. metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
  4. metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
  5. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  6. metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
  7. metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  33. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  34. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  35. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  36. metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
  37. metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
  38. metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
  39. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  40. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  41. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
  42. metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
  43. metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
  44. metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
  45. metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
  46. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  47. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  48. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  49. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  50. metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  51. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  52. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
  53. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
  54. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
  55. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
  56. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
  57. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  58. metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
  59. metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
  60. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
  61. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  62. metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
  63. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
  64. metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
  65. metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
  66. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
  67. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
  68. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
  69. metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
  70. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  71. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  72. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  73. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  74. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  75. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  76. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  77. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
  78. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  79. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  80. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
  81. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  82. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  83. metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
  84. metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
  85. metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
  86. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  87. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  88. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  89. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  90. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  91. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  92. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  93. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  94. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  95. metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
  96. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
  97. metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
  98. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
  99. metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  100. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
  101. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
  102. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
  103. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
  104. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
  105. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
  106. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
  107. metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
  108. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  109. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  110. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  111. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  112. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  113. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  114. metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
  115. metaflow_extensions/outerbounds/remote_config.py +53 -16
  116. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
  117. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  118. metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
  119. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  120. metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
  121. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  122. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  123. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  124. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
  125. ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
  126. ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
  127. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
  128. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,277 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import traceback
5
+
6
+ from metaflow._vendor import click
7
+ from metaflow import util
8
+ from metaflow import R
9
+ from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
10
+ from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
11
+ from metaflow.metadata_provider.util import sync_local_metadata_from_datastore
12
+ from metaflow.mflog import TASK_LOG_SOURCE
13
+
14
+ from .snowpark_exceptions import SnowparkKilledException
15
+ from .snowpark import Snowpark
16
+
17
+
18
+ @click.group()
19
+ def cli():
20
+ pass
21
+
22
+
23
+ @cli.group(help="Commands related to Snowpark.")
24
+ def snowpark():
25
+ pass
26
+
27
+
28
+ @snowpark.command(
29
+ help="Execute a single task using Snowpark. This command calls the "
30
+ "top-level step command inside a Snowpark job with the given options. "
31
+ "Typically you do not call this command directly; it is used internally by "
32
+ "Metaflow."
33
+ )
34
+ @click.argument("step-name")
35
+ @click.argument("code-package-sha")
36
+ @click.argument("code-package-url")
37
+ @click.option(
38
+ "--executable", help="Executable requirement for Snowpark Container Services."
39
+ )
40
+ # below are snowpark specific options
41
+ @click.option(
42
+ "--account",
43
+ help="Account ID for Snowpark Container Services.",
44
+ )
45
+ @click.option(
46
+ "--user",
47
+ help="Username for Snowpark Container Services.",
48
+ )
49
+ @click.option(
50
+ "--password",
51
+ help="Password for Snowpark Container Services.",
52
+ )
53
+ @click.option(
54
+ "--role",
55
+ help="Role for Snowpark Container Services.",
56
+ )
57
+ @click.option(
58
+ "--database",
59
+ help="Database for Snowpark Container Services.",
60
+ )
61
+ @click.option(
62
+ "--warehouse",
63
+ help="Warehouse for Snowpark Container Services.",
64
+ )
65
+ @click.option(
66
+ "--schema",
67
+ help="Schema for Snowpark Container Services.",
68
+ )
69
+ @click.option(
70
+ "--integration",
71
+ help="Outerbounds OAuth integration name for Snowpark Container Services. When set, uses OAuth authentication instead of password.",
72
+ )
73
+ @click.option(
74
+ "--image",
75
+ help="Docker image requirement for Snowpark Container Services. In name:version format.",
76
+ )
77
+ @click.option("--stage", help="Stage requirement for Snowpark Container Services.")
78
+ @click.option(
79
+ "--compute-pool", help="Compute Pool requirement for Snowpark Container Services."
80
+ )
81
+ @click.option("--volume-mounts", multiple=True)
82
+ @click.option(
83
+ "--external-integration",
84
+ multiple=True,
85
+ help="External Integration requirement for Snowpark Container Services.",
86
+ )
87
+ @click.option("--cpu", help="CPU requirement for Snowpark Container Services.")
88
+ @click.option("--gpu", help="GPU requirement for Snowpark Container Services.")
89
+ @click.option("--memory", help="Memory requirement for Snowpark Container Services.")
90
+ # TODO: secrets, volumes for snowpark
91
+ # TODO: others to consider: ubf-context, num-parallel
92
+ @click.option("--run-id", help="Passed to the top-level 'step'.")
93
+ @click.option("--task-id", help="Passed to the top-level 'step'.")
94
+ @click.option("--input-paths", help="Passed to the top-level 'step'.")
95
+ @click.option("--split-index", help="Passed to the top-level 'step'.")
96
+ @click.option("--clone-path", help="Passed to the top-level 'step'.")
97
+ @click.option("--clone-run-id", help="Passed to the top-level 'step'.")
98
+ @click.option(
99
+ "--tag", multiple=True, default=None, help="Passed to the top-level 'step'."
100
+ )
101
+ @click.option("--namespace", default=None, help="Passed to the top-level 'step'.")
102
+ @click.option("--retry-count", default=0, help="Passed to the top-level 'step'.")
103
+ @click.option(
104
+ "--max-user-code-retries", default=0, help="Passed to the top-level 'step'."
105
+ )
106
+ # TODO: this is not used anywhere as of now...
107
+ @click.option(
108
+ "--run-time-limit",
109
+ default=5 * 24 * 60 * 60, # Default is set to 5 days
110
+ help="Run time limit in seconds for Snowpark container.",
111
+ )
112
+ @click.pass_context
113
+ def step(
114
+ ctx,
115
+ step_name,
116
+ code_package_sha,
117
+ code_package_url,
118
+ executable=None,
119
+ account=None,
120
+ user=None,
121
+ password=None,
122
+ role=None,
123
+ database=None,
124
+ warehouse=None,
125
+ schema=None,
126
+ integration=None,
127
+ image=None,
128
+ stage=None,
129
+ compute_pool=None,
130
+ volume_mounts=None,
131
+ external_integration=None,
132
+ cpu=None,
133
+ gpu=None,
134
+ memory=None,
135
+ run_time_limit=None,
136
+ **kwargs
137
+ ):
138
+ def echo(msg, stream="stderr", job_id=None, **kwargs):
139
+ msg = util.to_unicode(msg)
140
+ if job_id:
141
+ msg = "[%s] %s" % (job_id, msg)
142
+ ctx.obj.echo_always(msg, err=(stream == sys.stderr), **kwargs)
143
+
144
+ if R.use_r():
145
+ entrypoint = R.entrypoint()
146
+ else:
147
+ executable = ctx.obj.environment.executable(step_name, executable)
148
+ entrypoint = "%s -u %s" % (executable, os.path.basename(sys.argv[0]))
149
+
150
+ top_args = " ".join(util.dict_to_cli_options(ctx.parent.parent.params))
151
+
152
+ input_paths = kwargs.get("input_paths")
153
+ split_vars = None
154
+ if input_paths:
155
+ max_size = 30 * 1024
156
+ split_vars = {
157
+ "METAFLOW_INPUT_PATHS_%d" % (i // max_size): input_paths[i : i + max_size]
158
+ for i in range(0, len(input_paths), max_size)
159
+ }
160
+ kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
161
+
162
+ step_args = " ".join(util.dict_to_cli_options(kwargs))
163
+
164
+ step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
165
+ entrypoint=entrypoint,
166
+ top_args=top_args,
167
+ step=step_name,
168
+ step_args=step_args,
169
+ )
170
+ node = ctx.obj.graph[step_name]
171
+
172
+ # Get retry information
173
+ retry_count = kwargs.get("retry_count", 0)
174
+ retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
175
+ minutes_between_retries = None
176
+ if retry_deco:
177
+ minutes_between_retries = int(
178
+ retry_deco[0].attributes.get("minutes_between_retries", 1)
179
+ )
180
+ if retry_count:
181
+ ctx.obj.echo_always(
182
+ "Sleeping %d minutes before the next retry" % minutes_between_retries
183
+ )
184
+ time.sleep(minutes_between_retries * 60)
185
+
186
+ # Set batch attributes
187
+ task_spec = {
188
+ "flow_name": ctx.obj.flow.name,
189
+ "step_name": step_name,
190
+ "run_id": kwargs["run_id"],
191
+ "task_id": kwargs["task_id"],
192
+ "retry_count": str(retry_count),
193
+ }
194
+ attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()}
195
+ attrs["metaflow.user"] = util.get_username()
196
+ attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[
197
+ "metaflow_version"
198
+ ]
199
+
200
+ # Set environment
201
+ env = {}
202
+ env_deco = [deco for deco in node.decorators if deco.name == "environment"]
203
+ if env_deco:
204
+ env = env_deco[0].attributes["vars"]
205
+
206
+ # Add the environment variables related to the input-paths argument
207
+ if split_vars:
208
+ env.update(split_vars)
209
+
210
+ # Set log tailing.
211
+ ds = ctx.obj.flow_datastore.get_task_datastore(
212
+ mode="w",
213
+ run_id=kwargs["run_id"],
214
+ step_name=step_name,
215
+ task_id=kwargs["task_id"],
216
+ attempt=int(retry_count),
217
+ )
218
+ stdout_location = ds.get_log_location(TASK_LOG_SOURCE, "stdout")
219
+ stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr")
220
+
221
+ def _sync_metadata():
222
+ if ctx.obj.metadata.TYPE == "local":
223
+ sync_local_metadata_from_datastore(
224
+ DATASTORE_LOCAL_DIR,
225
+ ctx.obj.flow_datastore.get_task_datastore(
226
+ kwargs["run_id"], step_name, kwargs["task_id"]
227
+ ),
228
+ )
229
+
230
+ try:
231
+ snowpark = Snowpark(
232
+ datastore=ctx.obj.flow_datastore,
233
+ metadata=ctx.obj.metadata,
234
+ environment=ctx.obj.environment,
235
+ client_credentials={
236
+ "account": account,
237
+ "user": user,
238
+ "password": password,
239
+ "role": role,
240
+ "database": database,
241
+ "warehouse": warehouse,
242
+ "schema": schema,
243
+ "integration": integration,
244
+ },
245
+ )
246
+ with ctx.obj.monitor.measure("metaflow.snowpark.launch_job"):
247
+ snowpark.launch_job(
248
+ step_name=step_name,
249
+ step_cli=step_cli,
250
+ task_spec=task_spec,
251
+ code_package_sha=code_package_sha,
252
+ code_package_url=code_package_url,
253
+ code_package_ds=ctx.obj.flow_datastore.TYPE,
254
+ image=image,
255
+ stage=stage,
256
+ compute_pool=compute_pool,
257
+ volume_mounts=volume_mounts,
258
+ external_integration=external_integration,
259
+ cpu=cpu,
260
+ gpu=gpu,
261
+ memory=memory,
262
+ run_time_limit=run_time_limit,
263
+ env=env,
264
+ attrs=attrs,
265
+ )
266
+ except Exception:
267
+ traceback.print_exc(chain=False)
268
+ _sync_metadata()
269
+ sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
270
+ try:
271
+ snowpark.wait(stdout_location, stderr_location, echo=echo)
272
+ except SnowparkKilledException:
273
+ # don't retry killed tasks
274
+ traceback.print_exc()
275
+ sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
276
+ finally:
277
+ _sync_metadata()
@@ -0,0 +1,150 @@
1
+ import sys
2
+ import tempfile
3
+
4
+ from .snowpark_exceptions import SnowflakeException
5
+ from .snowpark_service_spec import generate_spec_file
6
+
7
+ from metaflow.exception import MetaflowException
8
+
9
+
10
+ class SnowparkClient(object):
11
+ def __init__(
12
+ self,
13
+ account: str = None,
14
+ user: str = None,
15
+ password: str = None,
16
+ role: str = None,
17
+ database: str = None,
18
+ warehouse: str = None,
19
+ schema: str = None,
20
+ autocommit: bool = True,
21
+ integration: str = None,
22
+ ):
23
+ try:
24
+ from snowflake.core import Root
25
+ from snowflake.snowpark import Session
26
+
27
+ from snowflake.connector.errors import DatabaseError
28
+ except (NameError, ImportError, ModuleNotFoundError):
29
+ raise SnowflakeException(
30
+ "Could not import module 'snowflake'.\n\nInstall Snowflake "
31
+ "Python packages first:\n"
32
+ " snowflake==1.8.0\n"
33
+ " snowflake-connector-python==3.18.0\n"
34
+ " snowflake-snowpark-python==1.40.0\n\n"
35
+ "You can install them by executing:\n"
36
+ "%s -m pip install snowflake==1.8.0 snowflake-connector-python==3.18.0 snowflake-snowpark-python==1.40.0\n"
37
+ "or equivalent through your favorite Python package manager."
38
+ % sys.executable
39
+ )
40
+
41
+ if integration:
42
+ # Use OAuth authentication via Outerbounds integration
43
+ from metaflow_extensions.outerbounds.plugins.snowflake.snowflake import (
44
+ get_oauth_connection_params,
45
+ )
46
+
47
+ self.connection_parameters = get_oauth_connection_params(
48
+ user=user or "",
49
+ role=role or "",
50
+ integration=integration,
51
+ schema=schema or "",
52
+ account=account,
53
+ warehouse=warehouse,
54
+ database=database,
55
+ )
56
+ self.connection_parameters["autocommit"] = autocommit
57
+ else:
58
+ # Password-based authentication
59
+ self.connection_parameters = {
60
+ "account": account,
61
+ "user": user,
62
+ "password": password,
63
+ "role": role,
64
+ "warehouse": warehouse,
65
+ "database": database,
66
+ "schema": schema,
67
+ "autocommit": autocommit,
68
+ }
69
+
70
+ # Remove None values from connection parameters
71
+ self.connection_parameters = {
72
+ k: v for k, v in self.connection_parameters.items() if v is not None
73
+ }
74
+
75
+ try:
76
+ self.session = Session.builder.configs(self.connection_parameters).create()
77
+ self.root = Root(self.session)
78
+ except DatabaseError as e:
79
+ raise ConnectionError(e)
80
+
81
+ def __del__(self):
82
+ if hasattr(self, "session"):
83
+ self.session.close()
84
+
85
+ def __check_existence_of_stage_and_compute_pool(
86
+ self, db, schema, stage, compute_pool
87
+ ):
88
+ from snowflake.core.exceptions import NotFoundError
89
+
90
+ # check if stage exists, will raise an error otherwise
91
+ try:
92
+ self.root.databases[db].schemas[schema].stages[stage].fetch()
93
+ except NotFoundError:
94
+ raise MetaflowException(
95
+ "Stage *%s* does not exist or not authorized." % stage
96
+ )
97
+
98
+ # check if compute_pool exists, will raise an error otherwise
99
+ try:
100
+ self.root.compute_pools[compute_pool].fetch()
101
+ except NotFoundError:
102
+ raise MetaflowException(
103
+ "Compute pool *%s* does not exist or not authorized." % compute_pool
104
+ )
105
+
106
+ def submit(self, name: str, spec, stage, compute_pool, external_integration):
107
+ db = self.session.get_current_database()
108
+ schema = self.session.get_current_schema()
109
+
110
+ with tempfile.TemporaryDirectory() as temp_dir:
111
+ snowpark_spec_file = tempfile.NamedTemporaryFile(
112
+ dir=temp_dir, delete=False, suffix=".yaml"
113
+ )
114
+ generate_spec_file(spec, snowpark_spec_file.name, format="yaml")
115
+
116
+ self.__check_existence_of_stage_and_compute_pool(
117
+ db, schema, stage, compute_pool
118
+ )
119
+
120
+ # upload the spec file to stage
121
+ result = self.session.file.put(snowpark_spec_file.name, "@%s" % stage)
122
+
123
+ service_name = name.replace("-", "_")
124
+ external_access = (
125
+ "EXTERNAL_ACCESS_INTEGRATIONS=(%s) " % ",".join(external_integration)
126
+ if external_integration
127
+ else ""
128
+ )
129
+
130
+ # cannot pass 'is_job' parameter using the API, thus we need to use SQL directly..
131
+ query = """
132
+ EXECUTE JOB SERVICE IN COMPUTE POOL {compute_pool}
133
+ NAME = {db}.{schema}.{service_name}
134
+ {external_access}
135
+ FROM @{stage} SPECIFICATION_FILE={specification_file}
136
+ """.format(
137
+ compute_pool=compute_pool,
138
+ db=db,
139
+ schema=schema,
140
+ service_name=service_name,
141
+ external_access=external_access,
142
+ stage=stage,
143
+ specification_file=result[0].target,
144
+ )
145
+
146
+ async_job = self.session.sql(query).collect(block=False)
147
+ return async_job.query_id, service_name
148
+
149
+ def terminate_job(self, service):
150
+ service.delete()
@@ -0,0 +1,273 @@
1
+ import os
2
+ import sys
3
+ import platform
4
+
5
+ from metaflow import R, current
6
+ from metaflow.metadata_provider import MetaDatum
7
+ from metaflow.metadata_provider.util import sync_local_metadata_to_datastore
8
+ from metaflow.sidecar import Sidecar
9
+ from metaflow.decorators import StepDecorator
10
+ from metaflow.exception import MetaflowException
11
+ from metaflow.metaflow_config import (
12
+ DEFAULT_CONTAINER_IMAGE,
13
+ DEFAULT_CONTAINER_REGISTRY,
14
+ SNOWPARK_ACCOUNT,
15
+ SNOWPARK_USER,
16
+ SNOWPARK_PASSWORD,
17
+ SNOWPARK_ROLE,
18
+ SNOWPARK_DATABASE,
19
+ SNOWPARK_WAREHOUSE,
20
+ SNOWPARK_SCHEMA,
21
+ )
22
+
23
+ from metaflow.metaflow_config import (
24
+ DATASTORE_LOCAL_DIR,
25
+ )
26
+
27
+ from .snowpark_exceptions import SnowflakeException
28
+ from metaflow.plugins.aws.aws_utils import get_docker_registry
29
+
30
+
31
+ class Snowflake(object):
32
+ def __init__(self, connection_params):
33
+ self.connection_params = connection_params
34
+
35
+ def session(self):
36
+ # if using the pypi/conda decorator with @snowpark of any step,
37
+ # make sure to pass {'snowflake': '0.11.0'} as well to that step
38
+ try:
39
+ from snowflake.snowpark import Session
40
+
41
+ session = Session.builder.configs(self.connection_params).create()
42
+ return session
43
+ except (NameError, ImportError, ModuleNotFoundError):
44
+ raise SnowflakeException(
45
+ "Could not import module 'snowflake'.\n\n"
46
+ "Install required Snowflake packages using the @pypi decorator:\n"
47
+ "@pypi(packages={\n"
48
+ " 'snowflake': '1.8.0',\n"
49
+ " 'snowflake-connector-python': '3.18.0',\n"
50
+ " 'snowflake-snowpark-python': '1.40.0'\n"
51
+ "})\n"
52
+ )
53
+
54
+
55
+ class SnowparkDecorator(StepDecorator):
56
+ name = "snowpark"
57
+
58
+ defaults = {
59
+ "account": None,
60
+ "user": None,
61
+ "password": None,
62
+ "role": None,
63
+ "database": None,
64
+ "warehouse": None,
65
+ "schema": None,
66
+ "image": None,
67
+ "stage": None,
68
+ "compute_pool": None,
69
+ "volume_mounts": None,
70
+ "external_integration": None,
71
+ "cpu": None,
72
+ "gpu": None,
73
+ "memory": None,
74
+ "integration": None, # Outerbounds OAuth integration name
75
+ }
76
+
77
+ package_url = None
78
+ package_sha = None
79
+ run_time_limit = None
80
+
81
+ def __init__(self, attributes=None, statically_defined=False):
82
+ super(SnowparkDecorator, self).__init__(attributes, statically_defined)
83
+
84
+ # Set defaults from config (user can override via decorator or integration)
85
+ if not self.attributes["account"]:
86
+ self.attributes["account"] = SNOWPARK_ACCOUNT
87
+ if not self.attributes["user"]:
88
+ self.attributes["user"] = SNOWPARK_USER
89
+ if not self.attributes["role"]:
90
+ self.attributes["role"] = SNOWPARK_ROLE
91
+ if not self.attributes["database"]:
92
+ self.attributes["database"] = SNOWPARK_DATABASE
93
+ if not self.attributes["warehouse"]:
94
+ self.attributes["warehouse"] = SNOWPARK_WAREHOUSE
95
+ if not self.attributes["schema"]:
96
+ self.attributes["schema"] = SNOWPARK_SCHEMA
97
+ # Only use password from config if not using integration (OAuth)
98
+ if not self.attributes["integration"] and not self.attributes["password"]:
99
+ self.attributes["password"] = SNOWPARK_PASSWORD
100
+
101
+ # If no docker image is explicitly specified, impute a default image.
102
+ if not self.attributes["image"]:
103
+ # If metaflow-config specifies a docker image, just use that.
104
+ if DEFAULT_CONTAINER_IMAGE:
105
+ self.attributes["image"] = DEFAULT_CONTAINER_IMAGE
106
+ # If metaflow-config doesn't specify a docker image, assign a
107
+ # default docker image.
108
+ else:
109
+ # Metaflow-R has its own default docker image (rocker family)
110
+ if R.use_r():
111
+ self.attributes["image"] = R.container_image()
112
+ # Default to vanilla Python image corresponding to major.minor
113
+ # version of the Python interpreter launching the flow.
114
+ self.attributes["image"] = "python:%s.%s" % (
115
+ platform.python_version_tuple()[0],
116
+ platform.python_version_tuple()[1],
117
+ )
118
+
119
+ # Assign docker registry URL for the image.
120
+ if not get_docker_registry(self.attributes["image"]):
121
+ if DEFAULT_CONTAINER_REGISTRY:
122
+ self.attributes["image"] = "%s/%s" % (
123
+ DEFAULT_CONTAINER_REGISTRY.rstrip("/"),
124
+ self.attributes["image"],
125
+ )
126
+
127
+ # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
128
+ # to understand where these functions are invoked in the lifecycle of a
129
+ # Metaflow flow.
130
+ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger):
131
+ # Set internal state.
132
+ self.logger = logger
133
+ self.environment = environment
134
+ self.step = step
135
+ self.flow_datastore = flow_datastore
136
+
137
+ if any([deco.name == "parallel" for deco in decos]):
138
+ raise MetaflowException(
139
+ "Step *{step}* contains a @parallel decorator "
140
+ "with the @snowpark decorator. @parallel is not supported with @snowpark.".format(
141
+ step=step
142
+ )
143
+ )
144
+
145
+ def package_init(self, flow, step_name, environment):
146
+ try:
147
+ # Snowflake is a soft dependency.
148
+ from snowflake.snowpark import Session
149
+ except (NameError, ImportError, ModuleNotFoundError):
150
+ raise SnowflakeException(
151
+ "Could not import module 'snowflake'.\n\nInstall Snowflake "
152
+ "Python packages first:\n"
153
+ " snowflake==1.8.0\n"
154
+ " snowflake-connector-python==3.18.0\n"
155
+ " snowflake-snowpark-python==1.40.0\n\n"
156
+ "You can install them by executing:\n"
157
+ "%s -m pip install snowflake==1.8.0 snowflake-connector-python==3.18.0 snowflake-snowpark-python==1.40.0\n"
158
+ "or equivalent through your favorite Python package manager."
159
+ % sys.executable
160
+ )
161
+
162
+ def runtime_init(self, flow, graph, package, run_id):
163
+ # Set some more internal state.
164
+ self.flow = flow
165
+ self.graph = graph
166
+ self.package = package
167
+ self.run_id = run_id
168
+
169
+ def runtime_task_created(
170
+ self, task_datastore, task_id, split_index, input_paths, is_cloned, ubf_context
171
+ ):
172
+ if not is_cloned:
173
+ self._save_package_once(self.flow_datastore, self.package)
174
+
175
+ def runtime_step_cli(
176
+ self, cli_args, retry_count, max_user_code_retries, ubf_context
177
+ ):
178
+ if retry_count <= max_user_code_retries:
179
+ cli_args.commands = ["snowpark", "step"]
180
+ cli_args.command_args.append(self.package_sha)
181
+ cli_args.command_args.append(self.package_url)
182
+ cli_args.command_options.update(self.attributes)
183
+ cli_args.command_options["run-time-limit"] = self.run_time_limit
184
+ if not R.use_r():
185
+ cli_args.entrypoint[0] = sys.executable
186
+
187
+ def task_pre_step(
188
+ self,
189
+ step_name,
190
+ task_datastore,
191
+ metadata,
192
+ run_id,
193
+ task_id,
194
+ flow,
195
+ graph,
196
+ retry_count,
197
+ max_retries,
198
+ ubf_context,
199
+ inputs,
200
+ ):
201
+ self.metadata = metadata
202
+ self.task_datastore = task_datastore
203
+
204
+ # this path will exist within snowpark container services
205
+ login_token = open("/snowflake/session/token", "r").read()
206
+ connection_params = {
207
+ "account": os.environ.get("SNOWFLAKE_ACCOUNT"),
208
+ "host": os.environ.get("SNOWFLAKE_HOST"),
209
+ "authenticator": "oauth",
210
+ "token": login_token,
211
+ "database": os.environ.get("SNOWFLAKE_DATABASE"),
212
+ "schema": os.environ.get("SNOWFLAKE_SCHEMA"),
213
+ "autocommit": True,
214
+ "client_session_keep_alive": True,
215
+ }
216
+
217
+ # SNOWFLAKE_WAREHOUSE is injected explicitly by us
218
+ # but is not really required. So if it exists, we use it in
219
+ # connection parameters
220
+ if os.environ.get("SNOWFLAKE_WAREHOUSE"):
221
+ connection_params["warehouse"] = os.environ.get("SNOWFLAKE_WAREHOUSE")
222
+
223
+ snowflake_obj = Snowflake(connection_params)
224
+ current._update_env({"snowflake": snowflake_obj})
225
+
226
+ meta = {}
227
+ if "METAFLOW_SNOWPARK_WORKLOAD" in os.environ:
228
+ meta["snowflake-account"] = os.environ.get("SNOWFLAKE_ACCOUNT")
229
+ meta["snowflake-database"] = os.environ.get("SNOWFLAKE_DATABASE")
230
+ meta["snowflake-schema"] = os.environ.get("SNOWFLAKE_SCHEMA")
231
+ meta["snowflake-host"] = os.environ.get("SNOWFLAKE_HOST")
232
+ meta["snowflake-service-name"] = os.environ.get("SNOWFLAKE_SERVICE_NAME")
233
+
234
+ self._save_logs_sidecar = Sidecar("save_logs_periodically")
235
+ self._save_logs_sidecar.start()
236
+
237
+ if len(meta) > 0:
238
+ entries = [
239
+ MetaDatum(
240
+ field=k,
241
+ value=v,
242
+ type=k,
243
+ tags=["attempt_id:{0}".format(retry_count)],
244
+ )
245
+ for k, v in meta.items()
246
+ if v is not None
247
+ ]
248
+ # Register book-keeping metadata for debugging.
249
+ metadata.register_metadata(run_id, step_name, task_id, entries)
250
+
251
+ def task_finished(
252
+ self, step_name, flow, graph, is_task_ok, retry_count, max_retries
253
+ ):
254
+ if "METAFLOW_SNOWPARK_WORKLOAD" in os.environ:
255
+ if hasattr(self, "metadata") and self.metadata.TYPE == "local":
256
+ # Note that the datastore is *always* Amazon S3 (see
257
+ # runtime_task_created function).
258
+ sync_local_metadata_to_datastore(
259
+ DATASTORE_LOCAL_DIR, self.task_datastore
260
+ )
261
+
262
+ try:
263
+ self._save_logs_sidecar.terminate()
264
+ except:
265
+ # Best effort kill
266
+ pass
267
+
268
+ @classmethod
269
+ def _save_package_once(cls, flow_datastore, package):
270
+ if cls.package_url is None:
271
+ cls.package_url, cls.package_sha = flow_datastore.save_data(
272
+ [package.blob], len_hint=1
273
+ )[0]