elastic-kernel 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. elastic_kernel/__init__.py +0 -0
  2. elastic_kernel/command.py +43 -0
  3. elastic_kernel/kernel.json +5 -0
  4. elastic_kernel/kernel.py +258 -0
  5. elastic_kernel-0.0.2.dist-info/METADATA +291 -0
  6. elastic_kernel-0.0.2.dist-info/RECORD +42 -0
  7. elastic_kernel-0.0.2.dist-info/WHEEL +5 -0
  8. elastic_kernel-0.0.2.dist-info/entry_points.txt +2 -0
  9. elastic_kernel-0.0.2.dist-info/licenses/LICENSE +201 -0
  10. elastic_kernel-0.0.2.dist-info/top_level.txt +2 -0
  11. elastic_notebook/__init__.py +0 -0
  12. elastic_notebook/algorithm/__init__.py +0 -0
  13. elastic_notebook/algorithm/baseline.py +31 -0
  14. elastic_notebook/algorithm/optimizer_exact.py +121 -0
  15. elastic_notebook/algorithm/selector.py +41 -0
  16. elastic_notebook/core/__init__.py +0 -0
  17. elastic_notebook/core/common/__init__.py +0 -0
  18. elastic_notebook/core/common/checkpoint_file.py +129 -0
  19. elastic_notebook/core/common/profile_graph_size.py +39 -0
  20. elastic_notebook/core/common/profile_migration_speed.py +69 -0
  21. elastic_notebook/core/common/profile_variable_size.py +66 -0
  22. elastic_notebook/core/graph/__init__.py +0 -0
  23. elastic_notebook/core/graph/cell_execution.py +39 -0
  24. elastic_notebook/core/graph/graph.py +75 -0
  25. elastic_notebook/core/graph/variable_snapshot.py +31 -0
  26. elastic_notebook/core/io/__init__.py +0 -0
  27. elastic_notebook/core/io/adapter.py +18 -0
  28. elastic_notebook/core/io/filesystem_adapter.py +30 -0
  29. elastic_notebook/core/io/migrate.py +98 -0
  30. elastic_notebook/core/io/pickle.py +71 -0
  31. elastic_notebook/core/io/recover.py +51 -0
  32. elastic_notebook/core/mutation/__init__.py +0 -0
  33. elastic_notebook/core/mutation/fingerprint.py +184 -0
  34. elastic_notebook/core/mutation/id_graph.py +147 -0
  35. elastic_notebook/core/mutation/object_hash.py +204 -0
  36. elastic_notebook/core/notebook/__init__.py +0 -0
  37. elastic_notebook/core/notebook/checkpoint.py +222 -0
  38. elastic_notebook/core/notebook/find_input_vars.py +117 -0
  39. elastic_notebook/core/notebook/find_output_vars.py +18 -0
  40. elastic_notebook/core/notebook/restore_notebook.py +91 -0
  41. elastic_notebook/core/notebook/update_graph.py +46 -0
  42. elastic_notebook/elastic_notebook.py +336 -0
@@ -0,0 +1,147 @@
1
+ from inspect import isclass
2
+ from types import ModuleType
3
+
4
+ PRIMITIVES = {type(None), int, float, bool, str}
5
+ ITERABLES = {tuple, set, list}
6
+ OBJECT_FILTER_FUNC = lambda x: not x[0].startswith("_") and not is_primitive(type(x[1]))
7
+ ITERABLE_FILTER_FUNC = lambda x: not is_primitive(x)
8
+
9
+
10
+ class IdGraphNode:
11
+ """
12
+ The IdGraph is used to model the reachable objects (their ids and types) from a given object.
13
+ """
14
+
15
+ def __init__(self, obj_id, obj_type, child_nodes):
16
+ """
17
+ Args:
18
+ obj_id: id (memory address) of the object.
19
+ obj_type: type of the object.
20
+ child_nodes: other objects reachable from this object (i.e. class attributes, list/set members)
21
+ """
22
+ self.obj_id = obj_id
23
+ self.obj_type = obj_type
24
+ self.child_nodes = child_nodes
25
+
26
+
27
+ def is_primitive(obj_type):
28
+ return obj_type in PRIMITIVES
29
+
30
+
31
+ def is_root_equals(node1, node2):
32
+ """
33
+ Compare only the root notes of 2 ID graphs.
34
+ """
35
+ if node1 is None and node2 is None:
36
+ return True
37
+ elif node1 is None or node2 is None:
38
+ return False
39
+ return node1.obj_id == node2.obj_id and node1.obj_type == node2.obj_type
40
+
41
+
42
+ def is_structure_equals_helper(node1, node2, visited):
43
+ """
44
+ BFS helper for recursively comparing ID graphs.
45
+ """
46
+ if (
47
+ node1.obj_id != node2.obj_id
48
+ or node1.obj_type != node2.obj_type
49
+ or len(node1.child_nodes) != len(node2.child_nodes)
50
+ ):
51
+ return False
52
+ visited.add(node1)
53
+
54
+ for i in range(len(node1.child_nodes)):
55
+ if node1.child_nodes[i] in visited:
56
+ continue
57
+ if not is_structure_equals_helper(
58
+ node1.child_nodes[i], node2.child_nodes[i], visited
59
+ ):
60
+ return False
61
+
62
+ return True
63
+
64
+
65
+ def is_structure_equals(node1, node2):
66
+ """
67
+ Compare 2 ID graphs for equality.
68
+ """
69
+ if node1 is None and node2 is None:
70
+ return True
71
+ elif node1 is None or node2 is None:
72
+ return False
73
+ return is_structure_equals_helper(node1, node2, set())
74
+
75
+
76
+ def construct_id_graph_node(obj, visited):
77
+ """
78
+ Helper function for constructing an ID graph. Constructs an ID graph node and recurses into reachable objects
79
+ in BFS fashion.
80
+ Args:
81
+ obj: object.
82
+ visited (set): set of visited objects (for handling cyclic references).
83
+ """
84
+
85
+ if obj is None:
86
+ return None
87
+
88
+ # Construct ID graph node.
89
+ T = type(obj)
90
+ child_list = []
91
+ id_graph_node = IdGraphNode(id(obj), T.__name__, child_list)
92
+
93
+ visited[id(obj)] = id_graph_node
94
+
95
+ # obj is Iterable: recurse into items
96
+ if T in ITERABLES:
97
+ for child_obj in obj:
98
+ if is_primitive(type(child_obj)):
99
+ continue
100
+ if id(child_obj) in visited:
101
+ child_list.append(visited[id(child_obj)])
102
+ else:
103
+ child_list.append(construct_id_graph_node(child_obj, visited))
104
+
105
+ # obj is object instance: recurse into attributes
106
+ # Special note for sets: this works as the order of iterating through sets in the same session is deterministic.
107
+ elif (
108
+ hasattr(obj, "__dict__")
109
+ and not isinstance(obj, ModuleType)
110
+ and not isclass(obj)
111
+ ):
112
+ for k, v in filter(OBJECT_FILTER_FUNC, vars(obj).items()):
113
+ if id(v) in visited:
114
+ child_list.append(visited[id(v)])
115
+ else:
116
+ child_list.append(construct_id_graph_node(v, visited))
117
+
118
+ # obj is dictionary: recurse into both keys and values
119
+ elif T is dict:
120
+ for k, v in obj.items():
121
+ if is_primitive(type(k)):
122
+ continue
123
+ if id(k) in visited:
124
+ child_list.append(visited[id(k)])
125
+ else:
126
+ child_list.append(construct_id_graph_node(k, visited))
127
+
128
+ if is_primitive(type(v)):
129
+ continue
130
+ if id(v) in visited:
131
+ child_list.append(visited[id(v)])
132
+ else:
133
+ child_list.append(construct_id_graph_node(v, visited))
134
+
135
+ return id_graph_node
136
+
137
+
138
+ def construct_id_graph(obj):
139
+ """
140
+ Construct the ID graph for an object.
141
+ Returns:
142
+ graph_bytestring (bytestring): the serialized bytestring of the ID graph.
143
+ visited: a set of all IDs of reachable objects in the graph. For checking object overlaps.
144
+ """
145
+ visited = {}
146
+ graph = construct_id_graph_node(obj, visited)
147
+ return graph, set(visited.keys())
@@ -0,0 +1,204 @@
1
+ import copy
2
+ import io
3
+ from inspect import isclass
4
+ from types import FunctionType, ModuleType
5
+
6
+ import dill
7
+
8
+ # import polars as pl
9
+ import lightgbm
10
+ import numpy as np
11
+ import pandas as pd
12
+ import scipy
13
+ import torch
14
+ import xxhash
15
+
16
+ BASE_TYPES = [type(None), FunctionType]
17
+
18
+
19
+ class ImmutableObj:
20
+ def __init__(self):
21
+ pass
22
+
23
+ def __eq__(self, other):
24
+ if isinstance(other, ImmutableObj):
25
+ return True
26
+ return False
27
+
28
+
29
+ # Object representing none.
30
+ class NoneObj:
31
+ def __init__(self):
32
+ pass
33
+
34
+ def __eq__(self, other):
35
+ if isinstance(other, NoneObj):
36
+ return True
37
+ return False
38
+
39
+
40
+ # Object representing a dataframe.
41
+ class DataframeObj:
42
+ def __init__(self):
43
+ pass
44
+
45
+ def __eq__(self, other):
46
+ if isinstance(other, DataframeObj):
47
+ return True
48
+ return False
49
+
50
+
51
+ class NxGraphObj:
52
+ def __init__(self, graph):
53
+ self.graph = graph
54
+
55
+ def __eq__(self, other):
56
+ if isinstance(other, NxGraphObj):
57
+ return nx.graphs_equal(self.graph, other.graph)
58
+ return False
59
+
60
+
61
+ class NpArrayObj:
62
+ def __init__(self, arraystr):
63
+ self.arraystr = arraystr
64
+ pass
65
+
66
+ def __eq__(self, other):
67
+ if isinstance(other, NpArrayObj):
68
+ return self.arraystr == other.arraystr
69
+ return False
70
+
71
+
72
+ class ScipyArrayObj:
73
+ def __init__(self, arraystr):
74
+ self.arraystr = arraystr
75
+ pass
76
+
77
+ def __eq__(self, other):
78
+ if isinstance(other, ScipyArrayObj):
79
+ return self.arraystr == other.arraystr
80
+ return False
81
+
82
+
83
+ class TorchTensorObj:
84
+ def __init__(self, arraystr):
85
+ self.arraystr = arraystr
86
+ pass
87
+
88
+ def __eq__(self, other):
89
+ if isinstance(other, TorchTensorObj):
90
+ return self.arraystr == other.arraystr
91
+ return False
92
+
93
+
94
+ class ModuleObj:
95
+ def __init__(self):
96
+ pass
97
+
98
+ def __eq__(self, other):
99
+ if isinstance(other, ModuleObj):
100
+ return True
101
+ return False
102
+
103
+
104
+ # Object representing general unserializable class.
105
+ class UnserializableObj:
106
+ def __init__(self):
107
+ pass
108
+
109
+ def __eq__(self, other):
110
+ if isinstance(other, UnserializableObj):
111
+ return True
112
+ return False
113
+
114
+
115
+ class UncomparableObj:
116
+ def __init__(self):
117
+ pass
118
+
119
+ def __eq__(self, other):
120
+ if isinstance(other, UncomparableObj):
121
+ return True
122
+ return False
123
+
124
+
125
+ def construct_object_hash(obj, deepcopy=False):
126
+ """
127
+ Construct an object hash for the object. Uses deep-copy as a fallback.
128
+ """
129
+
130
+ if type(obj) in BASE_TYPES:
131
+ return ImmutableObj()
132
+
133
+ if isclass(obj):
134
+ return type(obj)
135
+
136
+ # Flag hack for Pandas dataframes: each dataframe column is a numpy array.
137
+ # All the writeable flags of these arrays are set to false; if after cell execution, any of these flags are
138
+ # reset to True, we assume that the dataframe has been modified.
139
+ if isinstance(obj, pd.DataFrame):
140
+ for _, col in obj.items():
141
+ col.__array__().flags.writeable = False
142
+ return DataframeObj()
143
+
144
+ if isinstance(obj, pd.Series):
145
+ obj.__array__().flags.writeable = False
146
+ return DataframeObj()
147
+
148
+ attr_str = getattr(obj, "__module__", None)
149
+ if attr_str and (
150
+ "matplotlib" in attr_str
151
+ or "transformers" in attr_str
152
+ or "networkx" in attr_str
153
+ or "keras" in attr_str
154
+ or "tensorflow" in attr_str
155
+ ):
156
+ return UncomparableObj()
157
+
158
+ # Object is file handle
159
+ if isinstance(obj, io.IOBase):
160
+ return UncomparableObj()
161
+
162
+ if isinstance(obj, np.ndarray):
163
+ h = xxhash.xxh3_128()
164
+ h.update(np.ascontiguousarray(obj.data))
165
+ str1 = h.intdigest()
166
+ return NpArrayObj(str1)
167
+
168
+ if isinstance(obj, scipy.sparse.csr_matrix):
169
+ h = xxhash.xxh3_128()
170
+ h.update(np.ascontiguousarray(obj))
171
+ str1 = h.intdigest()
172
+ return ScipyArrayObj(str1)
173
+
174
+ if isinstance(obj, torch.Tensor):
175
+ h = xxhash.xxh3_128()
176
+ h.update(np.ascontiguousarray(obj))
177
+ str1 = h.intdigest()
178
+ return TorchTensorObj(str1)
179
+
180
+ if isinstance(obj, ModuleType) or isclass(obj):
181
+ return ModuleObj()
182
+
183
+ # Polars dataframes are immutable.
184
+ # if isinstance(obj, pl.DataFrame):
185
+ # return type(obj)
186
+
187
+ # LightGBM dataframes are immutable.
188
+ if isinstance(obj, lightgbm.Dataset):
189
+ return type(obj)
190
+
191
+ # Try to hash the object; if the object is unhashable, use deepcopy as fallback.
192
+ try:
193
+ h = xxhash.xxh3_128()
194
+ h.update(obj)
195
+ return h.intdigest()
196
+ except:
197
+ try:
198
+ if deepcopy:
199
+ return copy.deepcopy(obj)
200
+ else:
201
+ return obj
202
+ except:
203
+ # If object is not even deepcopy-able, mark it as unserializable and assume modified-on-write.
204
+ return UnserializableObj()
File without changes
@@ -0,0 +1,222 @@
1
+ import time
2
+ from typing import Dict
3
+
4
+ import numpy as np
5
+ from ipykernel.zmqshell import ZMQInteractiveShell
6
+
7
+ from elastic_notebook.algorithm.selector import Selector
8
+ from elastic_notebook.core.common.profile_variable_size import profile_variable_size
9
+ from elastic_notebook.core.graph.graph import DependencyGraph
10
+ from elastic_notebook.core.io.migrate import migrate
11
+ from elastic_notebook.core.mutation.object_hash import UnserializableObj
12
+
13
+
14
+ def checkpoint(
15
+ graph: DependencyGraph,
16
+ shell: ZMQInteractiveShell,
17
+ fingerprint_dict: Dict,
18
+ selector: Selector,
19
+ udfs: set,
20
+ filename: str,
21
+ profile_dict,
22
+ write_log_location=None,
23
+ notebook_name=None,
24
+ optimizer_name=None,
25
+ elastic_notebook=None,
26
+ ):
27
+ """
28
+ Checkpoints the notebook. The optimizer selects the VSs to migrate and recompute and the OEs to recompute, then
29
+ writes the checkpoint as the specified filename.
30
+ Args:
31
+ graph (DependencyGraph): dependency graph representation of the notebook.
32
+ shell (ZMQInteractiveShell): interactive Jupyter shell storing the state of the current session.
33
+ selector (Selector): optimizer for computing the checkpointing configuration.
34
+ udfs (set): set of user-declared functions.
35
+ filename (str): location to write the file to.
36
+ write_log_location (str): location to write component runtimes to. For experimentation only.
37
+ notebook_name (str): notebook name. For experimentation only.
38
+ optimizer_name (str): optimizer name. For experimentation only.
39
+ elastic_notebook (ElasticNotebook, optional): ElasticNotebook instance to update migration lists.
40
+ """
41
+ profile_start = time.time()
42
+
43
+ # Retrieve active VSs from the graph. Active VSs are correspond to the latest instances/versions of each variable.
44
+ active_vss = set()
45
+ print("---------------------------")
46
+ print("all variables:")
47
+ # print(graph.variable_snapshots)
48
+ for vs_list in graph.variable_snapshots.values():
49
+ if not vs_list[-1].deleted:
50
+ print(f"name: {vs_list[-1].name}")
51
+ active_vss.add(vs_list[-1])
52
+
53
+ # Profile the size of each variable defined in the current session.
54
+ for active_vs in active_vss:
55
+ # 変数がfingerprint_dictに存在するかチェック
56
+ if active_vs.name not in fingerprint_dict:
57
+ print(f"Warning: Variable '{active_vs.name}' not found in fingerprint_dict")
58
+ continue
59
+
60
+ attr_str = getattr(shell.user_ns[active_vs.name], "__module__", None)
61
+ # Object is unserializable
62
+ if isinstance(fingerprint_dict[active_vs.name][2], UnserializableObj):
63
+ active_vs.size = np.inf
64
+
65
+ # Blacklisted object
66
+ elif attr_str and ("dataprep.eda" in attr_str or "bokeh" in attr_str):
67
+ active_vs.size = np.inf
68
+
69
+ # Profile size of object.
70
+ else:
71
+ active_vs.size = profile_variable_size(shell.user_ns[active_vs.name])
72
+
73
+ # Check for pairwise variable intersections. Variables sharing underlying data must be migrated or recomputed
74
+ # together.
75
+ overlapping_vss = []
76
+ for active_vs1 in active_vss:
77
+ for active_vs2 in active_vss:
78
+ if active_vs1 != active_vs2:
79
+ # 両方の変数がfingerprint_dictに存在するかチェック
80
+ if (
81
+ active_vs1.name not in fingerprint_dict
82
+ or active_vs2.name not in fingerprint_dict
83
+ ):
84
+ continue
85
+ if fingerprint_dict[active_vs1.name][1].intersection(
86
+ fingerprint_dict[active_vs2.name][1]
87
+ ):
88
+ overlapping_vss.append((active_vs1, active_vs2))
89
+
90
+ profile_end = time.time()
91
+ if write_log_location:
92
+ with open(
93
+ write_log_location
94
+ + "/output_"
95
+ + notebook_name
96
+ + "_"
97
+ + optimizer_name
98
+ + ".txt",
99
+ "a",
100
+ ) as f:
101
+ f.write("overlappings - " + repr(len(overlapping_vss)) + "\n")
102
+ f.write(
103
+ "Profile stage took - "
104
+ + repr(profile_start - profile_end)
105
+ + " seconds"
106
+ + "\n"
107
+ )
108
+ f.write(
109
+ "Idgraph stage took - "
110
+ + repr(profile_dict["idgraph"])
111
+ + " seconds"
112
+ + "\n"
113
+ )
114
+ f.write(
115
+ "Representation stage took - "
116
+ + repr(profile_dict["representation"])
117
+ + " seconds"
118
+ + "\n"
119
+ )
120
+
121
+ optimize_start = time.time()
122
+ # Initialize the optimizer.
123
+ add_start = time.time()
124
+ selector.dependency_graph = graph
125
+ selector.active_vss = active_vss
126
+ selector.overlapping_vss = overlapping_vss
127
+ add_end = time.time()
128
+
129
+ # Use the optimizer to compute the checkpointing configuration.
130
+ opt_start = time.time()
131
+ vss_to_migrate, ces_to_recompute = selector.select_vss(
132
+ write_log_location, notebook_name, optimizer_name
133
+ )
134
+ opt_end = time.time()
135
+ print("---------------------------")
136
+ print("variables to migrate:")
137
+ for vs in vss_to_migrate:
138
+ print(f"name: {vs.name}, size: {vs.size}")
139
+
140
+ difference_start = time.time()
141
+ vss_to_recompute = active_vss - vss_to_migrate
142
+ difference_end = time.time()
143
+
144
+ print("---------------------------")
145
+ print("variables to recompute:")
146
+ for vs in vss_to_recompute:
147
+ print(f"name: {vs.name}, size: {vs.size}")
148
+ print([vs.name for vs in vss_to_recompute])
149
+
150
+ print("---------------------------")
151
+ print("cells to recompute:")
152
+ for ce in ces_to_recompute:
153
+ print(f"cell num: {ce.cell_num}, cell runtime: {ce.cell_runtime}")
154
+ print(sorted([ce.cell_num + 1 for ce in ces_to_recompute]))
155
+
156
+ optimize_end = time.time()
157
+ if write_log_location:
158
+ with open(
159
+ write_log_location
160
+ + "/output_"
161
+ + notebook_name
162
+ + "_"
163
+ + optimizer_name
164
+ + ".txt",
165
+ "a",
166
+ ) as f:
167
+ f.write(
168
+ "Optimize stage took - "
169
+ + repr(optimize_end - optimize_start)
170
+ + " seconds"
171
+ + "\n"
172
+ )
173
+ f.write(
174
+ " Add stage took - " + repr(add_end - add_start) + " seconds" + "\n"
175
+ )
176
+ f.write(
177
+ " Opt stage took - " + repr(opt_end - opt_start) + " seconds" + "\n"
178
+ )
179
+ f.write(
180
+ " Diff stage took - "
181
+ + repr(difference_end - difference_start)
182
+ + " seconds"
183
+ + "\n"
184
+ )
185
+
186
+ # 変数リストを更新
187
+ if elastic_notebook is not None:
188
+ elastic_notebook.update_migration_lists(vss_to_migrate, vss_to_recompute)
189
+
190
+ # Store the notebook checkpoint to the specified location.
191
+ migrate_start = time.time()
192
+ migrate_success = True
193
+ migrate(
194
+ graph,
195
+ shell,
196
+ vss_to_migrate,
197
+ vss_to_recompute,
198
+ ces_to_recompute,
199
+ udfs,
200
+ selector.recomputation_ces,
201
+ selector.overlapping_vss,
202
+ filename,
203
+ )
204
+ migrate_end = time.time()
205
+
206
+ if write_log_location:
207
+ with open(
208
+ write_log_location
209
+ + "/output_"
210
+ + notebook_name
211
+ + "_"
212
+ + optimizer_name
213
+ + ".txt",
214
+ "a",
215
+ ) as f:
216
+ f.write(
217
+ "Migrate stage took - "
218
+ + repr(migrate_end - migrate_start)
219
+ + " seconds"
220
+ + "\n"
221
+ )
222
+ return migrate_success
@@ -0,0 +1,117 @@
1
+ import ast
2
+ import inspect
3
+ from collections import deque
4
+ from typing import Tuple
5
+
6
+ from ipykernel.zmqshell import ZMQInteractiveShell
7
+
8
+ PRIMITIVES = {int, bool, str, float}
9
+
10
+
11
+ # Node visitor for finding input variables.
12
+ class Visitor(ast.NodeVisitor):
13
+ def __init__(self, shell, shell_udfs):
14
+ # Whether we are currently in local scope.
15
+ self.is_local = False
16
+
17
+ # Functions declared in
18
+ self.functiondefs = set()
19
+ self.udfcalls = set()
20
+ self.loads = set()
21
+ self.globals = set()
22
+ self.shell = shell
23
+ self.udfs = shell_udfs
24
+
25
+ def generic_visit(self, node):
26
+ ast.NodeVisitor.generic_visit(self, node)
27
+
28
+ def visit_Name(self, node):
29
+ if isinstance(node.ctx, ast.Load):
30
+ # Only add as input if variable exists in current scope.
31
+ if not (
32
+ self.is_local
33
+ and node.id not in self.globals
34
+ and node.id in self.shell.user_ns
35
+ and type(self.shell.user_ns[node.id]) in PRIMITIVES
36
+ ):
37
+ self.loads.add(node.id)
38
+ ast.NodeVisitor.generic_visit(self, node)
39
+
40
+ def visit_AugAssign(self, node):
41
+ # Only add as input if variable exists in current scope.
42
+ if isinstance(node.target, ast.Name):
43
+ if not (
44
+ self.is_local
45
+ and node.target.id not in self.globals
46
+ and node.target.id in self.shell.user_ns
47
+ and type(self.shell.user_ns[node.target.id]) in PRIMITIVES
48
+ ):
49
+ self.loads.add(node.target.id)
50
+ ast.NodeVisitor.generic_visit(self, node)
51
+
52
+ def visit_Global(self, node):
53
+ for name in node.names:
54
+ self.globals.add(name)
55
+ ast.NodeVisitor.generic_visit(self, node)
56
+
57
+ def visit_Call(self, node):
58
+ if isinstance(node.func, ast.Name):
59
+ self.udfcalls.add(node.func.id)
60
+ ast.NodeVisitor.generic_visit(self, node)
61
+
62
+ def visit_FunctionDef(self, node):
63
+ # Only add as input if variable exists in current scope
64
+ self.is_local = True
65
+ self.functiondefs.add(node.name)
66
+ ast.NodeVisitor.generic_visit(self, node)
67
+ self.is_local = False
68
+
69
+
70
+ def find_input_vars(
71
+ cell: str, existing_variables: set, shell: ZMQInteractiveShell, shell_udfs: set
72
+ ) -> Tuple[set, dict]:
73
+ """
74
+ Capture the input variables of the cell via AST analysis.
75
+ Args:
76
+ cell (str): Raw cell cell.
77
+ existing_variables (set): Set of user-defined variables in the current session.
78
+ shell (ZMQInteractiveShell): Shell of current session. For inferring variable types.
79
+ shell_udfs (set): Set of user-declared functions in the shell.
80
+ """
81
+ # Initialize AST walker.
82
+ v1 = Visitor(shell=shell, shell_udfs=shell_udfs)
83
+
84
+ # Parse the cell code.
85
+ v1.visit(ast.parse(cell))
86
+
87
+ # Find top-level input variables and function declarations.
88
+ input_variables = v1.loads
89
+ function_defs = v1.functiondefs
90
+
91
+ # Recurse into accessed UDFs.
92
+ udf_calls = deque()
93
+ udf_calls_visited = set()
94
+
95
+ for udf in v1.udfcalls:
96
+ if udf not in udf_calls_visited and udf in shell_udfs:
97
+ udf_calls.append(udf)
98
+ udf_calls_visited.add(udf)
99
+
100
+ while udf_calls:
101
+ # Visit the next nested UDF call.
102
+ v_nested = Visitor(shell=shell, shell_udfs=shell_udfs)
103
+ udf = udf_calls.popleft()
104
+ v_nested.visit(ast.parse(inspect.getsource(shell.user_ns[udf])))
105
+
106
+ # Update input variables and function definitions
107
+ input_variables = input_variables.union(v_nested.loads)
108
+ function_defs = function_defs.union(v_nested.functiondefs)
109
+ for udf in v_nested.udfcalls:
110
+ if udf not in udf_calls_visited and udf in shell_udfs:
111
+ udf_calls.append(udf)
112
+ udf_calls_visited.add(udf)
113
+
114
+ # A variable is an input only if it is in the shell before cell execution.
115
+ input_variables = input_variables.intersection(existing_variables)
116
+
117
+ return input_variables, function_defs
@@ -0,0 +1,18 @@
1
+ def find_created_deleted_vars(pre_execution, post_execution):
2
+ """
3
+ Find created and deleted variables through computing a difference of the user namespace pre and post execution.
4
+ """
5
+ created_variables = set()
6
+ deleted_variables = set()
7
+
8
+ # New variables
9
+ for varname in post_execution.difference(pre_execution):
10
+ if not varname.startswith("_"):
11
+ created_variables.add(varname)
12
+
13
+ # Deleted variables
14
+ for varname in pre_execution.difference(post_execution):
15
+ if not varname.startswith("_"):
16
+ deleted_variables.add(varname)
17
+
18
+ return created_variables, deleted_variables