elastic-kernel 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- elastic_kernel/__init__.py +0 -0
- elastic_kernel/command.py +43 -0
- elastic_kernel/kernel.json +5 -0
- elastic_kernel/kernel.py +258 -0
- elastic_kernel-0.0.2.dist-info/METADATA +291 -0
- elastic_kernel-0.0.2.dist-info/RECORD +42 -0
- elastic_kernel-0.0.2.dist-info/WHEEL +5 -0
- elastic_kernel-0.0.2.dist-info/entry_points.txt +2 -0
- elastic_kernel-0.0.2.dist-info/licenses/LICENSE +201 -0
- elastic_kernel-0.0.2.dist-info/top_level.txt +2 -0
- elastic_notebook/__init__.py +0 -0
- elastic_notebook/algorithm/__init__.py +0 -0
- elastic_notebook/algorithm/baseline.py +31 -0
- elastic_notebook/algorithm/optimizer_exact.py +121 -0
- elastic_notebook/algorithm/selector.py +41 -0
- elastic_notebook/core/__init__.py +0 -0
- elastic_notebook/core/common/__init__.py +0 -0
- elastic_notebook/core/common/checkpoint_file.py +129 -0
- elastic_notebook/core/common/profile_graph_size.py +39 -0
- elastic_notebook/core/common/profile_migration_speed.py +69 -0
- elastic_notebook/core/common/profile_variable_size.py +66 -0
- elastic_notebook/core/graph/__init__.py +0 -0
- elastic_notebook/core/graph/cell_execution.py +39 -0
- elastic_notebook/core/graph/graph.py +75 -0
- elastic_notebook/core/graph/variable_snapshot.py +31 -0
- elastic_notebook/core/io/__init__.py +0 -0
- elastic_notebook/core/io/adapter.py +18 -0
- elastic_notebook/core/io/filesystem_adapter.py +30 -0
- elastic_notebook/core/io/migrate.py +98 -0
- elastic_notebook/core/io/pickle.py +71 -0
- elastic_notebook/core/io/recover.py +51 -0
- elastic_notebook/core/mutation/__init__.py +0 -0
- elastic_notebook/core/mutation/fingerprint.py +184 -0
- elastic_notebook/core/mutation/id_graph.py +147 -0
- elastic_notebook/core/mutation/object_hash.py +204 -0
- elastic_notebook/core/notebook/__init__.py +0 -0
- elastic_notebook/core/notebook/checkpoint.py +222 -0
- elastic_notebook/core/notebook/find_input_vars.py +117 -0
- elastic_notebook/core/notebook/find_output_vars.py +18 -0
- elastic_notebook/core/notebook/restore_notebook.py +91 -0
- elastic_notebook/core/notebook/update_graph.py +46 -0
- elastic_notebook/elastic_notebook.py +336 -0
@@ -0,0 +1,147 @@
|
|
1
|
+
from inspect import isclass
|
2
|
+
from types import ModuleType
|
3
|
+
|
4
|
+
PRIMITIVES = {type(None), int, float, bool, str}
|
5
|
+
ITERABLES = {tuple, set, list}
|
6
|
+
OBJECT_FILTER_FUNC = lambda x: not x[0].startswith("_") and not is_primitive(type(x[1]))
|
7
|
+
ITERABLE_FILTER_FUNC = lambda x: not is_primitive(x)
|
8
|
+
|
9
|
+
|
10
|
+
class IdGraphNode:
|
11
|
+
"""
|
12
|
+
The IdGraph is used to model the reachable objects (their ids and types) from a given object.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, obj_id, obj_type, child_nodes):
|
16
|
+
"""
|
17
|
+
Args:
|
18
|
+
obj_id: id (memory address) of the object.
|
19
|
+
obj_type: type of the object.
|
20
|
+
child_nodes: other objects reachable from this object (i.e. class attributes, list/set members)
|
21
|
+
"""
|
22
|
+
self.obj_id = obj_id
|
23
|
+
self.obj_type = obj_type
|
24
|
+
self.child_nodes = child_nodes
|
25
|
+
|
26
|
+
|
27
|
+
def is_primitive(obj_type):
|
28
|
+
return obj_type in PRIMITIVES
|
29
|
+
|
30
|
+
|
31
|
+
def is_root_equals(node1, node2):
|
32
|
+
"""
|
33
|
+
Compare only the root notes of 2 ID graphs.
|
34
|
+
"""
|
35
|
+
if node1 is None and node2 is None:
|
36
|
+
return True
|
37
|
+
elif node1 is None or node2 is None:
|
38
|
+
return False
|
39
|
+
return node1.obj_id == node2.obj_id and node1.obj_type == node2.obj_type
|
40
|
+
|
41
|
+
|
42
|
+
def is_structure_equals_helper(node1, node2, visited):
|
43
|
+
"""
|
44
|
+
BFS helper for recursively comparing ID graphs.
|
45
|
+
"""
|
46
|
+
if (
|
47
|
+
node1.obj_id != node2.obj_id
|
48
|
+
or node1.obj_type != node2.obj_type
|
49
|
+
or len(node1.child_nodes) != len(node2.child_nodes)
|
50
|
+
):
|
51
|
+
return False
|
52
|
+
visited.add(node1)
|
53
|
+
|
54
|
+
for i in range(len(node1.child_nodes)):
|
55
|
+
if node1.child_nodes[i] in visited:
|
56
|
+
continue
|
57
|
+
if not is_structure_equals_helper(
|
58
|
+
node1.child_nodes[i], node2.child_nodes[i], visited
|
59
|
+
):
|
60
|
+
return False
|
61
|
+
|
62
|
+
return True
|
63
|
+
|
64
|
+
|
65
|
+
def is_structure_equals(node1, node2):
|
66
|
+
"""
|
67
|
+
Compare 2 ID graphs for equality.
|
68
|
+
"""
|
69
|
+
if node1 is None and node2 is None:
|
70
|
+
return True
|
71
|
+
elif node1 is None or node2 is None:
|
72
|
+
return False
|
73
|
+
return is_structure_equals_helper(node1, node2, set())
|
74
|
+
|
75
|
+
|
76
|
+
def construct_id_graph_node(obj, visited):
|
77
|
+
"""
|
78
|
+
Helper function for constructing an ID graph. Constructs an ID graph node and recurses into reachable objects
|
79
|
+
in BFS fashion.
|
80
|
+
Args:
|
81
|
+
obj: object.
|
82
|
+
visited (set): set of visited objects (for handling cyclic references).
|
83
|
+
"""
|
84
|
+
|
85
|
+
if obj is None:
|
86
|
+
return None
|
87
|
+
|
88
|
+
# Construct ID graph node.
|
89
|
+
T = type(obj)
|
90
|
+
child_list = []
|
91
|
+
id_graph_node = IdGraphNode(id(obj), T.__name__, child_list)
|
92
|
+
|
93
|
+
visited[id(obj)] = id_graph_node
|
94
|
+
|
95
|
+
# obj is Iterable: recurse into items
|
96
|
+
if T in ITERABLES:
|
97
|
+
for child_obj in obj:
|
98
|
+
if is_primitive(type(child_obj)):
|
99
|
+
continue
|
100
|
+
if id(child_obj) in visited:
|
101
|
+
child_list.append(visited[id(child_obj)])
|
102
|
+
else:
|
103
|
+
child_list.append(construct_id_graph_node(child_obj, visited))
|
104
|
+
|
105
|
+
# obj is object instance: recurse into attributes
|
106
|
+
# Special note for sets: this works as the order of iterating through sets in the same session is deterministic.
|
107
|
+
elif (
|
108
|
+
hasattr(obj, "__dict__")
|
109
|
+
and not isinstance(obj, ModuleType)
|
110
|
+
and not isclass(obj)
|
111
|
+
):
|
112
|
+
for k, v in filter(OBJECT_FILTER_FUNC, vars(obj).items()):
|
113
|
+
if id(v) in visited:
|
114
|
+
child_list.append(visited[id(v)])
|
115
|
+
else:
|
116
|
+
child_list.append(construct_id_graph_node(v, visited))
|
117
|
+
|
118
|
+
# obj is dictionary: recurse into both keys and values
|
119
|
+
elif T is dict:
|
120
|
+
for k, v in obj.items():
|
121
|
+
if is_primitive(type(k)):
|
122
|
+
continue
|
123
|
+
if id(k) in visited:
|
124
|
+
child_list.append(visited[id(k)])
|
125
|
+
else:
|
126
|
+
child_list.append(construct_id_graph_node(k, visited))
|
127
|
+
|
128
|
+
if is_primitive(type(v)):
|
129
|
+
continue
|
130
|
+
if id(v) in visited:
|
131
|
+
child_list.append(visited[id(v)])
|
132
|
+
else:
|
133
|
+
child_list.append(construct_id_graph_node(v, visited))
|
134
|
+
|
135
|
+
return id_graph_node
|
136
|
+
|
137
|
+
|
138
|
+
def construct_id_graph(obj):
|
139
|
+
"""
|
140
|
+
Construct the ID graph for an object.
|
141
|
+
Returns:
|
142
|
+
graph_bytestring (bytestring): the serialized bytestring of the ID graph.
|
143
|
+
visited: a set of all IDs of reachable objects in the graph. For checking object overlaps.
|
144
|
+
"""
|
145
|
+
visited = {}
|
146
|
+
graph = construct_id_graph_node(obj, visited)
|
147
|
+
return graph, set(visited.keys())
|
@@ -0,0 +1,204 @@
|
|
1
|
+
import copy
|
2
|
+
import io
|
3
|
+
from inspect import isclass
|
4
|
+
from types import FunctionType, ModuleType
|
5
|
+
|
6
|
+
import dill
|
7
|
+
|
8
|
+
# import polars as pl
|
9
|
+
import lightgbm
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
import scipy
|
13
|
+
import torch
|
14
|
+
import xxhash
|
15
|
+
|
16
|
+
BASE_TYPES = [type(None), FunctionType]
|
17
|
+
|
18
|
+
|
19
|
+
class ImmutableObj:
|
20
|
+
def __init__(self):
|
21
|
+
pass
|
22
|
+
|
23
|
+
def __eq__(self, other):
|
24
|
+
if isinstance(other, ImmutableObj):
|
25
|
+
return True
|
26
|
+
return False
|
27
|
+
|
28
|
+
|
29
|
+
# Object representing none.
|
30
|
+
class NoneObj:
|
31
|
+
def __init__(self):
|
32
|
+
pass
|
33
|
+
|
34
|
+
def __eq__(self, other):
|
35
|
+
if isinstance(other, NoneObj):
|
36
|
+
return True
|
37
|
+
return False
|
38
|
+
|
39
|
+
|
40
|
+
# Object representing a dataframe.
|
41
|
+
class DataframeObj:
|
42
|
+
def __init__(self):
|
43
|
+
pass
|
44
|
+
|
45
|
+
def __eq__(self, other):
|
46
|
+
if isinstance(other, DataframeObj):
|
47
|
+
return True
|
48
|
+
return False
|
49
|
+
|
50
|
+
|
51
|
+
class NxGraphObj:
|
52
|
+
def __init__(self, graph):
|
53
|
+
self.graph = graph
|
54
|
+
|
55
|
+
def __eq__(self, other):
|
56
|
+
if isinstance(other, NxGraphObj):
|
57
|
+
return nx.graphs_equal(self.graph, other.graph)
|
58
|
+
return False
|
59
|
+
|
60
|
+
|
61
|
+
class NpArrayObj:
|
62
|
+
def __init__(self, arraystr):
|
63
|
+
self.arraystr = arraystr
|
64
|
+
pass
|
65
|
+
|
66
|
+
def __eq__(self, other):
|
67
|
+
if isinstance(other, NpArrayObj):
|
68
|
+
return self.arraystr == other.arraystr
|
69
|
+
return False
|
70
|
+
|
71
|
+
|
72
|
+
class ScipyArrayObj:
|
73
|
+
def __init__(self, arraystr):
|
74
|
+
self.arraystr = arraystr
|
75
|
+
pass
|
76
|
+
|
77
|
+
def __eq__(self, other):
|
78
|
+
if isinstance(other, ScipyArrayObj):
|
79
|
+
return self.arraystr == other.arraystr
|
80
|
+
return False
|
81
|
+
|
82
|
+
|
83
|
+
class TorchTensorObj:
|
84
|
+
def __init__(self, arraystr):
|
85
|
+
self.arraystr = arraystr
|
86
|
+
pass
|
87
|
+
|
88
|
+
def __eq__(self, other):
|
89
|
+
if isinstance(other, TorchTensorObj):
|
90
|
+
return self.arraystr == other.arraystr
|
91
|
+
return False
|
92
|
+
|
93
|
+
|
94
|
+
class ModuleObj:
|
95
|
+
def __init__(self):
|
96
|
+
pass
|
97
|
+
|
98
|
+
def __eq__(self, other):
|
99
|
+
if isinstance(other, ModuleObj):
|
100
|
+
return True
|
101
|
+
return False
|
102
|
+
|
103
|
+
|
104
|
+
# Object representing general unserializable class.
|
105
|
+
class UnserializableObj:
|
106
|
+
def __init__(self):
|
107
|
+
pass
|
108
|
+
|
109
|
+
def __eq__(self, other):
|
110
|
+
if isinstance(other, UnserializableObj):
|
111
|
+
return True
|
112
|
+
return False
|
113
|
+
|
114
|
+
|
115
|
+
class UncomparableObj:
|
116
|
+
def __init__(self):
|
117
|
+
pass
|
118
|
+
|
119
|
+
def __eq__(self, other):
|
120
|
+
if isinstance(other, UncomparableObj):
|
121
|
+
return True
|
122
|
+
return False
|
123
|
+
|
124
|
+
|
125
|
+
def construct_object_hash(obj, deepcopy=False):
|
126
|
+
"""
|
127
|
+
Construct an object hash for the object. Uses deep-copy as a fallback.
|
128
|
+
"""
|
129
|
+
|
130
|
+
if type(obj) in BASE_TYPES:
|
131
|
+
return ImmutableObj()
|
132
|
+
|
133
|
+
if isclass(obj):
|
134
|
+
return type(obj)
|
135
|
+
|
136
|
+
# Flag hack for Pandas dataframes: each dataframe column is a numpy array.
|
137
|
+
# All the writeable flags of these arrays are set to false; if after cell execution, any of these flags are
|
138
|
+
# reset to True, we assume that the dataframe has been modified.
|
139
|
+
if isinstance(obj, pd.DataFrame):
|
140
|
+
for _, col in obj.items():
|
141
|
+
col.__array__().flags.writeable = False
|
142
|
+
return DataframeObj()
|
143
|
+
|
144
|
+
if isinstance(obj, pd.Series):
|
145
|
+
obj.__array__().flags.writeable = False
|
146
|
+
return DataframeObj()
|
147
|
+
|
148
|
+
attr_str = getattr(obj, "__module__", None)
|
149
|
+
if attr_str and (
|
150
|
+
"matplotlib" in attr_str
|
151
|
+
or "transformers" in attr_str
|
152
|
+
or "networkx" in attr_str
|
153
|
+
or "keras" in attr_str
|
154
|
+
or "tensorflow" in attr_str
|
155
|
+
):
|
156
|
+
return UncomparableObj()
|
157
|
+
|
158
|
+
# Object is file handle
|
159
|
+
if isinstance(obj, io.IOBase):
|
160
|
+
return UncomparableObj()
|
161
|
+
|
162
|
+
if isinstance(obj, np.ndarray):
|
163
|
+
h = xxhash.xxh3_128()
|
164
|
+
h.update(np.ascontiguousarray(obj.data))
|
165
|
+
str1 = h.intdigest()
|
166
|
+
return NpArrayObj(str1)
|
167
|
+
|
168
|
+
if isinstance(obj, scipy.sparse.csr_matrix):
|
169
|
+
h = xxhash.xxh3_128()
|
170
|
+
h.update(np.ascontiguousarray(obj))
|
171
|
+
str1 = h.intdigest()
|
172
|
+
return ScipyArrayObj(str1)
|
173
|
+
|
174
|
+
if isinstance(obj, torch.Tensor):
|
175
|
+
h = xxhash.xxh3_128()
|
176
|
+
h.update(np.ascontiguousarray(obj))
|
177
|
+
str1 = h.intdigest()
|
178
|
+
return TorchTensorObj(str1)
|
179
|
+
|
180
|
+
if isinstance(obj, ModuleType) or isclass(obj):
|
181
|
+
return ModuleObj()
|
182
|
+
|
183
|
+
# Polars dataframes are immutable.
|
184
|
+
# if isinstance(obj, pl.DataFrame):
|
185
|
+
# return type(obj)
|
186
|
+
|
187
|
+
# LightGBM dataframes are immutable.
|
188
|
+
if isinstance(obj, lightgbm.Dataset):
|
189
|
+
return type(obj)
|
190
|
+
|
191
|
+
# Try to hash the object; if the object is unhashable, use deepcopy as fallback.
|
192
|
+
try:
|
193
|
+
h = xxhash.xxh3_128()
|
194
|
+
h.update(obj)
|
195
|
+
return h.intdigest()
|
196
|
+
except:
|
197
|
+
try:
|
198
|
+
if deepcopy:
|
199
|
+
return copy.deepcopy(obj)
|
200
|
+
else:
|
201
|
+
return obj
|
202
|
+
except:
|
203
|
+
# If object is not even deepcopy-able, mark it as unserializable and assume modified-on-write.
|
204
|
+
return UnserializableObj()
|
File without changes
|
@@ -0,0 +1,222 @@
|
|
1
|
+
import time
|
2
|
+
from typing import Dict
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from ipykernel.zmqshell import ZMQInteractiveShell
|
6
|
+
|
7
|
+
from elastic_notebook.algorithm.selector import Selector
|
8
|
+
from elastic_notebook.core.common.profile_variable_size import profile_variable_size
|
9
|
+
from elastic_notebook.core.graph.graph import DependencyGraph
|
10
|
+
from elastic_notebook.core.io.migrate import migrate
|
11
|
+
from elastic_notebook.core.mutation.object_hash import UnserializableObj
|
12
|
+
|
13
|
+
|
14
|
+
def checkpoint(
|
15
|
+
graph: DependencyGraph,
|
16
|
+
shell: ZMQInteractiveShell,
|
17
|
+
fingerprint_dict: Dict,
|
18
|
+
selector: Selector,
|
19
|
+
udfs: set,
|
20
|
+
filename: str,
|
21
|
+
profile_dict,
|
22
|
+
write_log_location=None,
|
23
|
+
notebook_name=None,
|
24
|
+
optimizer_name=None,
|
25
|
+
elastic_notebook=None,
|
26
|
+
):
|
27
|
+
"""
|
28
|
+
Checkpoints the notebook. The optimizer selects the VSs to migrate and recompute and the OEs to recompute, then
|
29
|
+
writes the checkpoint as the specified filename.
|
30
|
+
Args:
|
31
|
+
graph (DependencyGraph): dependency graph representation of the notebook.
|
32
|
+
shell (ZMQInteractiveShell): interactive Jupyter shell storing the state of the current session.
|
33
|
+
selector (Selector): optimizer for computing the checkpointing configuration.
|
34
|
+
udfs (set): set of user-declared functions.
|
35
|
+
filename (str): location to write the file to.
|
36
|
+
write_log_location (str): location to write component runtimes to. For experimentation only.
|
37
|
+
notebook_name (str): notebook name. For experimentation only.
|
38
|
+
optimizer_name (str): optimizer name. For experimentation only.
|
39
|
+
elastic_notebook (ElasticNotebook, optional): ElasticNotebook instance to update migration lists.
|
40
|
+
"""
|
41
|
+
profile_start = time.time()
|
42
|
+
|
43
|
+
# Retrieve active VSs from the graph. Active VSs are correspond to the latest instances/versions of each variable.
|
44
|
+
active_vss = set()
|
45
|
+
print("---------------------------")
|
46
|
+
print("all variables:")
|
47
|
+
# print(graph.variable_snapshots)
|
48
|
+
for vs_list in graph.variable_snapshots.values():
|
49
|
+
if not vs_list[-1].deleted:
|
50
|
+
print(f"name: {vs_list[-1].name}")
|
51
|
+
active_vss.add(vs_list[-1])
|
52
|
+
|
53
|
+
# Profile the size of each variable defined in the current session.
|
54
|
+
for active_vs in active_vss:
|
55
|
+
# 変数がfingerprint_dictに存在するかチェック
|
56
|
+
if active_vs.name not in fingerprint_dict:
|
57
|
+
print(f"Warning: Variable '{active_vs.name}' not found in fingerprint_dict")
|
58
|
+
continue
|
59
|
+
|
60
|
+
attr_str = getattr(shell.user_ns[active_vs.name], "__module__", None)
|
61
|
+
# Object is unserializable
|
62
|
+
if isinstance(fingerprint_dict[active_vs.name][2], UnserializableObj):
|
63
|
+
active_vs.size = np.inf
|
64
|
+
|
65
|
+
# Blacklisted object
|
66
|
+
elif attr_str and ("dataprep.eda" in attr_str or "bokeh" in attr_str):
|
67
|
+
active_vs.size = np.inf
|
68
|
+
|
69
|
+
# Profile size of object.
|
70
|
+
else:
|
71
|
+
active_vs.size = profile_variable_size(shell.user_ns[active_vs.name])
|
72
|
+
|
73
|
+
# Check for pairwise variable intersections. Variables sharing underlying data must be migrated or recomputed
|
74
|
+
# together.
|
75
|
+
overlapping_vss = []
|
76
|
+
for active_vs1 in active_vss:
|
77
|
+
for active_vs2 in active_vss:
|
78
|
+
if active_vs1 != active_vs2:
|
79
|
+
# 両方の変数がfingerprint_dictに存在するかチェック
|
80
|
+
if (
|
81
|
+
active_vs1.name not in fingerprint_dict
|
82
|
+
or active_vs2.name not in fingerprint_dict
|
83
|
+
):
|
84
|
+
continue
|
85
|
+
if fingerprint_dict[active_vs1.name][1].intersection(
|
86
|
+
fingerprint_dict[active_vs2.name][1]
|
87
|
+
):
|
88
|
+
overlapping_vss.append((active_vs1, active_vs2))
|
89
|
+
|
90
|
+
profile_end = time.time()
|
91
|
+
if write_log_location:
|
92
|
+
with open(
|
93
|
+
write_log_location
|
94
|
+
+ "/output_"
|
95
|
+
+ notebook_name
|
96
|
+
+ "_"
|
97
|
+
+ optimizer_name
|
98
|
+
+ ".txt",
|
99
|
+
"a",
|
100
|
+
) as f:
|
101
|
+
f.write("overlappings - " + repr(len(overlapping_vss)) + "\n")
|
102
|
+
f.write(
|
103
|
+
"Profile stage took - "
|
104
|
+
+ repr(profile_start - profile_end)
|
105
|
+
+ " seconds"
|
106
|
+
+ "\n"
|
107
|
+
)
|
108
|
+
f.write(
|
109
|
+
"Idgraph stage took - "
|
110
|
+
+ repr(profile_dict["idgraph"])
|
111
|
+
+ " seconds"
|
112
|
+
+ "\n"
|
113
|
+
)
|
114
|
+
f.write(
|
115
|
+
"Representation stage took - "
|
116
|
+
+ repr(profile_dict["representation"])
|
117
|
+
+ " seconds"
|
118
|
+
+ "\n"
|
119
|
+
)
|
120
|
+
|
121
|
+
optimize_start = time.time()
|
122
|
+
# Initialize the optimizer.
|
123
|
+
add_start = time.time()
|
124
|
+
selector.dependency_graph = graph
|
125
|
+
selector.active_vss = active_vss
|
126
|
+
selector.overlapping_vss = overlapping_vss
|
127
|
+
add_end = time.time()
|
128
|
+
|
129
|
+
# Use the optimizer to compute the checkpointing configuration.
|
130
|
+
opt_start = time.time()
|
131
|
+
vss_to_migrate, ces_to_recompute = selector.select_vss(
|
132
|
+
write_log_location, notebook_name, optimizer_name
|
133
|
+
)
|
134
|
+
opt_end = time.time()
|
135
|
+
print("---------------------------")
|
136
|
+
print("variables to migrate:")
|
137
|
+
for vs in vss_to_migrate:
|
138
|
+
print(f"name: {vs.name}, size: {vs.size}")
|
139
|
+
|
140
|
+
difference_start = time.time()
|
141
|
+
vss_to_recompute = active_vss - vss_to_migrate
|
142
|
+
difference_end = time.time()
|
143
|
+
|
144
|
+
print("---------------------------")
|
145
|
+
print("variables to recompute:")
|
146
|
+
for vs in vss_to_recompute:
|
147
|
+
print(f"name: {vs.name}, size: {vs.size}")
|
148
|
+
print([vs.name for vs in vss_to_recompute])
|
149
|
+
|
150
|
+
print("---------------------------")
|
151
|
+
print("cells to recompute:")
|
152
|
+
for ce in ces_to_recompute:
|
153
|
+
print(f"cell num: {ce.cell_num}, cell runtime: {ce.cell_runtime}")
|
154
|
+
print(sorted([ce.cell_num + 1 for ce in ces_to_recompute]))
|
155
|
+
|
156
|
+
optimize_end = time.time()
|
157
|
+
if write_log_location:
|
158
|
+
with open(
|
159
|
+
write_log_location
|
160
|
+
+ "/output_"
|
161
|
+
+ notebook_name
|
162
|
+
+ "_"
|
163
|
+
+ optimizer_name
|
164
|
+
+ ".txt",
|
165
|
+
"a",
|
166
|
+
) as f:
|
167
|
+
f.write(
|
168
|
+
"Optimize stage took - "
|
169
|
+
+ repr(optimize_end - optimize_start)
|
170
|
+
+ " seconds"
|
171
|
+
+ "\n"
|
172
|
+
)
|
173
|
+
f.write(
|
174
|
+
" Add stage took - " + repr(add_end - add_start) + " seconds" + "\n"
|
175
|
+
)
|
176
|
+
f.write(
|
177
|
+
" Opt stage took - " + repr(opt_end - opt_start) + " seconds" + "\n"
|
178
|
+
)
|
179
|
+
f.write(
|
180
|
+
" Diff stage took - "
|
181
|
+
+ repr(difference_end - difference_start)
|
182
|
+
+ " seconds"
|
183
|
+
+ "\n"
|
184
|
+
)
|
185
|
+
|
186
|
+
# 変数リストを更新
|
187
|
+
if elastic_notebook is not None:
|
188
|
+
elastic_notebook.update_migration_lists(vss_to_migrate, vss_to_recompute)
|
189
|
+
|
190
|
+
# Store the notebook checkpoint to the specified location.
|
191
|
+
migrate_start = time.time()
|
192
|
+
migrate_success = True
|
193
|
+
migrate(
|
194
|
+
graph,
|
195
|
+
shell,
|
196
|
+
vss_to_migrate,
|
197
|
+
vss_to_recompute,
|
198
|
+
ces_to_recompute,
|
199
|
+
udfs,
|
200
|
+
selector.recomputation_ces,
|
201
|
+
selector.overlapping_vss,
|
202
|
+
filename,
|
203
|
+
)
|
204
|
+
migrate_end = time.time()
|
205
|
+
|
206
|
+
if write_log_location:
|
207
|
+
with open(
|
208
|
+
write_log_location
|
209
|
+
+ "/output_"
|
210
|
+
+ notebook_name
|
211
|
+
+ "_"
|
212
|
+
+ optimizer_name
|
213
|
+
+ ".txt",
|
214
|
+
"a",
|
215
|
+
) as f:
|
216
|
+
f.write(
|
217
|
+
"Migrate stage took - "
|
218
|
+
+ repr(migrate_end - migrate_start)
|
219
|
+
+ " seconds"
|
220
|
+
+ "\n"
|
221
|
+
)
|
222
|
+
return migrate_success
|
@@ -0,0 +1,117 @@
|
|
1
|
+
import ast
|
2
|
+
import inspect
|
3
|
+
from collections import deque
|
4
|
+
from typing import Tuple
|
5
|
+
|
6
|
+
from ipykernel.zmqshell import ZMQInteractiveShell
|
7
|
+
|
8
|
+
PRIMITIVES = {int, bool, str, float}
|
9
|
+
|
10
|
+
|
11
|
+
# Node visitor for finding input variables.
|
12
|
+
class Visitor(ast.NodeVisitor):
|
13
|
+
def __init__(self, shell, shell_udfs):
|
14
|
+
# Whether we are currently in local scope.
|
15
|
+
self.is_local = False
|
16
|
+
|
17
|
+
# Functions declared in
|
18
|
+
self.functiondefs = set()
|
19
|
+
self.udfcalls = set()
|
20
|
+
self.loads = set()
|
21
|
+
self.globals = set()
|
22
|
+
self.shell = shell
|
23
|
+
self.udfs = shell_udfs
|
24
|
+
|
25
|
+
def generic_visit(self, node):
|
26
|
+
ast.NodeVisitor.generic_visit(self, node)
|
27
|
+
|
28
|
+
def visit_Name(self, node):
|
29
|
+
if isinstance(node.ctx, ast.Load):
|
30
|
+
# Only add as input if variable exists in current scope.
|
31
|
+
if not (
|
32
|
+
self.is_local
|
33
|
+
and node.id not in self.globals
|
34
|
+
and node.id in self.shell.user_ns
|
35
|
+
and type(self.shell.user_ns[node.id]) in PRIMITIVES
|
36
|
+
):
|
37
|
+
self.loads.add(node.id)
|
38
|
+
ast.NodeVisitor.generic_visit(self, node)
|
39
|
+
|
40
|
+
def visit_AugAssign(self, node):
|
41
|
+
# Only add as input if variable exists in current scope.
|
42
|
+
if isinstance(node.target, ast.Name):
|
43
|
+
if not (
|
44
|
+
self.is_local
|
45
|
+
and node.target.id not in self.globals
|
46
|
+
and node.target.id in self.shell.user_ns
|
47
|
+
and type(self.shell.user_ns[node.target.id]) in PRIMITIVES
|
48
|
+
):
|
49
|
+
self.loads.add(node.target.id)
|
50
|
+
ast.NodeVisitor.generic_visit(self, node)
|
51
|
+
|
52
|
+
def visit_Global(self, node):
|
53
|
+
for name in node.names:
|
54
|
+
self.globals.add(name)
|
55
|
+
ast.NodeVisitor.generic_visit(self, node)
|
56
|
+
|
57
|
+
def visit_Call(self, node):
|
58
|
+
if isinstance(node.func, ast.Name):
|
59
|
+
self.udfcalls.add(node.func.id)
|
60
|
+
ast.NodeVisitor.generic_visit(self, node)
|
61
|
+
|
62
|
+
def visit_FunctionDef(self, node):
|
63
|
+
# Only add as input if variable exists in current scope
|
64
|
+
self.is_local = True
|
65
|
+
self.functiondefs.add(node.name)
|
66
|
+
ast.NodeVisitor.generic_visit(self, node)
|
67
|
+
self.is_local = False
|
68
|
+
|
69
|
+
|
70
|
+
def find_input_vars(
|
71
|
+
cell: str, existing_variables: set, shell: ZMQInteractiveShell, shell_udfs: set
|
72
|
+
) -> Tuple[set, dict]:
|
73
|
+
"""
|
74
|
+
Capture the input variables of the cell via AST analysis.
|
75
|
+
Args:
|
76
|
+
cell (str): Raw cell cell.
|
77
|
+
existing_variables (set): Set of user-defined variables in the current session.
|
78
|
+
shell (ZMQInteractiveShell): Shell of current session. For inferring variable types.
|
79
|
+
shell_udfs (set): Set of user-declared functions in the shell.
|
80
|
+
"""
|
81
|
+
# Initialize AST walker.
|
82
|
+
v1 = Visitor(shell=shell, shell_udfs=shell_udfs)
|
83
|
+
|
84
|
+
# Parse the cell code.
|
85
|
+
v1.visit(ast.parse(cell))
|
86
|
+
|
87
|
+
# Find top-level input variables and function declarations.
|
88
|
+
input_variables = v1.loads
|
89
|
+
function_defs = v1.functiondefs
|
90
|
+
|
91
|
+
# Recurse into accessed UDFs.
|
92
|
+
udf_calls = deque()
|
93
|
+
udf_calls_visited = set()
|
94
|
+
|
95
|
+
for udf in v1.udfcalls:
|
96
|
+
if udf not in udf_calls_visited and udf in shell_udfs:
|
97
|
+
udf_calls.append(udf)
|
98
|
+
udf_calls_visited.add(udf)
|
99
|
+
|
100
|
+
while udf_calls:
|
101
|
+
# Visit the next nested UDF call.
|
102
|
+
v_nested = Visitor(shell=shell, shell_udfs=shell_udfs)
|
103
|
+
udf = udf_calls.popleft()
|
104
|
+
v_nested.visit(ast.parse(inspect.getsource(shell.user_ns[udf])))
|
105
|
+
|
106
|
+
# Update input variables and function definitions
|
107
|
+
input_variables = input_variables.union(v_nested.loads)
|
108
|
+
function_defs = function_defs.union(v_nested.functiondefs)
|
109
|
+
for udf in v_nested.udfcalls:
|
110
|
+
if udf not in udf_calls_visited and udf in shell_udfs:
|
111
|
+
udf_calls.append(udf)
|
112
|
+
udf_calls_visited.add(udf)
|
113
|
+
|
114
|
+
# A variable is an input only if it is in the shell before cell execution.
|
115
|
+
input_variables = input_variables.intersection(existing_variables)
|
116
|
+
|
117
|
+
return input_variables, function_defs
|
@@ -0,0 +1,18 @@
|
|
1
|
+
def find_created_deleted_vars(pre_execution, post_execution):
|
2
|
+
"""
|
3
|
+
Find created and deleted variables through computing a difference of the user namespace pre and post execution.
|
4
|
+
"""
|
5
|
+
created_variables = set()
|
6
|
+
deleted_variables = set()
|
7
|
+
|
8
|
+
# New variables
|
9
|
+
for varname in post_execution.difference(pre_execution):
|
10
|
+
if not varname.startswith("_"):
|
11
|
+
created_variables.add(varname)
|
12
|
+
|
13
|
+
# Deleted variables
|
14
|
+
for varname in pre_execution.difference(post_execution):
|
15
|
+
if not varname.startswith("_"):
|
16
|
+
deleted_variables.add(varname)
|
17
|
+
|
18
|
+
return created_variables, deleted_variables
|