ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (128) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -7
  2. metaflow_extensions/outerbounds/config/__init__.py +35 -0
  3. metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
  4. metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
  5. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  6. metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
  7. metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  33. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  34. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  35. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  36. metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
  37. metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
  38. metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
  39. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  40. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  41. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
  42. metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
  43. metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
  44. metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
  45. metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
  46. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  47. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  48. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  49. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  50. metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  51. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  52. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
  53. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
  54. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
  55. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
  56. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
  57. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  58. metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
  59. metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
  60. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
  61. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  62. metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
  63. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
  64. metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
  65. metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
  66. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
  67. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
  68. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
  69. metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
  70. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  71. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  72. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  73. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  74. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  75. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  76. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  77. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
  78. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  79. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  80. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
  81. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  82. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  83. metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
  84. metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
  85. metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
  86. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  87. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  88. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  89. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  90. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  91. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  92. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  93. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  94. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  95. metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
  96. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
  97. metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
  98. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
  99. metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  100. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
  101. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
  102. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
  103. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
  104. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
  105. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
  106. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
  107. metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
  108. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  109. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  110. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  111. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  112. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  113. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  114. metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
  115. metaflow_extensions/outerbounds/remote_config.py +53 -16
  116. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
  117. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  118. metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
  119. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  120. metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
  121. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  122. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  123. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  124. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
  125. ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
  126. ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
  127. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
  128. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ from metaflow.metaflow_current import current
2
+ import sqlite3
3
+ from threading import Thread, Event
4
+ import time
5
+
6
+
7
+ class InfoCollectorThread(Thread):
8
+ def __init__(
9
+ self,
10
+ interval=1,
11
+ file_name=None,
12
+ sqlite_fetch_func=None, # Callable
13
+ ):
14
+ super().__init__()
15
+ self._exit_event = Event()
16
+ self._interval = interval
17
+ assert file_name is not None, "file_name must be provided"
18
+ self._file_name = file_name
19
+ self.daemon = True
20
+ self._data = {}
21
+ self._has_errored = False
22
+ self._current_error = None
23
+ self.sqlite_fetch_func = sqlite_fetch_func
24
+
25
+ def read(self):
26
+ return self._data
27
+
28
+ def has_errored(self):
29
+ return self._has_errored
30
+
31
+ def get_error(self):
32
+ return self._current_error
33
+
34
+ def _safely_load(self):
35
+ try:
36
+ conn = sqlite3.connect(self._file_name)
37
+ data = self.sqlite_fetch_func(conn)
38
+ return {"metrics": data}, None
39
+ except FileNotFoundError as e:
40
+ return {}, e
41
+ except sqlite3.Error as e:
42
+ return {}, e
43
+ finally:
44
+ conn.close()
45
+
46
+ def run(self):
47
+ while self._exit_event.is_set() is False:
48
+ data, self._current_error = self._safely_load()
49
+ if not self._current_error:
50
+ self._data = data
51
+ self._has_errored = True if self._current_error else False
52
+ time.sleep(self._interval)
53
+
54
+ def stop(self):
55
+ self._exit_event.set()
56
+ self.join()
57
+
58
+
59
+ class CardRefresher:
60
+
61
+ CARD_ID = None
62
+
63
+ def on_startup(self, current_card):
64
+ raise NotImplementedError("make_card method must be implemented")
65
+
66
+ def on_error(self, current_card, error_message):
67
+ raise NotImplementedError("error_card method must be implemented")
68
+
69
+ def on_update(self, current_card, data_object):
70
+ raise NotImplementedError("update_card method must be implemented")
71
+
72
+ def sqlite_fetch_func(self, conn):
73
+ raise NotImplementedError("sqlite_fetch_func must be implemented")
74
+
75
+
76
+ class CardUpdaterThread(Thread):
77
+ def __init__(
78
+ self,
79
+ card_refresher: CardRefresher,
80
+ interval=1,
81
+ file_name=None,
82
+ collector_thread: InfoCollectorThread = None,
83
+ ):
84
+ super().__init__()
85
+ self._exit_event = Event()
86
+ self._interval = interval
87
+ self._refresher = card_refresher
88
+ self._file_name = file_name
89
+ self._collector_thread = collector_thread
90
+ self.daemon = True
91
+
92
+ def run(self):
93
+ if self._refresher.CARD_ID is None:
94
+ raise ValueError("CARD_ID must be defined")
95
+ current_card = current.card[self._refresher.CARD_ID]
96
+ self._refresher.on_startup(current_card)
97
+ while self._exit_event.is_set() is False:
98
+ data = self._collector_thread.read()
99
+ if self._collector_thread.has_errored():
100
+ self._refresher.on_error(
101
+ current_card, self._collector_thread.get_error()
102
+ )
103
+ self._refresher.on_update(current_card, data)
104
+ time.sleep(self._interval)
105
+
106
+ def stop(self):
107
+ self._exit_event.set()
108
+ self._collector_thread.stop()
109
+ self.join()
110
+
111
+
112
+ class AsyncPeriodicRefresher:
113
+ def __init__(
114
+ self,
115
+ card_referesher: CardRefresher,
116
+ updater_interval=1,
117
+ collector_interval=1,
118
+ file_name=None,
119
+ ):
120
+ assert card_referesher.CARD_ID is not None, "CARD_ID must be defined"
121
+ self._collector_thread = InfoCollectorThread(
122
+ interval=collector_interval,
123
+ file_name=file_name,
124
+ sqlite_fetch_func=card_referesher.sqlite_fetch_func,
125
+ )
126
+ self._collector_thread.start()
127
+ self._updater_thread = CardUpdaterThread(
128
+ card_refresher=card_referesher,
129
+ interval=updater_interval,
130
+ file_name=file_name,
131
+ collector_thread=self._collector_thread,
132
+ )
133
+
134
+ def start(self):
135
+ self._updater_thread.start()
136
+
137
+ def stop(self):
138
+ data = self._collector_thread.read()
139
+ current_card = current.card[self._updater_thread._refresher.CARD_ID]
140
+ self._updater_thread._refresher.on_update(current_card, data)
141
+ self._updater_thread.stop()
142
+ self._collector_thread.stop()
@@ -0,0 +1,545 @@
1
+ import os
2
+ from metaflow.cards import (
3
+ Markdown,
4
+ Table,
5
+ VegaChart,
6
+ ProgressBar,
7
+ MetaflowCardComponent,
8
+ Artifact,
9
+ )
10
+ import math
11
+ from metaflow.plugins.cards.card_modules.components import (
12
+ with_default_component_id,
13
+ TaskToDict,
14
+ ArtifactsComponent,
15
+ render_safely,
16
+ )
17
+ import datetime
18
+ from metaflow.metaflow_current import current
19
+ import json
20
+ from functools import wraps
21
+ from collections import defaultdict
22
+ from threading import Thread, Event
23
+ import time
24
+
25
+ DEFAULT_WIDTH = 500
26
+ DEFAULT_HEIGHT = 200
27
+ DEFAULT_PADDING = 10
28
+ BG_COLOR = "#f2eeea" # sand-100
29
+ VIEW_FILL = "#faf7f4" # sand-200
30
+ GREYS = [
31
+ "#ebe8e5",
32
+ "#b2afac",
33
+ "#6a6867",
34
+ ]
35
+ BLACK = "#31302f"
36
+ GREENS = ["#dae8e2", "#3e8265", "#4c9878", "#428a6b", "#37795d"]
37
+ YELLOWS = ["#faf1db", "#f7e2b1", "#fbd784", "#e4b957", "#d7a530"]
38
+ PURPLES = ["#f5eff9", "#e7d4f3", "#976bac", "#8e53a9", "#77458f"]
39
+ REDS = ["#fce5e2", "#f3b6af", "#e6786c", "#e35f50", "#ce493a"]
40
+ BLUES = ["#dfe9f4", "#bdd8f2", "#88b7e3", "#6799c8", "#4e7ca7"]
41
+ ALL_COLORS = [
42
+ # GREENS[0], PURPLES[0], REDS[0], BLUES[0], YELLOWS[0],
43
+ GREENS[1],
44
+ PURPLES[1],
45
+ REDS[1],
46
+ BLUES[1],
47
+ YELLOWS[1],
48
+ GREENS[2],
49
+ PURPLES[2],
50
+ REDS[2],
51
+ BLUES[2],
52
+ YELLOWS[2],
53
+ GREENS[3],
54
+ PURPLES[3],
55
+ REDS[3],
56
+ BLUES[3],
57
+ YELLOWS[3],
58
+ GREENS[4],
59
+ PURPLES[4],
60
+ REDS[4],
61
+ BLUES[4],
62
+ YELLOWS[4],
63
+ ]
64
+
65
+
66
+ def update_spec_data(spec, data):
67
+ spec["data"]["values"].append(data)
68
+ return spec
69
+
70
+
71
+ def update_data_object(data_object, data):
72
+ data_object["values"].append(data)
73
+ return data_object
74
+
75
+
76
+ def line_chart_spec(
77
+ title=None,
78
+ category_name="u",
79
+ y_name="v",
80
+ xtitle=None,
81
+ ytitle=None,
82
+ width=DEFAULT_WIDTH,
83
+ height=DEFAULT_HEIGHT,
84
+ with_params=True,
85
+ x_axis_temporal=False,
86
+ ):
87
+ parameters = [
88
+ {
89
+ "name": "interpolate",
90
+ "value": "linear",
91
+ "bind": {
92
+ "input": "select",
93
+ "options": [
94
+ "basis",
95
+ "cardinal",
96
+ "catmull-rom",
97
+ "linear",
98
+ "monotone",
99
+ "natural",
100
+ "step",
101
+ "step-after",
102
+ "step-before",
103
+ ],
104
+ },
105
+ },
106
+ {
107
+ "name": "tension",
108
+ "value": 0,
109
+ "bind": {"input": "range", "min": 0, "max": 1, "step": 0.05},
110
+ },
111
+ {
112
+ "name": "strokeWidth",
113
+ "value": 2,
114
+ "bind": {"input": "range", "min": 0, "max": 10, "step": 0.5},
115
+ },
116
+ {
117
+ "name": "strokeCap",
118
+ "value": "butt",
119
+ "bind": {"input": "select", "options": ["butt", "round", "square"]},
120
+ },
121
+ {
122
+ "name": "strokeDash",
123
+ "value": [1, 0],
124
+ "bind": {
125
+ "input": "select",
126
+ "options": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]],
127
+ },
128
+ },
129
+ ]
130
+ parameter_marks = {
131
+ "interpolate": {"expr": "interpolate"},
132
+ "tension": {"expr": "tension"},
133
+ "strokeWidth": {"expr": "strokeWidth"},
134
+ "strokeDash": {"expr": "strokeDash"},
135
+ "strokeCap": {"expr": "strokeCap"},
136
+ }
137
+ spec = {
138
+ "title": title if title else "Line Chart",
139
+ "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
140
+ "width": DEFAULT_WIDTH,
141
+ "height": DEFAULT_HEIGHT,
142
+ "background": BG_COLOR,
143
+ "padding": DEFAULT_PADDING,
144
+ "view": {"fill": VIEW_FILL},
145
+ "params": parameters if with_params else [],
146
+ "data": {"name": "values", "values": []},
147
+ "mark": {
148
+ "type": "line",
149
+ "tooltip": True,
150
+ **(parameter_marks if with_params else {}),
151
+ },
152
+ "selection": {"grid": {"type": "interval", "bind": "scales"}},
153
+ "encoding": {
154
+ "x": {
155
+ "field": category_name,
156
+ "title": xtitle if xtitle else category_name,
157
+ **({"timeUnit": "seconds"} if x_axis_temporal else {}),
158
+ **({"type": "quantitative"} if not x_axis_temporal else {}),
159
+ },
160
+ "y": {
161
+ "field": y_name,
162
+ "type": "quantitative",
163
+ "title": ytitle if ytitle else y_name,
164
+ },
165
+ },
166
+ }
167
+ data = {"values": []}
168
+ return spec, data
169
+
170
+
171
+ class LineChart(MetaflowCardComponent):
172
+ REALTIME_UPDATABLE = True
173
+
174
+ def __init__(
175
+ self,
176
+ title,
177
+ xtitle,
178
+ ytitle,
179
+ category_name,
180
+ y_name,
181
+ with_params=False,
182
+ x_axis_temporal=False,
183
+ ):
184
+ super().__init__()
185
+
186
+ self.spec, _ = line_chart_spec(
187
+ title=title,
188
+ xtitle=xtitle,
189
+ ytitle=ytitle,
190
+ category_name=category_name,
191
+ y_name=y_name,
192
+ with_params=with_params,
193
+ x_axis_temporal=x_axis_temporal,
194
+ )
195
+
196
+ def update(self, data): # Can take a diff
197
+ self.spec = update_spec_data(self.spec, data)
198
+
199
+ @with_default_component_id
200
+ def render(self):
201
+ vega_chart = VegaChart(self.spec, show_controls=True)
202
+ vega_chart.component_id = self.component_id
203
+ return vega_chart.render()
204
+
205
+
206
+ class ArtifactTable(Artifact):
207
+ def __init__(self, data_dict):
208
+ self._data = data_dict
209
+ self._task_to_dict = TaskToDict(only_repr=True)
210
+
211
+ @with_default_component_id
212
+ @render_safely
213
+ def render(self):
214
+ _art_list = []
215
+ for k, v in self._data.items():
216
+ _art = self._task_to_dict.infer_object(v)
217
+ _art["name"] = k
218
+ _art_list.append(_art)
219
+
220
+ af_component = ArtifactsComponent(data=_art_list)
221
+ af_component.component_id = self.component_id
222
+ return af_component.render()
223
+
224
+
225
+ # fmt: off
226
+ class BarPlot(MetaflowCardComponent):
227
+ REALTIME_UPDATABLE = True
228
+
229
+ def __init__(self, title, category_name, value_name, orientation="vertical"):
230
+
231
+ if orientation not in ["vertical", "horizontal"]:
232
+ raise ValueError("orientation must be either 'vertical' or 'horizontal'")
233
+
234
+ super().__init__()
235
+ self.spec = {
236
+ "title": title,
237
+ "$schema": "https://vega.github.io/schema/vega/v5.json",
238
+ "description": "A basic bar chart example to show a count of values grouped by a category.",
239
+ "background": BG_COLOR,
240
+ "view": {"fill": VIEW_FILL},
241
+ "width": DEFAULT_WIDTH,
242
+ "height": DEFAULT_HEIGHT,
243
+ "padding": DEFAULT_PADDING,
244
+ "data": [{"name": "table", "values": []}],
245
+ "signals": [
246
+ {
247
+ "name": "tooltip",
248
+ "value": {},
249
+ "on": [
250
+ {"events": "rect:pointerover", "update": "datum"},
251
+ {"events": "rect:pointerout", "update": "{}"},
252
+ ],
253
+ }
254
+ ],
255
+ "scales": [
256
+ {
257
+ "name": "xscale" if orientation == "vertical" else "yscale",
258
+ "type": "band",
259
+ "domain": {"data": "table", "field": category_name},
260
+ "range": "width" if orientation == "vertical" else "height",
261
+ "padding": 0.25,
262
+ "round": True,
263
+ },
264
+ {
265
+ "name": "yscale" if orientation == "vertical" else "xscale",
266
+ "domain": {"data": "table", "field": value_name},
267
+ "nice": True,
268
+ "range": "height" if orientation == "vertical" else "width",
269
+ },
270
+ {
271
+ "name": "color",
272
+ "type": "ordinal",
273
+ "domain": {"data": "table", "field": category_name},
274
+ "range": ALL_COLORS,
275
+ },
276
+ ],
277
+ "axes": [
278
+ {"orient": "bottom", "scale": "xscale", "zindex": 1},
279
+ {"orient": "left", "scale": "yscale", "zindex": 1},
280
+ ],
281
+ "marks": [
282
+ {
283
+ "type": "rect",
284
+ "from": {"data": "table"},
285
+ "encode": {
286
+ "enter": {
287
+ "x": {
288
+ "scale": "xscale",
289
+ "field": (
290
+ category_name
291
+ if orientation == "vertical"
292
+ else value_name
293
+ ),
294
+ },
295
+ "y": {
296
+ "scale": "yscale",
297
+ "field": (
298
+ value_name
299
+ if orientation == "vertical"
300
+ else category_name
301
+ ),
302
+ },
303
+ f"{'y2' if orientation == 'vertical' else 'x2'}": {
304
+ "scale": (
305
+ "yscale" if orientation == "vertical" else "xscale"
306
+ ),
307
+ "value": 0,
308
+ },
309
+ "width": {"scale": "xscale", "band": 1},
310
+ "height": {"scale": "yscale", "band": 1},
311
+ },
312
+ "update": {
313
+ "fill": {"value": GREENS[0]},
314
+ },
315
+ "hover": {"fill": {"value": GREENS[2]}},
316
+ },
317
+ },
318
+ {
319
+ "type": "text",
320
+ "encode": {
321
+ "enter": {
322
+ "align": {"value": "center"},
323
+ "baseline": {"value": "bottom"},
324
+ "fill": {"value": BG_COLOR},
325
+ },
326
+ "update": {
327
+ "x": {
328
+ "scale": "xscale",
329
+ "signal": f"tooltip.{category_name if orientation == 'vertical' else value_name}",
330
+ f"{'band' if orientation == 'vertical' else 'offset'}": (
331
+ 0.5 if orientation == "vertical" else -10
332
+ ),
333
+ },
334
+ "y": {
335
+ "scale": "yscale",
336
+ "signal": f"tooltip.{value_name if orientation == 'vertical' else category_name}",
337
+ f"{'band' if orientation == 'horizontal' else 'offset'}": (
338
+ 0.5 if orientation == "horizontal" else 20
339
+ ),
340
+ },
341
+ "text": {"signal": f"tooltip.{value_name}"},
342
+ "fillOpacity": [
343
+ {"test": "datum === tooltip", "value": 0},
344
+ {"value": 1},
345
+ ],
346
+ },
347
+ },
348
+ },
349
+ ],
350
+ }
351
+
352
+ def update(self, data): # Can take a diff
353
+ self.spec = update_spec_data(self.spec, data)
354
+
355
+ @with_default_component_id
356
+ def render(self):
357
+ vega_chart = VegaChart(self.spec, show_controls=True)
358
+ vega_chart.component_id = self.component_id
359
+ return vega_chart.render()
360
+
361
+
362
+ class ViolinPlot(MetaflowCardComponent):
363
+ REALTIME_UPDATABLE = True
364
+
365
+ def __init__(self, title, category_col_name, value_col_name):
366
+ super().__init__()
367
+
368
+ self.spec = {
369
+ "title": title,
370
+ "$schema": "https://vega.github.io/schema/vega/v5.json",
371
+ "description": "A violin chart to show a distributional properties of each category.",
372
+ "background": BG_COLOR,
373
+ "view": {"fill": VIEW_FILL},
374
+ "width": DEFAULT_WIDTH,
375
+ "height": DEFAULT_HEIGHT,
376
+ "padding": DEFAULT_PADDING,
377
+ "config": {
378
+ "axisBand": {"bandPosition": 1, "tickExtra": True, "tickOffset": 0}
379
+ },
380
+ "signals": [
381
+ {"name": "plotWidth", "value": 75},
382
+ {"name": "height", "update": "(plotWidth + 10) * 3"},
383
+ {
384
+ "name": "bandwidth",
385
+ "value": 0.1,
386
+ "bind": {"input": "range", "min": 0, "max": 0.2, "step": 0.01},
387
+ },
388
+ ],
389
+ "data": [
390
+ {"name": "src", "values": []},
391
+ {
392
+ "name": "density",
393
+ "source": "src",
394
+ "transform": [
395
+ {
396
+ "type": "kde",
397
+ "groupby": [category_col_name],
398
+ "field": value_col_name,
399
+ "bandwidth": {"signal": "bandwidth"},
400
+ "extent": {"signal": "domain('xscale')"},
401
+ }
402
+ ],
403
+ },
404
+ {
405
+ "name": "stats",
406
+ "source": "src",
407
+ "transform": [
408
+ {
409
+ "type": "aggregate",
410
+ "groupby": [category_col_name],
411
+ "fields": [value_col_name, value_col_name, value_col_name],
412
+ "ops": ["q1", "q3", "median"],
413
+ "as": ["q1", "q3", "median"],
414
+ }
415
+ ],
416
+ },
417
+ ],
418
+ "scales": [
419
+ {
420
+ "name": "layout",
421
+ "type": "band",
422
+ "range": "height",
423
+ "domain": {"data": "src", "field": category_col_name},
424
+ },
425
+ {
426
+ "name": "xscale",
427
+ "type": "linear",
428
+ "range": "width",
429
+ "round": True,
430
+ "domain": {"data": "src", "field": value_col_name},
431
+ "zero": False,
432
+ "nice": True,
433
+ },
434
+ {
435
+ "name": "hscale",
436
+ "type": "linear",
437
+ "range": [0, {"signal": "plotWidth"}],
438
+ "domain": {"data": "density", "field": "density"},
439
+ },
440
+ {
441
+ "name": "color",
442
+ "type": "ordinal",
443
+ "domain": {"data": "src", "field": category_col_name},
444
+ "range": ALL_COLORS,
445
+ },
446
+ ],
447
+ "axes": [
448
+ {"orient": "bottom", "scale": "xscale", "zindex": 1},
449
+ {"orient": "left", "scale": "layout", "zindex": 1},
450
+ ],
451
+ "marks": [
452
+ {
453
+ "type": "group",
454
+ "from": {
455
+ "facet": {
456
+ "data": "density",
457
+ "name": "violin",
458
+ "groupby": category_col_name,
459
+ }
460
+ },
461
+ "encode": {
462
+ "enter": {
463
+ "yc": {
464
+ "scale": "layout",
465
+ "field": category_col_name,
466
+ "band": 0.5,
467
+ },
468
+ "height": {"signal": "plotWidth"},
469
+ "width": {"signal": "width"},
470
+ }
471
+ },
472
+ "data": [
473
+ {
474
+ "name": "summary",
475
+ "source": "stats",
476
+ "transform": [
477
+ {
478
+ "type": "filter",
479
+ "expr": f"datum.{category_col_name} === parent.{category_col_name}",
480
+ }
481
+ ],
482
+ }
483
+ ],
484
+ "marks": [
485
+ {
486
+ "type": "area",
487
+ "from": {"data": "violin"},
488
+ "encode": {
489
+ "enter": {
490
+ "fill": {
491
+ "scale": "color",
492
+ "field": {"parent": category_col_name},
493
+ }
494
+ },
495
+ "update": {
496
+ "x": {"scale": "xscale", "field": "value"},
497
+ "yc": {"signal": "plotWidth / 2"},
498
+ "height": {"scale": "hscale", "field": "density"},
499
+ },
500
+ },
501
+ },
502
+ {
503
+ "type": "rect",
504
+ "from": {"data": "summary"},
505
+ "encode": {
506
+ "enter": {
507
+ "fill": {"value": BLACK},
508
+ "height": {"value": 2},
509
+ },
510
+ "update": {
511
+ "x": {"scale": "xscale", "field": "q1"},
512
+ "x2": {"scale": "xscale", "field": "q3"},
513
+ "yc": {"signal": "plotWidth / 2"},
514
+ },
515
+ },
516
+ },
517
+ {
518
+ "type": "rect",
519
+ "from": {"data": "summary"},
520
+ "encode": {
521
+ "enter": {
522
+ "fill": {"value": BLACK},
523
+ "width": {"value": 2},
524
+ "height": {"value": 8},
525
+ },
526
+ "update": {
527
+ "x": {"scale": "xscale", "field": "median"},
528
+ "yc": {"signal": "plotWidth / 2"},
529
+ },
530
+ },
531
+ },
532
+ ],
533
+ }
534
+ ],
535
+ }
536
+
537
+ def update(self, data): # Can take a diff
538
+ self.spec = update_spec_data(self.spec, data)
539
+
540
+ @with_default_component_id
541
+ def render(self):
542
+ vega_chart = VegaChart(self.spec, show_controls=True)
543
+ vega_chart.component_id = self.component_id
544
+ return vega_chart.render()
545
+ # fmt: on