ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow-extensions might be problematic. Click here for more details.
- metaflow_extensions/outerbounds/__init__.py +1 -7
- metaflow_extensions/outerbounds/config/__init__.py +35 -0
- metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
- metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
- metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
- metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
- metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
- metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
- metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
- metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
- metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
- metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
- metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
- metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
- metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
- metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
- metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
- metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
- metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
- metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
- metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
- metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
- metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
- metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
- metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
- metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
- metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
- metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
- metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
- metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
- metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
- metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
- metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
- metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
- metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
- metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
- metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
- metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
- metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
- metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
- metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
- metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
- metaflow_extensions/outerbounds/remote_config.py +53 -16
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
- metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
- metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
- {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
- ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
- ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
- {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from metaflow.exception import MetaflowException
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SnowflakeException(MetaflowException):
|
|
5
|
+
headline = "Snowflake error"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SnowparkException(MetaflowException):
|
|
9
|
+
headline = "Snowpark error"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SnowparkKilledException(MetaflowException):
|
|
13
|
+
headline = "Snowpark job killed"
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
from .snowpark_client import SnowparkClient
|
|
4
|
+
from .snowpark_service_spec import (
|
|
5
|
+
Container,
|
|
6
|
+
Resources,
|
|
7
|
+
SnowparkServiceSpec,
|
|
8
|
+
VolumeMount,
|
|
9
|
+
)
|
|
10
|
+
from .snowpark_exceptions import SnowparkException
|
|
11
|
+
|
|
12
|
+
mapping = str.maketrans("0123456789", "abcdefghij")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# keep only alpha numeric characters and dashes..
|
|
16
|
+
def sanitize_name(job_name: str):
|
|
17
|
+
return "".join(char for char in job_name if char.isalnum() or char == "-")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# this is not a decorator since the exception imports need to be inside
|
|
21
|
+
# and not at the top level..
|
|
22
|
+
def retry_operation(
|
|
23
|
+
exception_type, func, max_retries=3, retry_delay=2, *args, **kwargs
|
|
24
|
+
):
|
|
25
|
+
for attempt in range(max_retries):
|
|
26
|
+
try:
|
|
27
|
+
return func(*args, **kwargs)
|
|
28
|
+
except exception_type as e:
|
|
29
|
+
if attempt < max_retries - 1:
|
|
30
|
+
time.sleep(retry_delay)
|
|
31
|
+
else:
|
|
32
|
+
raise e
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SnowparkJob(object):
|
|
36
|
+
def __init__(self, client: SnowparkClient, name, command, **kwargs):
|
|
37
|
+
self.client = client
|
|
38
|
+
self.name = sanitize_name(name)
|
|
39
|
+
self.command = command
|
|
40
|
+
self.kwargs = kwargs
|
|
41
|
+
self.container_name = self.name.translate(mapping).lower()
|
|
42
|
+
|
|
43
|
+
def create_job_spec(self):
|
|
44
|
+
if self.kwargs.get("image") is None:
|
|
45
|
+
raise SnowparkException(
|
|
46
|
+
"Unable to launch job on Snowpark Container Services. No docker 'image' specified."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if self.kwargs.get("stage") is None:
|
|
50
|
+
raise SnowparkException(
|
|
51
|
+
"Unable to launch job on Snowpark Container Services. No 'stage' specified."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if self.kwargs.get("compute_pool") is None:
|
|
55
|
+
raise SnowparkException(
|
|
56
|
+
"Unable to launch job on Snowpark Container Services. No 'compute_pool' specified."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
resources = Resources(
|
|
60
|
+
requests={
|
|
61
|
+
k: v
|
|
62
|
+
for k, v in [
|
|
63
|
+
("cpu", self.kwargs.get("cpu")),
|
|
64
|
+
("nvidia.com/gpu", self.kwargs.get("gpu")),
|
|
65
|
+
("memory", self.kwargs.get("memory")),
|
|
66
|
+
]
|
|
67
|
+
if v
|
|
68
|
+
},
|
|
69
|
+
limits={
|
|
70
|
+
k: v
|
|
71
|
+
for k, v in [
|
|
72
|
+
("nvidia.com/gpu", self.kwargs.get("gpu")),
|
|
73
|
+
]
|
|
74
|
+
if v
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
volume_mounts = self.kwargs.get("volume_mounts")
|
|
79
|
+
vm_objs = []
|
|
80
|
+
if volume_mounts:
|
|
81
|
+
if isinstance(volume_mounts, str):
|
|
82
|
+
volume_mounts = [volume_mounts]
|
|
83
|
+
for vm in volume_mounts:
|
|
84
|
+
name, mount_path = vm.split(":", 1)
|
|
85
|
+
vm_objs.append(VolumeMount(name=name, mount_path=mount_path))
|
|
86
|
+
|
|
87
|
+
container = (
|
|
88
|
+
Container(name=self.container_name, image=self.kwargs.get("image"))
|
|
89
|
+
.env(self.kwargs.get("environment_variables"))
|
|
90
|
+
.resources(resources)
|
|
91
|
+
.volume_mounts(vm_objs)
|
|
92
|
+
.command(self.command)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self.spec = SnowparkServiceSpec().containers([container])
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
def environment_variable(self, name, value):
|
|
99
|
+
# Never set to None
|
|
100
|
+
if value is None:
|
|
101
|
+
return self
|
|
102
|
+
self.kwargs["environment_variables"] = dict(
|
|
103
|
+
self.kwargs.get("environment_variables", {}), **{name: value}
|
|
104
|
+
)
|
|
105
|
+
return self
|
|
106
|
+
|
|
107
|
+
def create(self):
|
|
108
|
+
return self.create_job_spec()
|
|
109
|
+
|
|
110
|
+
def execute(self):
|
|
111
|
+
query_id, service_name = self.client.submit(
|
|
112
|
+
self.name,
|
|
113
|
+
self.spec,
|
|
114
|
+
self.kwargs.get("stage"),
|
|
115
|
+
self.kwargs.get("compute_pool"),
|
|
116
|
+
self.kwargs.get("external_integration"),
|
|
117
|
+
)
|
|
118
|
+
return RunningJob(
|
|
119
|
+
client=self.client,
|
|
120
|
+
query_id=query_id,
|
|
121
|
+
service_name=service_name,
|
|
122
|
+
**self.kwargs
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def image(self, image):
|
|
126
|
+
self.kwargs["image"] = image
|
|
127
|
+
return self
|
|
128
|
+
|
|
129
|
+
def stage(self, stage):
|
|
130
|
+
self.kwargs["stage"] = stage
|
|
131
|
+
return self
|
|
132
|
+
|
|
133
|
+
def compute_pool(self, compute_pool):
|
|
134
|
+
self.kwargs["compute_pool"] = compute_pool
|
|
135
|
+
return self
|
|
136
|
+
|
|
137
|
+
def volume_mounts(self, volume_mounts):
|
|
138
|
+
self.kwargs["volume_mounts"] = volume_mounts
|
|
139
|
+
return self
|
|
140
|
+
|
|
141
|
+
def external_integration(self, external_integration):
|
|
142
|
+
self.kwargs["external_integration"] = external_integration
|
|
143
|
+
return self
|
|
144
|
+
|
|
145
|
+
def cpu(self, cpu):
|
|
146
|
+
self.kwargs["cpu"] = cpu
|
|
147
|
+
return self
|
|
148
|
+
|
|
149
|
+
def gpu(self, gpu):
|
|
150
|
+
self.kwargs["gpu"] = gpu
|
|
151
|
+
return self
|
|
152
|
+
|
|
153
|
+
def memory(self, memory):
|
|
154
|
+
self.kwargs["memory"] = memory
|
|
155
|
+
return self
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class RunningJob(object):
|
|
159
|
+
def __init__(self, client, query_id, service_name, **kwargs):
|
|
160
|
+
self.client = client
|
|
161
|
+
self.query_id = query_id
|
|
162
|
+
self.service_name = service_name
|
|
163
|
+
self.kwargs = kwargs
|
|
164
|
+
|
|
165
|
+
from snowflake.core.exceptions import NotFoundError
|
|
166
|
+
|
|
167
|
+
self.service = retry_operation(NotFoundError, self.__get_service)
|
|
168
|
+
|
|
169
|
+
def __get_service(self):
|
|
170
|
+
db = self.client.session.get_current_database()
|
|
171
|
+
schema = self.client.session.get_current_schema()
|
|
172
|
+
return (
|
|
173
|
+
self.client.root.databases[db].schemas[schema].services[self.service_name]
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def __repr__(self):
|
|
177
|
+
return "{}('{}')".format(self.__class__.__name__, self.query_id)
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def id(self):
|
|
181
|
+
return self.query_id
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def job_name(self):
|
|
185
|
+
return self.service_name
|
|
186
|
+
|
|
187
|
+
def status_obj(self, timeout=0):
|
|
188
|
+
from snowflake.core.exceptions import APIError, NotFoundError
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
return retry_operation(
|
|
192
|
+
APIError, self.service.get_service_status, timeout=timeout
|
|
193
|
+
)
|
|
194
|
+
except NotFoundError:
|
|
195
|
+
raise SnowparkException(
|
|
196
|
+
"The image *%s* most probably doesn't exist on Snowpark, or too many resources (CPU, GPU, memory) were requested."
|
|
197
|
+
% self.kwargs.get("image")
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def status(self):
|
|
202
|
+
status_list = self.status_obj()
|
|
203
|
+
if not status_list:
|
|
204
|
+
return "UNKNOWN"
|
|
205
|
+
return status_list[0].get("status", "UNKNOWN")
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def message(self):
|
|
209
|
+
status_list = self.status_obj()
|
|
210
|
+
if not status_list:
|
|
211
|
+
return None
|
|
212
|
+
return status_list[0].get("message")
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def is_waiting(self):
|
|
216
|
+
return self.status in ["PENDING", "UNKNOWN"]
|
|
217
|
+
|
|
218
|
+
@property
|
|
219
|
+
def is_running(self):
|
|
220
|
+
return self.status in ["PENDING", "READY"]
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def has_failed(self):
|
|
224
|
+
return self.status == "FAILED"
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def has_succeeded(self):
|
|
228
|
+
return self.status == "DONE"
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def has_finished(self):
|
|
232
|
+
return self.has_succeeded or self.has_failed
|
|
233
|
+
|
|
234
|
+
def kill(self):
|
|
235
|
+
from snowflake.core.exceptions import NotFoundError
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
if not self.has_finished:
|
|
239
|
+
self.client.terminate_job(service=self.service)
|
|
240
|
+
except (NotFoundError, TypeError):
|
|
241
|
+
pass
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import List, Dict, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Container:
|
|
7
|
+
def __init__(self, name: str, image: str):
|
|
8
|
+
self.payload = defaultdict(lambda: defaultdict(dict))
|
|
9
|
+
self.payload["name"] = name
|
|
10
|
+
self.payload["image"] = image
|
|
11
|
+
|
|
12
|
+
def command(self, command: List[str]) -> "Container":
|
|
13
|
+
self.payload["command"] = command
|
|
14
|
+
return self
|
|
15
|
+
|
|
16
|
+
def args(self, args: List[str]) -> "Container":
|
|
17
|
+
self.payload["args"] = args
|
|
18
|
+
return self
|
|
19
|
+
|
|
20
|
+
def env(self, env: Dict[str, str]) -> "Container":
|
|
21
|
+
self.payload["env"] = env
|
|
22
|
+
return self
|
|
23
|
+
|
|
24
|
+
def readiness_probe(self, readiness_probe: "ReadinessProbe") -> "Container":
|
|
25
|
+
self.payload["readiness_probe"] = (
|
|
26
|
+
readiness_probe.to_dict() if readiness_probe else None
|
|
27
|
+
)
|
|
28
|
+
return self
|
|
29
|
+
|
|
30
|
+
def volume_mounts(self, volume_mounts: List["VolumeMount"]) -> "Container":
|
|
31
|
+
self.payload["volume_mounts"] = [vm.to_dict() for vm in volume_mounts]
|
|
32
|
+
return self
|
|
33
|
+
|
|
34
|
+
def resources(self, resources: "Resources") -> "Container":
|
|
35
|
+
self.payload["resources"] = resources.to_dict() if resources else None
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
def secrets(self, secrets: List["Secrets"]) -> "Container":
|
|
39
|
+
self.payload["secrets"] = [secret.to_dict() for secret in secrets]
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> Dict:
|
|
43
|
+
result = {"name": self.payload["name"], "image": self.payload["image"]}
|
|
44
|
+
if "command" in self.payload and self.payload["command"]:
|
|
45
|
+
result["command"] = self.payload["command"]
|
|
46
|
+
if "args" in self.payload and self.payload["args"]:
|
|
47
|
+
result["args"] = self.payload["args"]
|
|
48
|
+
if "env" in self.payload and self.payload["env"]:
|
|
49
|
+
result["env"] = self.payload["env"]
|
|
50
|
+
if "readiness_probe" in self.payload and self.payload["readiness_probe"]:
|
|
51
|
+
result["readiness_probe"] = self.payload["readiness_probe"]
|
|
52
|
+
if "volume_mounts" in self.payload and self.payload["volume_mounts"]:
|
|
53
|
+
result["volume_mounts"] = self.payload["volume_mounts"]
|
|
54
|
+
if "resources" in self.payload and self.payload["resources"]:
|
|
55
|
+
result["resources"] = self.payload["resources"]
|
|
56
|
+
if "secrets" in self.payload and self.payload["secrets"]:
|
|
57
|
+
result["secrets"] = self.payload["secrets"]
|
|
58
|
+
return result
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ReadinessProbe:
|
|
62
|
+
def __init__(self, port: int, path: str):
|
|
63
|
+
self.payload = defaultdict(dict)
|
|
64
|
+
self.payload["port"] = port
|
|
65
|
+
self.payload["path"] = path
|
|
66
|
+
|
|
67
|
+
def to_dict(self) -> Dict:
|
|
68
|
+
return dict(self.payload)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class VolumeMount:
|
|
72
|
+
def __init__(self, name: str, mount_path: str):
|
|
73
|
+
self.payload = defaultdict(dict)
|
|
74
|
+
self.payload["name"] = name
|
|
75
|
+
self.payload["mount_path"] = mount_path
|
|
76
|
+
|
|
77
|
+
def to_dict(self) -> Dict:
|
|
78
|
+
return dict(self.payload)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Resources:
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
requests: Optional[Dict[str, str]] = None,
|
|
85
|
+
limits: Optional[Dict[str, str]] = None,
|
|
86
|
+
):
|
|
87
|
+
self.payload = defaultdict(dict)
|
|
88
|
+
if requests:
|
|
89
|
+
self.payload["requests"] = requests
|
|
90
|
+
if limits:
|
|
91
|
+
self.payload["limits"] = limits
|
|
92
|
+
|
|
93
|
+
def to_dict(self) -> Dict:
|
|
94
|
+
result = {}
|
|
95
|
+
if "requests" in self.payload and self.payload["requests"]:
|
|
96
|
+
result["requests"] = self.payload["requests"]
|
|
97
|
+
if "limits" in self.payload and self.payload["limits"]:
|
|
98
|
+
result["limits"] = self.payload["limits"]
|
|
99
|
+
return result
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Secrets:
|
|
103
|
+
def __init__(self, snowflake_secret: str):
|
|
104
|
+
self.payload = {"snowflake_secret": snowflake_secret}
|
|
105
|
+
|
|
106
|
+
def secret_key_ref(self, secret_key_ref: str) -> "Secrets":
|
|
107
|
+
self.payload["secret_key_ref"] = secret_key_ref
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
def env_var_name(self, env_var_name: str) -> "Secrets":
|
|
111
|
+
self.payload["env_var_name"] = env_var_name
|
|
112
|
+
return self
|
|
113
|
+
|
|
114
|
+
def directory_path(self, directory_path: str) -> "Secrets":
|
|
115
|
+
self.payload["directory_path"] = directory_path
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def to_dict(self) -> Dict:
|
|
119
|
+
return {k: v for k, v in self.payload.items() if v}
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class Endpoint:
|
|
123
|
+
def __init__(self, name: str, port: int):
|
|
124
|
+
self.payload = defaultdict(dict)
|
|
125
|
+
self.payload["name"] = name
|
|
126
|
+
self.payload["port"] = port
|
|
127
|
+
|
|
128
|
+
def public(self, public: bool) -> "Endpoint":
|
|
129
|
+
self.payload["public"] = public
|
|
130
|
+
return self
|
|
131
|
+
|
|
132
|
+
def protocol(self, protocol: str) -> "Endpoint":
|
|
133
|
+
self.payload["protocol"] = protocol
|
|
134
|
+
return self
|
|
135
|
+
|
|
136
|
+
def to_dict(self) -> Dict:
|
|
137
|
+
return dict(self.payload)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Volume:
|
|
141
|
+
def __init__(self, name: str, source: str):
|
|
142
|
+
self.payload = defaultdict(dict)
|
|
143
|
+
self.payload["name"] = name
|
|
144
|
+
self.payload["source"] = source
|
|
145
|
+
|
|
146
|
+
def size(self, size: str) -> "Volume":
|
|
147
|
+
self.payload["size"] = size
|
|
148
|
+
return self
|
|
149
|
+
|
|
150
|
+
def block_config(self, block_config: Dict) -> "Volume":
|
|
151
|
+
self.payload["block_config"] = block_config
|
|
152
|
+
return self
|
|
153
|
+
|
|
154
|
+
def uid(self, uid: int) -> "Volume":
|
|
155
|
+
self.payload["uid"] = uid
|
|
156
|
+
return self
|
|
157
|
+
|
|
158
|
+
def gid(self, gid: int) -> "Volume":
|
|
159
|
+
self.payload["gid"] = gid
|
|
160
|
+
return self
|
|
161
|
+
|
|
162
|
+
def to_dict(self) -> Dict:
|
|
163
|
+
result = {"name": self.payload["name"], "source": self.payload["source"]}
|
|
164
|
+
if "size" in self.payload:
|
|
165
|
+
result["size"] = self.payload["size"]
|
|
166
|
+
if "block_config" in self.payload:
|
|
167
|
+
result["block_config"] = self.payload["block_config"]
|
|
168
|
+
if "uid" in self.payload:
|
|
169
|
+
result["uid"] = self.payload["uid"]
|
|
170
|
+
if "gid" in self.payload:
|
|
171
|
+
result["gid"] = self.payload["gid"]
|
|
172
|
+
return result
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class LogExporters:
|
|
176
|
+
def __init__(self):
|
|
177
|
+
self.payload = {}
|
|
178
|
+
|
|
179
|
+
def event_table_config(self, log_level: str) -> "LogExporters":
|
|
180
|
+
self.payload["eventTableConfig"] = {"logLevel": log_level}
|
|
181
|
+
return self
|
|
182
|
+
|
|
183
|
+
def to_dict(self) -> Dict:
|
|
184
|
+
return dict(self.payload)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class ServiceRole:
|
|
188
|
+
def __init__(self, name: str):
|
|
189
|
+
self.payload = {"name": name}
|
|
190
|
+
|
|
191
|
+
def endpoints(self, endpoints: List[str]) -> "ServiceRole":
|
|
192
|
+
self.payload["endpoints"] = endpoints
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
def to_dict(self) -> Dict:
|
|
196
|
+
return dict(self.payload)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class SnowparkServiceSpec:
|
|
200
|
+
def __init__(self):
|
|
201
|
+
self.payload = defaultdict(lambda: defaultdict(list))
|
|
202
|
+
|
|
203
|
+
def containers(self, containers: List[Container]) -> "SnowparkServiceSpec":
|
|
204
|
+
self.payload["containers"] = [container.to_dict() for container in containers]
|
|
205
|
+
return self
|
|
206
|
+
|
|
207
|
+
def endpoints(self, endpoints: List[Endpoint]) -> "SnowparkServiceSpec":
|
|
208
|
+
self.payload["endpoints"] = [endpoint.to_dict() for endpoint in endpoints]
|
|
209
|
+
return self
|
|
210
|
+
|
|
211
|
+
def volumes(self, volumes: List[Volume]) -> "SnowparkServiceSpec":
|
|
212
|
+
self.payload["volumes"] = [volume.to_dict() for volume in volumes]
|
|
213
|
+
return self
|
|
214
|
+
|
|
215
|
+
def log_exporters(self, log_exporters: LogExporters) -> "SnowparkServiceSpec":
|
|
216
|
+
self.payload["logExporters"] = log_exporters.to_dict()
|
|
217
|
+
return self
|
|
218
|
+
|
|
219
|
+
def service_roles(self, service_roles: List[ServiceRole]) -> "SnowparkServiceSpec":
|
|
220
|
+
self.payload["serviceRoles"] = [role.to_dict() for role in service_roles]
|
|
221
|
+
return self
|
|
222
|
+
|
|
223
|
+
def to_dict(self) -> Dict:
|
|
224
|
+
return {"spec": dict(self.payload)}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def generate_spec_file(spec: SnowparkServiceSpec, filename: str, format: str = "yaml"):
|
|
228
|
+
import yaml
|
|
229
|
+
|
|
230
|
+
spec_dict = spec.to_dict()
|
|
231
|
+
with open(filename, "w") as file:
|
|
232
|
+
if format == "json":
|
|
233
|
+
json.dump(spec_dict, file, indent=2)
|
|
234
|
+
elif format == "yaml":
|
|
235
|
+
yaml.dump(spec_dict, file, default_flow_style=False)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
if __name__ == "__main__":
|
|
239
|
+
# Example usage
|
|
240
|
+
container = (
|
|
241
|
+
Container(name="example-container", image="example-image")
|
|
242
|
+
.command(["python3", "app.py"])
|
|
243
|
+
.env({"ENV_VARIABLE": "value"})
|
|
244
|
+
.readiness_probe(ReadinessProbe(port=8080, path="/health"))
|
|
245
|
+
.resources(
|
|
246
|
+
Resources(requests={"memory": "2G", "cpu": "1"}, limits={"memory": "4G"})
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
endpoint = Endpoint(name="example-endpoint", port=8080).public(True)
|
|
251
|
+
volume = Volume(name="example-volume", source="local")
|
|
252
|
+
spec = (
|
|
253
|
+
SnowparkServiceSpec()
|
|
254
|
+
.containers([container])
|
|
255
|
+
.endpoints([endpoint])
|
|
256
|
+
.volumes([volume])
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
generate_spec_file(spec, "service_spec.yaml", format="yaml")
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from metaflow.decorators import StepDecorator
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TensorboardDecorator(StepDecorator):
|
|
6
|
+
name = "tensorboard"
|
|
7
|
+
defaults = {}
|
|
8
|
+
|
|
9
|
+
def task_decorate(
|
|
10
|
+
self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
|
|
11
|
+
):
|
|
12
|
+
@functools.wraps(step_func)
|
|
13
|
+
def tb_wrapper():
|
|
14
|
+
import sys, os
|
|
15
|
+
from metaflow import metaflow_config, current
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
19
|
+
except:
|
|
20
|
+
print(
|
|
21
|
+
"[@tensorboard] Torch and tensorboard not found - logging disabled!",
|
|
22
|
+
file=sys.stderr,
|
|
23
|
+
)
|
|
24
|
+
step_func()
|
|
25
|
+
else:
|
|
26
|
+
tb_root = os.path.join(metaflow_config.DATATOOLS_S3ROOT, "tb")
|
|
27
|
+
pathspec = current.pathspec
|
|
28
|
+
try:
|
|
29
|
+
log_dir = os.path.join(tb_root, current.project_flow_name, pathspec)
|
|
30
|
+
except:
|
|
31
|
+
log_dir = os.path.join(tb_root, pathspec)
|
|
32
|
+
comps = log_dir[len(tb_root) + 1 :].split("/")
|
|
33
|
+
run_level = "/".join(comps[:-2])
|
|
34
|
+
flow_level = "/".join(comps[:-3])
|
|
35
|
+
|
|
36
|
+
print("[@tensorboard] -- INSPECTING RESULTS")
|
|
37
|
+
print(
|
|
38
|
+
"[@tensorboard] -- Execute one of these commands on your workstation:"
|
|
39
|
+
)
|
|
40
|
+
print(f"[@tensorboard] Compare tasks of this run: obtb {run_level}")
|
|
41
|
+
print(f"[@tensorboard] Compare across runs: obtb {flow_level}")
|
|
42
|
+
writer = SummaryWriter(log_dir=log_dir)
|
|
43
|
+
setattr(flow, "obtb", writer)
|
|
44
|
+
try:
|
|
45
|
+
step_func()
|
|
46
|
+
finally:
|
|
47
|
+
writer.flush()
|
|
48
|
+
delattr(flow, "obtb")
|
|
49
|
+
|
|
50
|
+
return tb_wrapper
|