ob-metaflow-extensions 1.1.128__py2.py3-none-any.whl → 1.1.129__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

@@ -0,0 +1,545 @@
1
+ import os
2
+ from metaflow.cards import (
3
+ Markdown,
4
+ Table,
5
+ VegaChart,
6
+ ProgressBar,
7
+ MetaflowCardComponent,
8
+ Artifact,
9
+ )
10
+ import math
11
+ from metaflow.plugins.cards.card_modules.components import (
12
+ with_default_component_id,
13
+ TaskToDict,
14
+ ArtifactsComponent,
15
+ render_safely,
16
+ )
17
+ import datetime
18
+ from metaflow.metaflow_current import current
19
+ import json
20
+ from functools import wraps
21
+ from collections import defaultdict
22
+ from threading import Thread, Event
23
+ import time
24
+
25
+ DEFAULT_WIDTH = 500
26
+ DEFAULT_HEIGHT = 200
27
+ DEFAULT_PADDING = 10
28
+ BG_COLOR = "#f2eeea" # sand-100
29
+ VIEW_FILL = "#faf7f4" # sand-200
30
+ GREYS = [
31
+ "#ebe8e5",
32
+ "#b2afac",
33
+ "#6a6867",
34
+ ]
35
+ BLACK = "#31302f"
36
+ GREENS = ["#dae8e2", "#3e8265", "#4c9878", "#428a6b", "#37795d"]
37
+ YELLOWS = ["#faf1db", "#f7e2b1", "#fbd784", "#e4b957", "#d7a530"]
38
+ PURPLES = ["#f5eff9", "#e7d4f3", "#976bac", "#8e53a9", "#77458f"]
39
+ REDS = ["#fce5e2", "#f3b6af", "#e6786c", "#e35f50", "#ce493a"]
40
+ BLUES = ["#dfe9f4", "#bdd8f2", "#88b7e3", "#6799c8", "#4e7ca7"]
41
+ ALL_COLORS = [
42
+ # GREENS[0], PURPLES[0], REDS[0], BLUES[0], YELLOWS[0],
43
+ GREENS[1],
44
+ PURPLES[1],
45
+ REDS[1],
46
+ BLUES[1],
47
+ YELLOWS[1],
48
+ GREENS[2],
49
+ PURPLES[2],
50
+ REDS[2],
51
+ BLUES[2],
52
+ YELLOWS[2],
53
+ GREENS[3],
54
+ PURPLES[3],
55
+ REDS[3],
56
+ BLUES[3],
57
+ YELLOWS[3],
58
+ GREENS[4],
59
+ PURPLES[4],
60
+ REDS[4],
61
+ BLUES[4],
62
+ YELLOWS[4],
63
+ ]
64
+
65
+
66
+ def update_spec_data(spec, data):
67
+ spec["data"]["values"].append(data)
68
+ return spec
69
+
70
+
71
+ def update_data_object(data_object, data):
72
+ data_object["values"].append(data)
73
+ return data_object
74
+
75
+
76
+ def line_chart_spec(
77
+ title=None,
78
+ category_name="u",
79
+ y_name="v",
80
+ xtitle=None,
81
+ ytitle=None,
82
+ width=DEFAULT_WIDTH,
83
+ height=DEFAULT_HEIGHT,
84
+ with_params=True,
85
+ x_axis_temporal=False,
86
+ ):
87
+ parameters = [
88
+ {
89
+ "name": "interpolate",
90
+ "value": "linear",
91
+ "bind": {
92
+ "input": "select",
93
+ "options": [
94
+ "basis",
95
+ "cardinal",
96
+ "catmull-rom",
97
+ "linear",
98
+ "monotone",
99
+ "natural",
100
+ "step",
101
+ "step-after",
102
+ "step-before",
103
+ ],
104
+ },
105
+ },
106
+ {
107
+ "name": "tension",
108
+ "value": 0,
109
+ "bind": {"input": "range", "min": 0, "max": 1, "step": 0.05},
110
+ },
111
+ {
112
+ "name": "strokeWidth",
113
+ "value": 2,
114
+ "bind": {"input": "range", "min": 0, "max": 10, "step": 0.5},
115
+ },
116
+ {
117
+ "name": "strokeCap",
118
+ "value": "butt",
119
+ "bind": {"input": "select", "options": ["butt", "round", "square"]},
120
+ },
121
+ {
122
+ "name": "strokeDash",
123
+ "value": [1, 0],
124
+ "bind": {
125
+ "input": "select",
126
+ "options": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]],
127
+ },
128
+ },
129
+ ]
130
+ parameter_marks = {
131
+ "interpolate": {"expr": "interpolate"},
132
+ "tension": {"expr": "tension"},
133
+ "strokeWidth": {"expr": "strokeWidth"},
134
+ "strokeDash": {"expr": "strokeDash"},
135
+ "strokeCap": {"expr": "strokeCap"},
136
+ }
137
+ spec = {
138
+ "title": title if title else "Line Chart",
139
+ "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
140
+ "width": DEFAULT_WIDTH,
141
+ "height": DEFAULT_HEIGHT,
142
+ "background": BG_COLOR,
143
+ "padding": DEFAULT_PADDING,
144
+ "view": {"fill": VIEW_FILL},
145
+ "params": parameters if with_params else [],
146
+ "data": {"name": "values", "values": []},
147
+ "mark": {
148
+ "type": "line",
149
+ "tooltip": True,
150
+ **(parameter_marks if with_params else {}),
151
+ },
152
+ "selection": {"grid": {"type": "interval", "bind": "scales"}},
153
+ "encoding": {
154
+ "x": {
155
+ "field": category_name,
156
+ "title": xtitle if xtitle else category_name,
157
+ **({"timeUnit": "seconds"} if x_axis_temporal else {}),
158
+ **({"type": "quantitative"} if not x_axis_temporal else {}),
159
+ },
160
+ "y": {
161
+ "field": y_name,
162
+ "type": "quantitative",
163
+ "title": ytitle if ytitle else y_name,
164
+ },
165
+ },
166
+ }
167
+ data = {"values": []}
168
+ return spec, data
169
+
170
+
171
+ class LineChart(MetaflowCardComponent):
172
+ REALTIME_UPDATABLE = True
173
+
174
+ def __init__(
175
+ self,
176
+ title,
177
+ xtitle,
178
+ ytitle,
179
+ category_name,
180
+ y_name,
181
+ with_params=False,
182
+ x_axis_temporal=False,
183
+ ):
184
+ super().__init__()
185
+
186
+ self.spec, _ = line_chart_spec(
187
+ title=title,
188
+ xtitle=xtitle,
189
+ ytitle=ytitle,
190
+ category_name=category_name,
191
+ y_name=y_name,
192
+ with_params=with_params,
193
+ x_axis_temporal=x_axis_temporal,
194
+ )
195
+
196
+ def update(self, data): # Can take a diff
197
+ self.spec = update_spec_data(self.spec, data)
198
+
199
+ @with_default_component_id
200
+ def render(self):
201
+ vega_chart = VegaChart(self.spec, show_controls=True)
202
+ vega_chart.component_id = self.component_id
203
+ return vega_chart.render()
204
+
205
+
206
+ class ArtifactTable(Artifact):
207
+ def __init__(self, data_dict):
208
+ self._data = data_dict
209
+ self._task_to_dict = TaskToDict(only_repr=True)
210
+
211
+ @with_default_component_id
212
+ @render_safely
213
+ def render(self):
214
+ _art_list = []
215
+ for k, v in self._data.items():
216
+ _art = self._task_to_dict.infer_object(v)
217
+ _art["name"] = k
218
+ _art_list.append(_art)
219
+
220
+ af_component = ArtifactsComponent(data=_art_list)
221
+ af_component.component_id = self.component_id
222
+ return af_component.render()
223
+
224
+
225
+ # fmt: off
226
+ class BarPlot(MetaflowCardComponent):
227
+ REALTIME_UPDATABLE = True
228
+
229
+ def __init__(self, title, category_name, value_name, orientation="vertical"):
230
+
231
+ if orientation not in ["vertical", "horizontal"]:
232
+ raise ValueError("orientation must be either 'vertical' or 'horizontal'")
233
+
234
+ super().__init__()
235
+ self.spec = {
236
+ "title": title,
237
+ "$schema": "https://vega.github.io/schema/vega/v5.json",
238
+ "description": "A basic bar chart example to show a count of values grouped by a category.",
239
+ "background": BG_COLOR,
240
+ "view": {"fill": VIEW_FILL},
241
+ "width": DEFAULT_WIDTH,
242
+ "height": DEFAULT_HEIGHT,
243
+ "padding": DEFAULT_PADDING,
244
+ "data": [{"name": "table", "values": []}],
245
+ "signals": [
246
+ {
247
+ "name": "tooltip",
248
+ "value": {},
249
+ "on": [
250
+ {"events": "rect:pointerover", "update": "datum"},
251
+ {"events": "rect:pointerout", "update": "{}"},
252
+ ],
253
+ }
254
+ ],
255
+ "scales": [
256
+ {
257
+ "name": "xscale" if orientation == "vertical" else "yscale",
258
+ "type": "band",
259
+ "domain": {"data": "table", "field": category_name},
260
+ "range": "width" if orientation == "vertical" else "height",
261
+ "padding": 0.25,
262
+ "round": True,
263
+ },
264
+ {
265
+ "name": "yscale" if orientation == "vertical" else "xscale",
266
+ "domain": {"data": "table", "field": value_name},
267
+ "nice": True,
268
+ "range": "height" if orientation == "vertical" else "width",
269
+ },
270
+ {
271
+ "name": "color",
272
+ "type": "ordinal",
273
+ "domain": {"data": "table", "field": category_name},
274
+ "range": ALL_COLORS,
275
+ },
276
+ ],
277
+ "axes": [
278
+ {"orient": "bottom", "scale": "xscale", "zindex": 1},
279
+ {"orient": "left", "scale": "yscale", "zindex": 1},
280
+ ],
281
+ "marks": [
282
+ {
283
+ "type": "rect",
284
+ "from": {"data": "table"},
285
+ "encode": {
286
+ "enter": {
287
+ "x": {
288
+ "scale": "xscale",
289
+ "field": (
290
+ category_name
291
+ if orientation == "vertical"
292
+ else value_name
293
+ ),
294
+ },
295
+ "y": {
296
+ "scale": "yscale",
297
+ "field": (
298
+ value_name
299
+ if orientation == "vertical"
300
+ else category_name
301
+ ),
302
+ },
303
+ f"{'y2' if orientation == 'vertical' else 'x2'}": {
304
+ "scale": (
305
+ "yscale" if orientation == "vertical" else "xscale"
306
+ ),
307
+ "value": 0,
308
+ },
309
+ "width": {"scale": "xscale", "band": 1},
310
+ "height": {"scale": "yscale", "band": 1},
311
+ },
312
+ "update": {
313
+ "fill": {"value": GREENS[0]},
314
+ },
315
+ "hover": {"fill": {"value": GREENS[2]}},
316
+ },
317
+ },
318
+ {
319
+ "type": "text",
320
+ "encode": {
321
+ "enter": {
322
+ "align": {"value": "center"},
323
+ "baseline": {"value": "bottom"},
324
+ "fill": {"value": BG_COLOR},
325
+ },
326
+ "update": {
327
+ "x": {
328
+ "scale": "xscale",
329
+ "signal": f"tooltip.{category_name if orientation == 'vertical' else value_name}",
330
+ f"{'band' if orientation == 'vertical' else 'offset'}": (
331
+ 0.5 if orientation == "vertical" else -10
332
+ ),
333
+ },
334
+ "y": {
335
+ "scale": "yscale",
336
+ "signal": f"tooltip.{value_name if orientation == 'vertical' else category_name}",
337
+ f"{'band' if orientation == 'horizontal' else 'offset'}": (
338
+ 0.5 if orientation == "horizontal" else 20
339
+ ),
340
+ },
341
+ "text": {"signal": f"tooltip.{value_name}"},
342
+ "fillOpacity": [
343
+ {"test": "datum === tooltip", "value": 0},
344
+ {"value": 1},
345
+ ],
346
+ },
347
+ },
348
+ },
349
+ ],
350
+ }
351
+
352
+ def update(self, data): # Can take a diff
353
+ self.spec = update_spec_data(self.spec, data)
354
+
355
+ @with_default_component_id
356
+ def render(self):
357
+ vega_chart = VegaChart(self.spec, show_controls=True)
358
+ vega_chart.component_id = self.component_id
359
+ return vega_chart.render()
360
+
361
+
362
+ class ViolinPlot(MetaflowCardComponent):
363
+ REALTIME_UPDATABLE = True
364
+
365
+ def __init__(self, title, category_col_name, value_col_name):
366
+ super().__init__()
367
+
368
+ self.spec = {
369
+ "title": title,
370
+ "$schema": "https://vega.github.io/schema/vega/v5.json",
371
+ "description": "A violin chart to show a distributional properties of each category.",
372
+ "background": BG_COLOR,
373
+ "view": {"fill": VIEW_FILL},
374
+ "width": DEFAULT_WIDTH,
375
+ "height": DEFAULT_HEIGHT,
376
+ "padding": DEFAULT_PADDING,
377
+ "config": {
378
+ "axisBand": {"bandPosition": 1, "tickExtra": True, "tickOffset": 0}
379
+ },
380
+ "signals": [
381
+ {"name": "plotWidth", "value": 75},
382
+ {"name": "height", "update": "(plotWidth + 10) * 3"},
383
+ {
384
+ "name": "bandwidth",
385
+ "value": 0.1,
386
+ "bind": {"input": "range", "min": 0, "max": 0.2, "step": 0.01},
387
+ },
388
+ ],
389
+ "data": [
390
+ {"name": "src", "values": []},
391
+ {
392
+ "name": "density",
393
+ "source": "src",
394
+ "transform": [
395
+ {
396
+ "type": "kde",
397
+ "groupby": [category_col_name],
398
+ "field": value_col_name,
399
+ "bandwidth": {"signal": "bandwidth"},
400
+ "extent": {"signal": "domain('xscale')"},
401
+ }
402
+ ],
403
+ },
404
+ {
405
+ "name": "stats",
406
+ "source": "src",
407
+ "transform": [
408
+ {
409
+ "type": "aggregate",
410
+ "groupby": [category_col_name],
411
+ "fields": [value_col_name, value_col_name, value_col_name],
412
+ "ops": ["q1", "q3", "median"],
413
+ "as": ["q1", "q3", "median"],
414
+ }
415
+ ],
416
+ },
417
+ ],
418
+ "scales": [
419
+ {
420
+ "name": "layout",
421
+ "type": "band",
422
+ "range": "height",
423
+ "domain": {"data": "src", "field": category_col_name},
424
+ },
425
+ {
426
+ "name": "xscale",
427
+ "type": "linear",
428
+ "range": "width",
429
+ "round": True,
430
+ "domain": {"data": "src", "field": value_col_name},
431
+ "zero": False,
432
+ "nice": True,
433
+ },
434
+ {
435
+ "name": "hscale",
436
+ "type": "linear",
437
+ "range": [0, {"signal": "plotWidth"}],
438
+ "domain": {"data": "density", "field": "density"},
439
+ },
440
+ {
441
+ "name": "color",
442
+ "type": "ordinal",
443
+ "domain": {"data": "src", "field": category_col_name},
444
+ "range": ALL_COLORS,
445
+ },
446
+ ],
447
+ "axes": [
448
+ {"orient": "bottom", "scale": "xscale", "zindex": 1},
449
+ {"orient": "left", "scale": "layout", "zindex": 1},
450
+ ],
451
+ "marks": [
452
+ {
453
+ "type": "group",
454
+ "from": {
455
+ "facet": {
456
+ "data": "density",
457
+ "name": "violin",
458
+ "groupby": category_col_name,
459
+ }
460
+ },
461
+ "encode": {
462
+ "enter": {
463
+ "yc": {
464
+ "scale": "layout",
465
+ "field": category_col_name,
466
+ "band": 0.5,
467
+ },
468
+ "height": {"signal": "plotWidth"},
469
+ "width": {"signal": "width"},
470
+ }
471
+ },
472
+ "data": [
473
+ {
474
+ "name": "summary",
475
+ "source": "stats",
476
+ "transform": [
477
+ {
478
+ "type": "filter",
479
+ "expr": f"datum.{category_col_name} === parent.{category_col_name}",
480
+ }
481
+ ],
482
+ }
483
+ ],
484
+ "marks": [
485
+ {
486
+ "type": "area",
487
+ "from": {"data": "violin"},
488
+ "encode": {
489
+ "enter": {
490
+ "fill": {
491
+ "scale": "color",
492
+ "field": {"parent": category_col_name},
493
+ }
494
+ },
495
+ "update": {
496
+ "x": {"scale": "xscale", "field": "value"},
497
+ "yc": {"signal": "plotWidth / 2"},
498
+ "height": {"scale": "hscale", "field": "density"},
499
+ },
500
+ },
501
+ },
502
+ {
503
+ "type": "rect",
504
+ "from": {"data": "summary"},
505
+ "encode": {
506
+ "enter": {
507
+ "fill": {"value": BLACK},
508
+ "height": {"value": 2},
509
+ },
510
+ "update": {
511
+ "x": {"scale": "xscale", "field": "q1"},
512
+ "x2": {"scale": "xscale", "field": "q3"},
513
+ "yc": {"signal": "plotWidth / 2"},
514
+ },
515
+ },
516
+ },
517
+ {
518
+ "type": "rect",
519
+ "from": {"data": "summary"},
520
+ "encode": {
521
+ "enter": {
522
+ "fill": {"value": BLACK},
523
+ "width": {"value": 2},
524
+ "height": {"value": 8},
525
+ },
526
+ "update": {
527
+ "x": {"scale": "xscale", "field": "median"},
528
+ "yc": {"signal": "plotWidth / 2"},
529
+ },
530
+ },
531
+ },
532
+ ],
533
+ }
534
+ ],
535
+ }
536
+
537
+ def update(self, data): # Can take a diff
538
+ self.spec = update_spec_data(self.spec, data)
539
+
540
+ @with_default_component_id
541
+ def render(self):
542
+ vega_chart = VegaChart(self.spec, show_controls=True)
543
+ vega_chart.component_id = self.component_id
544
+ return vega_chart.render()
545
+ # fmt: on
@@ -0,0 +1,70 @@
1
+ from metaflow.exception import MetaflowException
2
+ from collections import defaultdict
3
+
4
+
5
+ class CardDecoratorInjector:
6
+ """
7
+ Mixin Useful for injecting @card decorators from other first class Metaflow decorators.
8
+ """
9
+
10
+ _first_time_init = defaultdict(dict)
11
+
12
+ @classmethod
13
+ def _get_first_time_init_cached_value(cls, step_name, card_id):
14
+ return cls._first_time_init.get(step_name, {}).get(card_id, None)
15
+
16
+ @classmethod
17
+ def _set_first_time_init_cached_value(cls, step_name, card_id, value):
18
+ cls._first_time_init[step_name][card_id] = value
19
+
20
+ def _card_deco_already_attached(self, step, card_id):
21
+ for decorator in step.decorators:
22
+ if decorator.name == "card":
23
+ if decorator.attributes["id"] and card_id in decorator.attributes["id"]:
24
+ return True
25
+ return False
26
+
27
+ def _get_step(self, flow, step_name):
28
+ for step in flow:
29
+ if step.name == step_name:
30
+ return step
31
+ return None
32
+
33
+ def _first_time_init_check(self, step_dag_node, card_id):
34
+ """ """
35
+ return not self._card_deco_already_attached(step_dag_node, card_id)
36
+
37
+ def attach_card_decorator(
38
+ self,
39
+ flow,
40
+ step_name,
41
+ card_id,
42
+ card_type,
43
+ refresh_interval=5,
44
+ ):
45
+ """
46
+ This method is called `step_init` in your StepDecorator code since
47
+ this class is used as a Mixin
48
+ """
49
+ from metaflow import decorators as _decorators
50
+
51
+ if not all([card_id, card_type]):
52
+ raise MetaflowException(
53
+ "`INJECTED_CARD_ID` and `INJECTED_CARD_TYPE` must be set in the `CardDecoratorInjector` Mixin"
54
+ )
55
+
56
+ step_dag_node = self._get_step(flow, step_name)
57
+ if (
58
+ self._get_first_time_init_cached_value(step_name, card_id) is None
59
+ ): # First check class level setting.
60
+ if self._first_time_init_check(step_dag_node, card_id):
61
+ self._set_first_time_init_cached_value(step_name, card_id, True)
62
+ _decorators._attach_decorators_to_step(
63
+ step_dag_node,
64
+ [
65
+ "card:type=%s,id=%s,refresh_interval=%s"
66
+ % (card_type, card_id, str(refresh_interval))
67
+ ],
68
+ )
69
+ else:
70
+ self._set_first_time_init_cached_value(step_name, card_id, False)