ob-metaflow-extensions 1.1.128__py2.py3-none-any.whl → 1.1.129__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

@@ -319,8 +319,9 @@ STEP_DECORATORS_DESC = [
319
319
  ("snowpark", ".snowpark.snowpark_decorator.SnowparkDecorator"),
320
320
  ("tensorboard", ".tensorboard.TensorboardDecorator"),
321
321
  ("gpu_profile", ".profilers.gpu_profile_decorator.GPUProfileDecorator"),
322
+ ("nim", ".nim.NimDecorator"),
322
323
  ]
323
- FLOW_DECORATORS_DESC = [("nim", ".nim.NimDecorator")]
324
+
324
325
  TOGGLE_STEP_DECORATOR = [
325
326
  "-batch",
326
327
  "-step_functions_internal",
@@ -0,0 +1,142 @@
1
+ from metaflow.metaflow_current import current
2
+ import sqlite3
3
+ from threading import Thread, Event
4
+ import time
5
+
6
+
7
+ class InfoCollectorThread(Thread):
8
+ def __init__(
9
+ self,
10
+ interval=1,
11
+ file_name=None,
12
+ sqlite_fetch_func=None, # Callable
13
+ ):
14
+ super().__init__()
15
+ self._exit_event = Event()
16
+ self._interval = interval
17
+ assert file_name is not None, "file_name must be provided"
18
+ self._file_name = file_name
19
+ self.daemon = True
20
+ self._data = {}
21
+ self._has_errored = False
22
+ self._current_error = None
23
+ self.sqlite_fetch_func = sqlite_fetch_func
24
+
25
+ def read(self):
26
+ return self._data
27
+
28
+ def has_errored(self):
29
+ return self._has_errored
30
+
31
+ def get_error(self):
32
+ return self._current_error
33
+
34
+ def _safely_load(self):
35
+ try:
36
+ conn = sqlite3.connect(self._file_name)
37
+ data = self.sqlite_fetch_func(conn)
38
+ return {"metrics": data}, None
39
+ except FileNotFoundError as e:
40
+ return {}, e
41
+ except sqlite3.Error as e:
42
+ return {}, e
43
+ finally:
44
+ conn.close()
45
+
46
+ def run(self):
47
+ while self._exit_event.is_set() is False:
48
+ data, self._current_error = self._safely_load()
49
+ if not self._current_error:
50
+ self._data = data
51
+ self._has_errored = True if self._current_error else False
52
+ time.sleep(self._interval)
53
+
54
+ def stop(self):
55
+ self._exit_event.set()
56
+ self.join()
57
+
58
+
59
+ class CardRefresher:
60
+
61
+ CARD_ID = None
62
+
63
+ def on_startup(self, current_card):
64
+ raise NotImplementedError("make_card method must be implemented")
65
+
66
+ def on_error(self, current_card, error_message):
67
+ raise NotImplementedError("error_card method must be implemented")
68
+
69
+ def on_update(self, current_card, data_object):
70
+ raise NotImplementedError("update_card method must be implemented")
71
+
72
+ def sqlite_fetch_func(self, conn):
73
+ raise NotImplementedError("sqlite_fetch_func must be implemented")
74
+
75
+
76
+ class CardUpdaterThread(Thread):
77
+ def __init__(
78
+ self,
79
+ card_refresher: CardRefresher,
80
+ interval=1,
81
+ file_name=None,
82
+ collector_thread: InfoCollectorThread = None,
83
+ ):
84
+ super().__init__()
85
+ self._exit_event = Event()
86
+ self._interval = interval
87
+ self._refresher = card_refresher
88
+ self._file_name = file_name
89
+ self._collector_thread = collector_thread
90
+ self.daemon = True
91
+
92
+ def run(self):
93
+ if self._refresher.CARD_ID is None:
94
+ raise ValueError("CARD_ID must be defined")
95
+ current_card = current.card[self._refresher.CARD_ID]
96
+ self._refresher.on_startup(current_card)
97
+ while self._exit_event.is_set() is False:
98
+ data = self._collector_thread.read()
99
+ if self._collector_thread.has_errored():
100
+ self._refresher.on_error(
101
+ current_card, self._collector_thread.get_error()
102
+ )
103
+ self._refresher.on_update(current_card, data)
104
+ time.sleep(self._interval)
105
+
106
+ def stop(self):
107
+ self._exit_event.set()
108
+ self._collector_thread.stop()
109
+ self.join()
110
+
111
+
112
+ class AsyncPeriodicRefresher:
113
+ def __init__(
114
+ self,
115
+ card_referesher: CardRefresher,
116
+ updater_interval=1,
117
+ collector_interval=1,
118
+ file_name=None,
119
+ ):
120
+ assert card_referesher.CARD_ID is not None, "CARD_ID must be defined"
121
+ self._collector_thread = InfoCollectorThread(
122
+ interval=collector_interval,
123
+ file_name=file_name,
124
+ sqlite_fetch_func=card_referesher.sqlite_fetch_func,
125
+ )
126
+ self._collector_thread.start()
127
+ self._updater_thread = CardUpdaterThread(
128
+ card_refresher=card_referesher,
129
+ interval=updater_interval,
130
+ file_name=file_name,
131
+ collector_thread=self._collector_thread,
132
+ )
133
+
134
+ def start(self):
135
+ self._updater_thread.start()
136
+
137
+ def stop(self):
138
+ data = self._collector_thread.read()
139
+ current_card = current.card[self._updater_thread._refresher.CARD_ID]
140
+ self._updater_thread._refresher.on_update(current_card, data)
141
+ self._updater_thread.stop()
142
+ self._collector_thread.stop()