ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow-extensions might be problematic. Click here for more details.
- metaflow_extensions/outerbounds/__init__.py +1 -7
- metaflow_extensions/outerbounds/config/__init__.py +35 -0
- metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
- metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
- metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
- metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
- metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
- metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
- metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
- metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
- metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
- metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
- metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
- metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
- metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
- metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
- metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
- metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
- metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
- metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
- metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
- metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
- metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
- metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
- metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
- metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
- metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
- metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
- metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
- metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
- metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
- metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
- metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
- metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
- metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
- metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
- metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
- metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
- metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
- metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
- metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
- metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
- metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
- metaflow_extensions/outerbounds/remote_config.py +53 -16
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
- metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
- metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
- {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
- ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
- ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
- {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from metaflow.metaflow_current import current
|
|
2
|
+
import sqlite3
|
|
3
|
+
from threading import Thread, Event
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class InfoCollectorThread(Thread):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
interval=1,
|
|
11
|
+
file_name=None,
|
|
12
|
+
sqlite_fetch_func=None, # Callable
|
|
13
|
+
):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self._exit_event = Event()
|
|
16
|
+
self._interval = interval
|
|
17
|
+
assert file_name is not None, "file_name must be provided"
|
|
18
|
+
self._file_name = file_name
|
|
19
|
+
self.daemon = True
|
|
20
|
+
self._data = {}
|
|
21
|
+
self._has_errored = False
|
|
22
|
+
self._current_error = None
|
|
23
|
+
self.sqlite_fetch_func = sqlite_fetch_func
|
|
24
|
+
|
|
25
|
+
def read(self):
|
|
26
|
+
return self._data
|
|
27
|
+
|
|
28
|
+
def has_errored(self):
|
|
29
|
+
return self._has_errored
|
|
30
|
+
|
|
31
|
+
def get_error(self):
|
|
32
|
+
return self._current_error
|
|
33
|
+
|
|
34
|
+
def _safely_load(self):
|
|
35
|
+
try:
|
|
36
|
+
conn = sqlite3.connect(self._file_name)
|
|
37
|
+
data = self.sqlite_fetch_func(conn)
|
|
38
|
+
return {"metrics": data}, None
|
|
39
|
+
except FileNotFoundError as e:
|
|
40
|
+
return {}, e
|
|
41
|
+
except sqlite3.Error as e:
|
|
42
|
+
return {}, e
|
|
43
|
+
finally:
|
|
44
|
+
conn.close()
|
|
45
|
+
|
|
46
|
+
def run(self):
|
|
47
|
+
while self._exit_event.is_set() is False:
|
|
48
|
+
data, self._current_error = self._safely_load()
|
|
49
|
+
if not self._current_error:
|
|
50
|
+
self._data = data
|
|
51
|
+
self._has_errored = True if self._current_error else False
|
|
52
|
+
time.sleep(self._interval)
|
|
53
|
+
|
|
54
|
+
def stop(self):
|
|
55
|
+
self._exit_event.set()
|
|
56
|
+
self.join()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CardRefresher:
|
|
60
|
+
|
|
61
|
+
CARD_ID = None
|
|
62
|
+
|
|
63
|
+
def on_startup(self, current_card):
|
|
64
|
+
raise NotImplementedError("make_card method must be implemented")
|
|
65
|
+
|
|
66
|
+
def on_error(self, current_card, error_message):
|
|
67
|
+
raise NotImplementedError("error_card method must be implemented")
|
|
68
|
+
|
|
69
|
+
def on_update(self, current_card, data_object):
|
|
70
|
+
raise NotImplementedError("update_card method must be implemented")
|
|
71
|
+
|
|
72
|
+
def sqlite_fetch_func(self, conn):
|
|
73
|
+
raise NotImplementedError("sqlite_fetch_func must be implemented")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CardUpdaterThread(Thread):
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
card_refresher: CardRefresher,
|
|
80
|
+
interval=1,
|
|
81
|
+
file_name=None,
|
|
82
|
+
collector_thread: InfoCollectorThread = None,
|
|
83
|
+
):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self._exit_event = Event()
|
|
86
|
+
self._interval = interval
|
|
87
|
+
self._refresher = card_refresher
|
|
88
|
+
self._file_name = file_name
|
|
89
|
+
self._collector_thread = collector_thread
|
|
90
|
+
self.daemon = True
|
|
91
|
+
|
|
92
|
+
def run(self):
|
|
93
|
+
if self._refresher.CARD_ID is None:
|
|
94
|
+
raise ValueError("CARD_ID must be defined")
|
|
95
|
+
current_card = current.card[self._refresher.CARD_ID]
|
|
96
|
+
self._refresher.on_startup(current_card)
|
|
97
|
+
while self._exit_event.is_set() is False:
|
|
98
|
+
data = self._collector_thread.read()
|
|
99
|
+
if self._collector_thread.has_errored():
|
|
100
|
+
self._refresher.on_error(
|
|
101
|
+
current_card, self._collector_thread.get_error()
|
|
102
|
+
)
|
|
103
|
+
self._refresher.on_update(current_card, data)
|
|
104
|
+
time.sleep(self._interval)
|
|
105
|
+
|
|
106
|
+
def stop(self):
|
|
107
|
+
self._exit_event.set()
|
|
108
|
+
self._collector_thread.stop()
|
|
109
|
+
self.join()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class AsyncPeriodicRefresher:
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
card_referesher: CardRefresher,
|
|
116
|
+
updater_interval=1,
|
|
117
|
+
collector_interval=1,
|
|
118
|
+
file_name=None,
|
|
119
|
+
):
|
|
120
|
+
assert card_referesher.CARD_ID is not None, "CARD_ID must be defined"
|
|
121
|
+
self._collector_thread = InfoCollectorThread(
|
|
122
|
+
interval=collector_interval,
|
|
123
|
+
file_name=file_name,
|
|
124
|
+
sqlite_fetch_func=card_referesher.sqlite_fetch_func,
|
|
125
|
+
)
|
|
126
|
+
self._collector_thread.start()
|
|
127
|
+
self._updater_thread = CardUpdaterThread(
|
|
128
|
+
card_refresher=card_referesher,
|
|
129
|
+
interval=updater_interval,
|
|
130
|
+
file_name=file_name,
|
|
131
|
+
collector_thread=self._collector_thread,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def start(self):
|
|
135
|
+
self._updater_thread.start()
|
|
136
|
+
|
|
137
|
+
def stop(self):
|
|
138
|
+
data = self._collector_thread.read()
|
|
139
|
+
current_card = current.card[self._updater_thread._refresher.CARD_ID]
|
|
140
|
+
self._updater_thread._refresher.on_update(current_card, data)
|
|
141
|
+
self._updater_thread.stop()
|
|
142
|
+
self._collector_thread.stop()
|
|
@@ -0,0 +1,545 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from metaflow.cards import (
|
|
3
|
+
Markdown,
|
|
4
|
+
Table,
|
|
5
|
+
VegaChart,
|
|
6
|
+
ProgressBar,
|
|
7
|
+
MetaflowCardComponent,
|
|
8
|
+
Artifact,
|
|
9
|
+
)
|
|
10
|
+
import math
|
|
11
|
+
from metaflow.plugins.cards.card_modules.components import (
|
|
12
|
+
with_default_component_id,
|
|
13
|
+
TaskToDict,
|
|
14
|
+
ArtifactsComponent,
|
|
15
|
+
render_safely,
|
|
16
|
+
)
|
|
17
|
+
import datetime
|
|
18
|
+
from metaflow.metaflow_current import current
|
|
19
|
+
import json
|
|
20
|
+
from functools import wraps
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from threading import Thread, Event
|
|
23
|
+
import time
|
|
24
|
+
|
|
25
|
+
DEFAULT_WIDTH = 500
|
|
26
|
+
DEFAULT_HEIGHT = 200
|
|
27
|
+
DEFAULT_PADDING = 10
|
|
28
|
+
BG_COLOR = "#f2eeea" # sand-100
|
|
29
|
+
VIEW_FILL = "#faf7f4" # sand-200
|
|
30
|
+
GREYS = [
|
|
31
|
+
"#ebe8e5",
|
|
32
|
+
"#b2afac",
|
|
33
|
+
"#6a6867",
|
|
34
|
+
]
|
|
35
|
+
BLACK = "#31302f"
|
|
36
|
+
GREENS = ["#dae8e2", "#3e8265", "#4c9878", "#428a6b", "#37795d"]
|
|
37
|
+
YELLOWS = ["#faf1db", "#f7e2b1", "#fbd784", "#e4b957", "#d7a530"]
|
|
38
|
+
PURPLES = ["#f5eff9", "#e7d4f3", "#976bac", "#8e53a9", "#77458f"]
|
|
39
|
+
REDS = ["#fce5e2", "#f3b6af", "#e6786c", "#e35f50", "#ce493a"]
|
|
40
|
+
BLUES = ["#dfe9f4", "#bdd8f2", "#88b7e3", "#6799c8", "#4e7ca7"]
|
|
41
|
+
ALL_COLORS = [
|
|
42
|
+
# GREENS[0], PURPLES[0], REDS[0], BLUES[0], YELLOWS[0],
|
|
43
|
+
GREENS[1],
|
|
44
|
+
PURPLES[1],
|
|
45
|
+
REDS[1],
|
|
46
|
+
BLUES[1],
|
|
47
|
+
YELLOWS[1],
|
|
48
|
+
GREENS[2],
|
|
49
|
+
PURPLES[2],
|
|
50
|
+
REDS[2],
|
|
51
|
+
BLUES[2],
|
|
52
|
+
YELLOWS[2],
|
|
53
|
+
GREENS[3],
|
|
54
|
+
PURPLES[3],
|
|
55
|
+
REDS[3],
|
|
56
|
+
BLUES[3],
|
|
57
|
+
YELLOWS[3],
|
|
58
|
+
GREENS[4],
|
|
59
|
+
PURPLES[4],
|
|
60
|
+
REDS[4],
|
|
61
|
+
BLUES[4],
|
|
62
|
+
YELLOWS[4],
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def update_spec_data(spec, data):
|
|
67
|
+
spec["data"]["values"].append(data)
|
|
68
|
+
return spec
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def update_data_object(data_object, data):
|
|
72
|
+
data_object["values"].append(data)
|
|
73
|
+
return data_object
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def line_chart_spec(
|
|
77
|
+
title=None,
|
|
78
|
+
category_name="u",
|
|
79
|
+
y_name="v",
|
|
80
|
+
xtitle=None,
|
|
81
|
+
ytitle=None,
|
|
82
|
+
width=DEFAULT_WIDTH,
|
|
83
|
+
height=DEFAULT_HEIGHT,
|
|
84
|
+
with_params=True,
|
|
85
|
+
x_axis_temporal=False,
|
|
86
|
+
):
|
|
87
|
+
parameters = [
|
|
88
|
+
{
|
|
89
|
+
"name": "interpolate",
|
|
90
|
+
"value": "linear",
|
|
91
|
+
"bind": {
|
|
92
|
+
"input": "select",
|
|
93
|
+
"options": [
|
|
94
|
+
"basis",
|
|
95
|
+
"cardinal",
|
|
96
|
+
"catmull-rom",
|
|
97
|
+
"linear",
|
|
98
|
+
"monotone",
|
|
99
|
+
"natural",
|
|
100
|
+
"step",
|
|
101
|
+
"step-after",
|
|
102
|
+
"step-before",
|
|
103
|
+
],
|
|
104
|
+
},
|
|
105
|
+
},
|
|
106
|
+
{
|
|
107
|
+
"name": "tension",
|
|
108
|
+
"value": 0,
|
|
109
|
+
"bind": {"input": "range", "min": 0, "max": 1, "step": 0.05},
|
|
110
|
+
},
|
|
111
|
+
{
|
|
112
|
+
"name": "strokeWidth",
|
|
113
|
+
"value": 2,
|
|
114
|
+
"bind": {"input": "range", "min": 0, "max": 10, "step": 0.5},
|
|
115
|
+
},
|
|
116
|
+
{
|
|
117
|
+
"name": "strokeCap",
|
|
118
|
+
"value": "butt",
|
|
119
|
+
"bind": {"input": "select", "options": ["butt", "round", "square"]},
|
|
120
|
+
},
|
|
121
|
+
{
|
|
122
|
+
"name": "strokeDash",
|
|
123
|
+
"value": [1, 0],
|
|
124
|
+
"bind": {
|
|
125
|
+
"input": "select",
|
|
126
|
+
"options": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]],
|
|
127
|
+
},
|
|
128
|
+
},
|
|
129
|
+
]
|
|
130
|
+
parameter_marks = {
|
|
131
|
+
"interpolate": {"expr": "interpolate"},
|
|
132
|
+
"tension": {"expr": "tension"},
|
|
133
|
+
"strokeWidth": {"expr": "strokeWidth"},
|
|
134
|
+
"strokeDash": {"expr": "strokeDash"},
|
|
135
|
+
"strokeCap": {"expr": "strokeCap"},
|
|
136
|
+
}
|
|
137
|
+
spec = {
|
|
138
|
+
"title": title if title else "Line Chart",
|
|
139
|
+
"$schema": "https://vega.github.io/schema/vega-lite/v5.json",
|
|
140
|
+
"width": DEFAULT_WIDTH,
|
|
141
|
+
"height": DEFAULT_HEIGHT,
|
|
142
|
+
"background": BG_COLOR,
|
|
143
|
+
"padding": DEFAULT_PADDING,
|
|
144
|
+
"view": {"fill": VIEW_FILL},
|
|
145
|
+
"params": parameters if with_params else [],
|
|
146
|
+
"data": {"name": "values", "values": []},
|
|
147
|
+
"mark": {
|
|
148
|
+
"type": "line",
|
|
149
|
+
"tooltip": True,
|
|
150
|
+
**(parameter_marks if with_params else {}),
|
|
151
|
+
},
|
|
152
|
+
"selection": {"grid": {"type": "interval", "bind": "scales"}},
|
|
153
|
+
"encoding": {
|
|
154
|
+
"x": {
|
|
155
|
+
"field": category_name,
|
|
156
|
+
"title": xtitle if xtitle else category_name,
|
|
157
|
+
**({"timeUnit": "seconds"} if x_axis_temporal else {}),
|
|
158
|
+
**({"type": "quantitative"} if not x_axis_temporal else {}),
|
|
159
|
+
},
|
|
160
|
+
"y": {
|
|
161
|
+
"field": y_name,
|
|
162
|
+
"type": "quantitative",
|
|
163
|
+
"title": ytitle if ytitle else y_name,
|
|
164
|
+
},
|
|
165
|
+
},
|
|
166
|
+
}
|
|
167
|
+
data = {"values": []}
|
|
168
|
+
return spec, data
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class LineChart(MetaflowCardComponent):
|
|
172
|
+
REALTIME_UPDATABLE = True
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
title,
|
|
177
|
+
xtitle,
|
|
178
|
+
ytitle,
|
|
179
|
+
category_name,
|
|
180
|
+
y_name,
|
|
181
|
+
with_params=False,
|
|
182
|
+
x_axis_temporal=False,
|
|
183
|
+
):
|
|
184
|
+
super().__init__()
|
|
185
|
+
|
|
186
|
+
self.spec, _ = line_chart_spec(
|
|
187
|
+
title=title,
|
|
188
|
+
xtitle=xtitle,
|
|
189
|
+
ytitle=ytitle,
|
|
190
|
+
category_name=category_name,
|
|
191
|
+
y_name=y_name,
|
|
192
|
+
with_params=with_params,
|
|
193
|
+
x_axis_temporal=x_axis_temporal,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def update(self, data): # Can take a diff
|
|
197
|
+
self.spec = update_spec_data(self.spec, data)
|
|
198
|
+
|
|
199
|
+
@with_default_component_id
|
|
200
|
+
def render(self):
|
|
201
|
+
vega_chart = VegaChart(self.spec, show_controls=True)
|
|
202
|
+
vega_chart.component_id = self.component_id
|
|
203
|
+
return vega_chart.render()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class ArtifactTable(Artifact):
|
|
207
|
+
def __init__(self, data_dict):
|
|
208
|
+
self._data = data_dict
|
|
209
|
+
self._task_to_dict = TaskToDict(only_repr=True)
|
|
210
|
+
|
|
211
|
+
@with_default_component_id
|
|
212
|
+
@render_safely
|
|
213
|
+
def render(self):
|
|
214
|
+
_art_list = []
|
|
215
|
+
for k, v in self._data.items():
|
|
216
|
+
_art = self._task_to_dict.infer_object(v)
|
|
217
|
+
_art["name"] = k
|
|
218
|
+
_art_list.append(_art)
|
|
219
|
+
|
|
220
|
+
af_component = ArtifactsComponent(data=_art_list)
|
|
221
|
+
af_component.component_id = self.component_id
|
|
222
|
+
return af_component.render()
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# fmt: off
|
|
226
|
+
class BarPlot(MetaflowCardComponent):
|
|
227
|
+
REALTIME_UPDATABLE = True
|
|
228
|
+
|
|
229
|
+
def __init__(self, title, category_name, value_name, orientation="vertical"):
|
|
230
|
+
|
|
231
|
+
if orientation not in ["vertical", "horizontal"]:
|
|
232
|
+
raise ValueError("orientation must be either 'vertical' or 'horizontal'")
|
|
233
|
+
|
|
234
|
+
super().__init__()
|
|
235
|
+
self.spec = {
|
|
236
|
+
"title": title,
|
|
237
|
+
"$schema": "https://vega.github.io/schema/vega/v5.json",
|
|
238
|
+
"description": "A basic bar chart example to show a count of values grouped by a category.",
|
|
239
|
+
"background": BG_COLOR,
|
|
240
|
+
"view": {"fill": VIEW_FILL},
|
|
241
|
+
"width": DEFAULT_WIDTH,
|
|
242
|
+
"height": DEFAULT_HEIGHT,
|
|
243
|
+
"padding": DEFAULT_PADDING,
|
|
244
|
+
"data": [{"name": "table", "values": []}],
|
|
245
|
+
"signals": [
|
|
246
|
+
{
|
|
247
|
+
"name": "tooltip",
|
|
248
|
+
"value": {},
|
|
249
|
+
"on": [
|
|
250
|
+
{"events": "rect:pointerover", "update": "datum"},
|
|
251
|
+
{"events": "rect:pointerout", "update": "{}"},
|
|
252
|
+
],
|
|
253
|
+
}
|
|
254
|
+
],
|
|
255
|
+
"scales": [
|
|
256
|
+
{
|
|
257
|
+
"name": "xscale" if orientation == "vertical" else "yscale",
|
|
258
|
+
"type": "band",
|
|
259
|
+
"domain": {"data": "table", "field": category_name},
|
|
260
|
+
"range": "width" if orientation == "vertical" else "height",
|
|
261
|
+
"padding": 0.25,
|
|
262
|
+
"round": True,
|
|
263
|
+
},
|
|
264
|
+
{
|
|
265
|
+
"name": "yscale" if orientation == "vertical" else "xscale",
|
|
266
|
+
"domain": {"data": "table", "field": value_name},
|
|
267
|
+
"nice": True,
|
|
268
|
+
"range": "height" if orientation == "vertical" else "width",
|
|
269
|
+
},
|
|
270
|
+
{
|
|
271
|
+
"name": "color",
|
|
272
|
+
"type": "ordinal",
|
|
273
|
+
"domain": {"data": "table", "field": category_name},
|
|
274
|
+
"range": ALL_COLORS,
|
|
275
|
+
},
|
|
276
|
+
],
|
|
277
|
+
"axes": [
|
|
278
|
+
{"orient": "bottom", "scale": "xscale", "zindex": 1},
|
|
279
|
+
{"orient": "left", "scale": "yscale", "zindex": 1},
|
|
280
|
+
],
|
|
281
|
+
"marks": [
|
|
282
|
+
{
|
|
283
|
+
"type": "rect",
|
|
284
|
+
"from": {"data": "table"},
|
|
285
|
+
"encode": {
|
|
286
|
+
"enter": {
|
|
287
|
+
"x": {
|
|
288
|
+
"scale": "xscale",
|
|
289
|
+
"field": (
|
|
290
|
+
category_name
|
|
291
|
+
if orientation == "vertical"
|
|
292
|
+
else value_name
|
|
293
|
+
),
|
|
294
|
+
},
|
|
295
|
+
"y": {
|
|
296
|
+
"scale": "yscale",
|
|
297
|
+
"field": (
|
|
298
|
+
value_name
|
|
299
|
+
if orientation == "vertical"
|
|
300
|
+
else category_name
|
|
301
|
+
),
|
|
302
|
+
},
|
|
303
|
+
f"{'y2' if orientation == 'vertical' else 'x2'}": {
|
|
304
|
+
"scale": (
|
|
305
|
+
"yscale" if orientation == "vertical" else "xscale"
|
|
306
|
+
),
|
|
307
|
+
"value": 0,
|
|
308
|
+
},
|
|
309
|
+
"width": {"scale": "xscale", "band": 1},
|
|
310
|
+
"height": {"scale": "yscale", "band": 1},
|
|
311
|
+
},
|
|
312
|
+
"update": {
|
|
313
|
+
"fill": {"value": GREENS[0]},
|
|
314
|
+
},
|
|
315
|
+
"hover": {"fill": {"value": GREENS[2]}},
|
|
316
|
+
},
|
|
317
|
+
},
|
|
318
|
+
{
|
|
319
|
+
"type": "text",
|
|
320
|
+
"encode": {
|
|
321
|
+
"enter": {
|
|
322
|
+
"align": {"value": "center"},
|
|
323
|
+
"baseline": {"value": "bottom"},
|
|
324
|
+
"fill": {"value": BG_COLOR},
|
|
325
|
+
},
|
|
326
|
+
"update": {
|
|
327
|
+
"x": {
|
|
328
|
+
"scale": "xscale",
|
|
329
|
+
"signal": f"tooltip.{category_name if orientation == 'vertical' else value_name}",
|
|
330
|
+
f"{'band' if orientation == 'vertical' else 'offset'}": (
|
|
331
|
+
0.5 if orientation == "vertical" else -10
|
|
332
|
+
),
|
|
333
|
+
},
|
|
334
|
+
"y": {
|
|
335
|
+
"scale": "yscale",
|
|
336
|
+
"signal": f"tooltip.{value_name if orientation == 'vertical' else category_name}",
|
|
337
|
+
f"{'band' if orientation == 'horizontal' else 'offset'}": (
|
|
338
|
+
0.5 if orientation == "horizontal" else 20
|
|
339
|
+
),
|
|
340
|
+
},
|
|
341
|
+
"text": {"signal": f"tooltip.{value_name}"},
|
|
342
|
+
"fillOpacity": [
|
|
343
|
+
{"test": "datum === tooltip", "value": 0},
|
|
344
|
+
{"value": 1},
|
|
345
|
+
],
|
|
346
|
+
},
|
|
347
|
+
},
|
|
348
|
+
},
|
|
349
|
+
],
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
def update(self, data): # Can take a diff
|
|
353
|
+
self.spec = update_spec_data(self.spec, data)
|
|
354
|
+
|
|
355
|
+
@with_default_component_id
|
|
356
|
+
def render(self):
|
|
357
|
+
vega_chart = VegaChart(self.spec, show_controls=True)
|
|
358
|
+
vega_chart.component_id = self.component_id
|
|
359
|
+
return vega_chart.render()
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class ViolinPlot(MetaflowCardComponent):
|
|
363
|
+
REALTIME_UPDATABLE = True
|
|
364
|
+
|
|
365
|
+
def __init__(self, title, category_col_name, value_col_name):
|
|
366
|
+
super().__init__()
|
|
367
|
+
|
|
368
|
+
self.spec = {
|
|
369
|
+
"title": title,
|
|
370
|
+
"$schema": "https://vega.github.io/schema/vega/v5.json",
|
|
371
|
+
"description": "A violin chart to show a distributional properties of each category.",
|
|
372
|
+
"background": BG_COLOR,
|
|
373
|
+
"view": {"fill": VIEW_FILL},
|
|
374
|
+
"width": DEFAULT_WIDTH,
|
|
375
|
+
"height": DEFAULT_HEIGHT,
|
|
376
|
+
"padding": DEFAULT_PADDING,
|
|
377
|
+
"config": {
|
|
378
|
+
"axisBand": {"bandPosition": 1, "tickExtra": True, "tickOffset": 0}
|
|
379
|
+
},
|
|
380
|
+
"signals": [
|
|
381
|
+
{"name": "plotWidth", "value": 75},
|
|
382
|
+
{"name": "height", "update": "(plotWidth + 10) * 3"},
|
|
383
|
+
{
|
|
384
|
+
"name": "bandwidth",
|
|
385
|
+
"value": 0.1,
|
|
386
|
+
"bind": {"input": "range", "min": 0, "max": 0.2, "step": 0.01},
|
|
387
|
+
},
|
|
388
|
+
],
|
|
389
|
+
"data": [
|
|
390
|
+
{"name": "src", "values": []},
|
|
391
|
+
{
|
|
392
|
+
"name": "density",
|
|
393
|
+
"source": "src",
|
|
394
|
+
"transform": [
|
|
395
|
+
{
|
|
396
|
+
"type": "kde",
|
|
397
|
+
"groupby": [category_col_name],
|
|
398
|
+
"field": value_col_name,
|
|
399
|
+
"bandwidth": {"signal": "bandwidth"},
|
|
400
|
+
"extent": {"signal": "domain('xscale')"},
|
|
401
|
+
}
|
|
402
|
+
],
|
|
403
|
+
},
|
|
404
|
+
{
|
|
405
|
+
"name": "stats",
|
|
406
|
+
"source": "src",
|
|
407
|
+
"transform": [
|
|
408
|
+
{
|
|
409
|
+
"type": "aggregate",
|
|
410
|
+
"groupby": [category_col_name],
|
|
411
|
+
"fields": [value_col_name, value_col_name, value_col_name],
|
|
412
|
+
"ops": ["q1", "q3", "median"],
|
|
413
|
+
"as": ["q1", "q3", "median"],
|
|
414
|
+
}
|
|
415
|
+
],
|
|
416
|
+
},
|
|
417
|
+
],
|
|
418
|
+
"scales": [
|
|
419
|
+
{
|
|
420
|
+
"name": "layout",
|
|
421
|
+
"type": "band",
|
|
422
|
+
"range": "height",
|
|
423
|
+
"domain": {"data": "src", "field": category_col_name},
|
|
424
|
+
},
|
|
425
|
+
{
|
|
426
|
+
"name": "xscale",
|
|
427
|
+
"type": "linear",
|
|
428
|
+
"range": "width",
|
|
429
|
+
"round": True,
|
|
430
|
+
"domain": {"data": "src", "field": value_col_name},
|
|
431
|
+
"zero": False,
|
|
432
|
+
"nice": True,
|
|
433
|
+
},
|
|
434
|
+
{
|
|
435
|
+
"name": "hscale",
|
|
436
|
+
"type": "linear",
|
|
437
|
+
"range": [0, {"signal": "plotWidth"}],
|
|
438
|
+
"domain": {"data": "density", "field": "density"},
|
|
439
|
+
},
|
|
440
|
+
{
|
|
441
|
+
"name": "color",
|
|
442
|
+
"type": "ordinal",
|
|
443
|
+
"domain": {"data": "src", "field": category_col_name},
|
|
444
|
+
"range": ALL_COLORS,
|
|
445
|
+
},
|
|
446
|
+
],
|
|
447
|
+
"axes": [
|
|
448
|
+
{"orient": "bottom", "scale": "xscale", "zindex": 1},
|
|
449
|
+
{"orient": "left", "scale": "layout", "zindex": 1},
|
|
450
|
+
],
|
|
451
|
+
"marks": [
|
|
452
|
+
{
|
|
453
|
+
"type": "group",
|
|
454
|
+
"from": {
|
|
455
|
+
"facet": {
|
|
456
|
+
"data": "density",
|
|
457
|
+
"name": "violin",
|
|
458
|
+
"groupby": category_col_name,
|
|
459
|
+
}
|
|
460
|
+
},
|
|
461
|
+
"encode": {
|
|
462
|
+
"enter": {
|
|
463
|
+
"yc": {
|
|
464
|
+
"scale": "layout",
|
|
465
|
+
"field": category_col_name,
|
|
466
|
+
"band": 0.5,
|
|
467
|
+
},
|
|
468
|
+
"height": {"signal": "plotWidth"},
|
|
469
|
+
"width": {"signal": "width"},
|
|
470
|
+
}
|
|
471
|
+
},
|
|
472
|
+
"data": [
|
|
473
|
+
{
|
|
474
|
+
"name": "summary",
|
|
475
|
+
"source": "stats",
|
|
476
|
+
"transform": [
|
|
477
|
+
{
|
|
478
|
+
"type": "filter",
|
|
479
|
+
"expr": f"datum.{category_col_name} === parent.{category_col_name}",
|
|
480
|
+
}
|
|
481
|
+
],
|
|
482
|
+
}
|
|
483
|
+
],
|
|
484
|
+
"marks": [
|
|
485
|
+
{
|
|
486
|
+
"type": "area",
|
|
487
|
+
"from": {"data": "violin"},
|
|
488
|
+
"encode": {
|
|
489
|
+
"enter": {
|
|
490
|
+
"fill": {
|
|
491
|
+
"scale": "color",
|
|
492
|
+
"field": {"parent": category_col_name},
|
|
493
|
+
}
|
|
494
|
+
},
|
|
495
|
+
"update": {
|
|
496
|
+
"x": {"scale": "xscale", "field": "value"},
|
|
497
|
+
"yc": {"signal": "plotWidth / 2"},
|
|
498
|
+
"height": {"scale": "hscale", "field": "density"},
|
|
499
|
+
},
|
|
500
|
+
},
|
|
501
|
+
},
|
|
502
|
+
{
|
|
503
|
+
"type": "rect",
|
|
504
|
+
"from": {"data": "summary"},
|
|
505
|
+
"encode": {
|
|
506
|
+
"enter": {
|
|
507
|
+
"fill": {"value": BLACK},
|
|
508
|
+
"height": {"value": 2},
|
|
509
|
+
},
|
|
510
|
+
"update": {
|
|
511
|
+
"x": {"scale": "xscale", "field": "q1"},
|
|
512
|
+
"x2": {"scale": "xscale", "field": "q3"},
|
|
513
|
+
"yc": {"signal": "plotWidth / 2"},
|
|
514
|
+
},
|
|
515
|
+
},
|
|
516
|
+
},
|
|
517
|
+
{
|
|
518
|
+
"type": "rect",
|
|
519
|
+
"from": {"data": "summary"},
|
|
520
|
+
"encode": {
|
|
521
|
+
"enter": {
|
|
522
|
+
"fill": {"value": BLACK},
|
|
523
|
+
"width": {"value": 2},
|
|
524
|
+
"height": {"value": 8},
|
|
525
|
+
},
|
|
526
|
+
"update": {
|
|
527
|
+
"x": {"scale": "xscale", "field": "median"},
|
|
528
|
+
"yc": {"signal": "plotWidth / 2"},
|
|
529
|
+
},
|
|
530
|
+
},
|
|
531
|
+
},
|
|
532
|
+
],
|
|
533
|
+
}
|
|
534
|
+
],
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
def update(self, data): # Can take a diff
|
|
538
|
+
self.spec = update_spec_data(self.spec, data)
|
|
539
|
+
|
|
540
|
+
@with_default_component_id
|
|
541
|
+
def render(self):
|
|
542
|
+
vega_chart = VegaChart(self.spec, show_controls=True)
|
|
543
|
+
vega_chart.component_id = self.component_id
|
|
544
|
+
return vega_chart.render()
|
|
545
|
+
# fmt: on
|