elastic-kernel 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- elastic_kernel/__init__.py +0 -0
- elastic_kernel/command.py +43 -0
- elastic_kernel/kernel.json +5 -0
- elastic_kernel/kernel.py +258 -0
- elastic_kernel-0.0.2.dist-info/METADATA +291 -0
- elastic_kernel-0.0.2.dist-info/RECORD +42 -0
- elastic_kernel-0.0.2.dist-info/WHEEL +5 -0
- elastic_kernel-0.0.2.dist-info/entry_points.txt +2 -0
- elastic_kernel-0.0.2.dist-info/licenses/LICENSE +201 -0
- elastic_kernel-0.0.2.dist-info/top_level.txt +2 -0
- elastic_notebook/__init__.py +0 -0
- elastic_notebook/algorithm/__init__.py +0 -0
- elastic_notebook/algorithm/baseline.py +31 -0
- elastic_notebook/algorithm/optimizer_exact.py +121 -0
- elastic_notebook/algorithm/selector.py +41 -0
- elastic_notebook/core/__init__.py +0 -0
- elastic_notebook/core/common/__init__.py +0 -0
- elastic_notebook/core/common/checkpoint_file.py +129 -0
- elastic_notebook/core/common/profile_graph_size.py +39 -0
- elastic_notebook/core/common/profile_migration_speed.py +69 -0
- elastic_notebook/core/common/profile_variable_size.py +66 -0
- elastic_notebook/core/graph/__init__.py +0 -0
- elastic_notebook/core/graph/cell_execution.py +39 -0
- elastic_notebook/core/graph/graph.py +75 -0
- elastic_notebook/core/graph/variable_snapshot.py +31 -0
- elastic_notebook/core/io/__init__.py +0 -0
- elastic_notebook/core/io/adapter.py +18 -0
- elastic_notebook/core/io/filesystem_adapter.py +30 -0
- elastic_notebook/core/io/migrate.py +98 -0
- elastic_notebook/core/io/pickle.py +71 -0
- elastic_notebook/core/io/recover.py +51 -0
- elastic_notebook/core/mutation/__init__.py +0 -0
- elastic_notebook/core/mutation/fingerprint.py +184 -0
- elastic_notebook/core/mutation/id_graph.py +147 -0
- elastic_notebook/core/mutation/object_hash.py +204 -0
- elastic_notebook/core/notebook/__init__.py +0 -0
- elastic_notebook/core/notebook/checkpoint.py +222 -0
- elastic_notebook/core/notebook/find_input_vars.py +117 -0
- elastic_notebook/core/notebook/find_output_vars.py +18 -0
- elastic_notebook/core/notebook/restore_notebook.py +91 -0
- elastic_notebook/core/notebook/update_graph.py +46 -0
- elastic_notebook/elastic_notebook.py +336 -0
@@ -0,0 +1,66 @@
|
|
1
|
+
import sys
|
2
|
+
import types
|
3
|
+
|
4
|
+
|
5
|
+
def get_total_size(data):
|
6
|
+
"""
|
7
|
+
Compute the estimated total size of a variable.
|
8
|
+
"""
|
9
|
+
|
10
|
+
def get_memory_size(obj, is_initialize, visited):
|
11
|
+
# same memory space should be calculated only once
|
12
|
+
obj_id = id(obj)
|
13
|
+
if obj_id in visited:
|
14
|
+
return 0
|
15
|
+
visited.add(obj_id)
|
16
|
+
|
17
|
+
try:
|
18
|
+
total_size = sys.getsizeof(obj)
|
19
|
+
except Exception:
|
20
|
+
total_size = float("inf")
|
21
|
+
|
22
|
+
obj_type = type(obj)
|
23
|
+
if obj_type in [int, float, str, bool, type(None)]:
|
24
|
+
# if the original obj is not primitive, then the size is already included
|
25
|
+
if not is_initialize:
|
26
|
+
return 0
|
27
|
+
else:
|
28
|
+
if obj_type in [list, tuple, set]:
|
29
|
+
for e in obj:
|
30
|
+
total_size = total_size + get_memory_size(e, False, visited)
|
31
|
+
elif obj_type is dict:
|
32
|
+
for k, v in obj.items():
|
33
|
+
total_size = total_size + get_memory_size(k, False, visited)
|
34
|
+
total_size = total_size + get_memory_size(v, False, visited)
|
35
|
+
# function, method, class
|
36
|
+
elif obj_type in [
|
37
|
+
types.FunctionType,
|
38
|
+
types.MethodType,
|
39
|
+
types.BuiltinFunctionType,
|
40
|
+
types.ModuleType,
|
41
|
+
] or isinstance(
|
42
|
+
obj, type
|
43
|
+
): # True if obj is a class
|
44
|
+
pass
|
45
|
+
# custom class instance
|
46
|
+
elif isinstance(type(obj), type):
|
47
|
+
# if obj has no builtin size and has additional pointers
|
48
|
+
# if obj has builtin size, all the additional memory space is already added
|
49
|
+
if not hasattr(obj, "__sizeof__") and hasattr(obj, "__dict__"):
|
50
|
+
for k, v in getattr(obj, "__dict__").items():
|
51
|
+
total_size = total_size + get_memory_size(k, False, visited)
|
52
|
+
total_size = total_size + get_memory_size(v, False, visited)
|
53
|
+
else:
|
54
|
+
raise NotImplementedError("Not handled", obj)
|
55
|
+
return total_size
|
56
|
+
|
57
|
+
return get_memory_size(data, True, set())
|
58
|
+
|
59
|
+
|
60
|
+
def profile_variable_size(x) -> int:
|
61
|
+
"""
|
62
|
+
Profiles the size of variable x. Notably, this should recursively find the size of lists, sets and dictionaries.
|
63
|
+
Args:
|
64
|
+
x: The variable to profile.
|
65
|
+
"""
|
66
|
+
return get_total_size(x)
|
File without changes
|
@@ -0,0 +1,39 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
#
|
4
|
+
# Copyright 2021-2022 University of Illinois
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
|
8
|
+
class CellExecution:
|
9
|
+
"""
|
10
|
+
A cell execution (object) corresponds to a cell execution (action, i.e. press play) in the notebook session.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
cell_num: int,
|
16
|
+
cell: str,
|
17
|
+
cell_runtime: float,
|
18
|
+
start_time: float,
|
19
|
+
src_vss: List,
|
20
|
+
dst_vss: List,
|
21
|
+
):
|
22
|
+
"""
|
23
|
+
Create an operation event from cell execution metrics.
|
24
|
+
Args:
|
25
|
+
cell_num (int): The nth cell execution of the current session.
|
26
|
+
cell (str): Raw cell cell.
|
27
|
+
cell_runtime (float): Cell runtime.
|
28
|
+
start_time (time): Time of start of cell execution. Note that this is different from when the cell was
|
29
|
+
queued.
|
30
|
+
src_vss (List[VariableSnapshot]): Nodeset containing input VSs of the cell execution.
|
31
|
+
dst_vss (List[VariableSnapshot]): Nodeset containing output VSs of the cell execution.
|
32
|
+
"""
|
33
|
+
self.cell_num = cell_num
|
34
|
+
self.cell = cell
|
35
|
+
self.cell_runtime = cell_runtime
|
36
|
+
self.start_time = start_time
|
37
|
+
|
38
|
+
self.src_vss = src_vss
|
39
|
+
self.dst_vss = dst_vss
|
@@ -0,0 +1,75 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from elastic_notebook.core.graph.cell_execution import CellExecution
|
5
|
+
from elastic_notebook.core.graph.variable_snapshot import VariableSnapshot
|
6
|
+
|
7
|
+
|
8
|
+
class DependencyGraph:
|
9
|
+
"""
|
10
|
+
A dependency graph is a snapshot of the history of a notebook instance.
|
11
|
+
Nodesets and operation events are the nodes and edges of the dependency graph.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
"""
|
16
|
+
Create a new dependency graph. Called when the magic extension of elastic notebook is loaded with %load_ext.
|
17
|
+
"""
|
18
|
+
# Cell executions.
|
19
|
+
self.cell_executions = []
|
20
|
+
|
21
|
+
# Dict of variable snapshots.
|
22
|
+
# Keys are variable names, while values are lists of the actual VSs.
|
23
|
+
# i.e. {"x": [(x, 1), (x, 2)], "y": [(y, 1), (y, 2), (y, 3)]}
|
24
|
+
self.variable_snapshots = defaultdict(list)
|
25
|
+
|
26
|
+
def create_variable_snapshot(
|
27
|
+
self, variable_name: str, deleted: bool
|
28
|
+
) -> VariableSnapshot:
|
29
|
+
"""
|
30
|
+
Creates a new variable snapshot for a given variable.
|
31
|
+
Args:
|
32
|
+
variable_name (str): variable_name
|
33
|
+
deleted (bool): Whether this VS is created for the deletion of a variable, i.e. 'del x'.
|
34
|
+
"""
|
35
|
+
|
36
|
+
# Assign a version number to the VS.
|
37
|
+
if variable_name in self.variable_snapshots:
|
38
|
+
version = len(self.variable_snapshots[variable_name])
|
39
|
+
else:
|
40
|
+
version = 0
|
41
|
+
|
42
|
+
# Create a new VS instance and store it in the graph.
|
43
|
+
vs = VariableSnapshot(variable_name, version, deleted)
|
44
|
+
self.variable_snapshots[variable_name].append(vs)
|
45
|
+
return vs
|
46
|
+
|
47
|
+
def add_cell_execution(
|
48
|
+
self, cell, cell_runtime: float, start_time: float, src_vss: List, dst_vss: List
|
49
|
+
):
|
50
|
+
"""
|
51
|
+
Create a cell execution from captured metrics.
|
52
|
+
Args:
|
53
|
+
cell (str): Raw cell cell.
|
54
|
+
cell_runtime (float): Cell runtime.
|
55
|
+
start_time (time): Time of start of cell execution. Note that this is different from when the cell was
|
56
|
+
queued.
|
57
|
+
src_vss (List): List containing input VSs of the cell execution.
|
58
|
+
dst_vss (List): List containing output VSs of the cell execution.
|
59
|
+
"""
|
60
|
+
|
61
|
+
# Create a cell execution.
|
62
|
+
ce = CellExecution(
|
63
|
+
len(self.cell_executions), cell, cell_runtime, start_time, src_vss, dst_vss
|
64
|
+
)
|
65
|
+
|
66
|
+
# Add the newly created cell execution to the graph.
|
67
|
+
self.cell_executions.append(ce)
|
68
|
+
|
69
|
+
# Set the newly created cell execution as dependent on its input variable snapshots.
|
70
|
+
for src_vs in src_vss:
|
71
|
+
src_vs.input_ces.append(ce)
|
72
|
+
|
73
|
+
# Set the newly created cell execution as the parent of its output variable snapshots.
|
74
|
+
for dst_vs in dst_vss:
|
75
|
+
dst_vs.output_ce = ce
|
@@ -0,0 +1,31 @@
|
|
1
|
+
class VariableSnapshot:
|
2
|
+
"""
|
3
|
+
A variable snapshot in the dependency graph corresponds to a version of a variable.
|
4
|
+
I.e. if variable 'x' has been assigned 3 times (x = 1, x = 2, x = 3), then 'x' will have 3 corresponding
|
5
|
+
variable snapshots.
|
6
|
+
"""
|
7
|
+
|
8
|
+
def __init__(self, name: str, version: int, deleted: bool):
|
9
|
+
"""
|
10
|
+
Create a variable snapshot from variable properties.
|
11
|
+
Args:
|
12
|
+
name (str): Variable name.
|
13
|
+
version (int): The nth update to the corresponding variable.
|
14
|
+
deleted (bool): Whether this VS is created for the deletion of a variable, i.e. 'del x'.
|
15
|
+
"""
|
16
|
+
self.name = name
|
17
|
+
self.version = version
|
18
|
+
|
19
|
+
# Whether this VS corresponds to a deleted variable.
|
20
|
+
# i.e. if this VS was created for 'del x' we set this to true so this variable is explicitly not considered
|
21
|
+
# for migration.
|
22
|
+
self.deleted = deleted
|
23
|
+
|
24
|
+
# Cell executions accessing this variable snapshot (i.e. require this variable snapshot to run).
|
25
|
+
self.input_ces = []
|
26
|
+
|
27
|
+
# The unique cell execution creating this variable snapshot.
|
28
|
+
self.output_ce = None
|
29
|
+
|
30
|
+
# Size of variable pointed to by VS; estimated at migration time.
|
31
|
+
self.size = 0
|
File without changes
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
|
4
|
+
class Adapter:
|
5
|
+
def __init__(self):
|
6
|
+
pass
|
7
|
+
|
8
|
+
def read_all(self, path: Path):
|
9
|
+
raise NotImplementedError()
|
10
|
+
|
11
|
+
def create(self, path: Path):
|
12
|
+
raise NotImplementedError()
|
13
|
+
|
14
|
+
def write_all(self, path: Path, buf):
|
15
|
+
raise NotImplementedError()
|
16
|
+
|
17
|
+
def remove(self, path: Path):
|
18
|
+
raise NotImplementedError()
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import gc
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import dill
|
5
|
+
|
6
|
+
from elastic_notebook.core.io.adapter import Adapter
|
7
|
+
|
8
|
+
|
9
|
+
class FilesystemAdapter(Adapter):
|
10
|
+
def __init__(self):
|
11
|
+
super().__init__()
|
12
|
+
|
13
|
+
def read_all(self, path: Path):
|
14
|
+
"""
|
15
|
+
The following (read then decode) is faster vs. directly returning dill.load when network speed is low.
|
16
|
+
"""
|
17
|
+
gc.disable()
|
18
|
+
contents_bytestring = open(path, "rb").read()
|
19
|
+
contents = dill.loads(contents_bytestring)
|
20
|
+
gc.enable()
|
21
|
+
return contents
|
22
|
+
|
23
|
+
def create(self, path: Path):
|
24
|
+
path.touch()
|
25
|
+
|
26
|
+
def write_all(self, path: Path, buf):
|
27
|
+
dill.dump(buf, open(path, "wb"))
|
28
|
+
|
29
|
+
def remove(self, path: Path):
|
30
|
+
path.unlink()
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import dill
|
5
|
+
|
6
|
+
# from elastic_notebook.core.io.filesystem_adapter import FilesystemAdapter
|
7
|
+
from ipykernel.zmqshell import ZMQInteractiveShell
|
8
|
+
|
9
|
+
from elastic_notebook.core.common.checkpoint_file import CheckpointFile
|
10
|
+
from elastic_notebook.core.graph.graph import DependencyGraph
|
11
|
+
|
12
|
+
# Default checkpoint location if a file path isn't specified.
|
13
|
+
FILENAME = "./notebook.pickle"
|
14
|
+
|
15
|
+
|
16
|
+
def migrate(
|
17
|
+
graph: DependencyGraph,
|
18
|
+
shell: ZMQInteractiveShell,
|
19
|
+
vss_to_migrate: set,
|
20
|
+
vss_to_recompute: set,
|
21
|
+
ces_to_recompute: set,
|
22
|
+
udfs,
|
23
|
+
recomputation_ces,
|
24
|
+
overlapping_vss,
|
25
|
+
filename: str,
|
26
|
+
):
|
27
|
+
"""
|
28
|
+
Writes the graph representation of the notebook, migrated variables, and instructions for recomputation as the
|
29
|
+
specified file.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
graph (DependencyGraph): dependency graph representation of the notebook.
|
33
|
+
shell (ZMQInteractiveShell): interactive Jupyter shell storing the state of the current session.
|
34
|
+
vss_to_migrate (set): set of VSs to migrate.
|
35
|
+
vss_to_recompute (set): set of VSs to recompute.
|
36
|
+
ces_to_recompute (set): set of CEs to recompute post-migration.
|
37
|
+
filename (str): the location to write the checkpoint to.
|
38
|
+
udfs (set): set of user-declared functions.
|
39
|
+
"""
|
40
|
+
# Retrieve variables
|
41
|
+
variables = defaultdict(list)
|
42
|
+
for vs in vss_to_migrate:
|
43
|
+
variables[vs.output_ce].append(vs)
|
44
|
+
|
45
|
+
# construct serialization order list.
|
46
|
+
temp_dict = {}
|
47
|
+
serialization_order = []
|
48
|
+
for vs1, vs2 in overlapping_vss:
|
49
|
+
if vs1 in temp_dict:
|
50
|
+
temp_dict[vs1].add(vs2)
|
51
|
+
elif vs2 in temp_dict:
|
52
|
+
temp_dict[vs2].add(vs1)
|
53
|
+
else:
|
54
|
+
# create new entry
|
55
|
+
new_set = {vs1, vs2}
|
56
|
+
temp_dict[vs1] = new_set
|
57
|
+
temp_dict[vs2] = new_set
|
58
|
+
|
59
|
+
for vs in vss_to_migrate:
|
60
|
+
if vs not in temp_dict:
|
61
|
+
temp_dict[vs] = {vs}
|
62
|
+
|
63
|
+
for v in temp_dict.values():
|
64
|
+
serialization_order.append(list(v))
|
65
|
+
|
66
|
+
# Construct checkpoint JSON.
|
67
|
+
# adapter = FilesystemAdapter()
|
68
|
+
metadata = (
|
69
|
+
CheckpointFile()
|
70
|
+
.with_dependency_graph(graph)
|
71
|
+
.with_variables(variables)
|
72
|
+
.with_vss_to_migrate(vss_to_migrate)
|
73
|
+
.with_vss_to_recompute(vss_to_recompute)
|
74
|
+
.with_ces_to_recompute(ces_to_recompute)
|
75
|
+
.with_recomputation_ces(recomputation_ces)
|
76
|
+
.with_serialization_order(serialization_order)
|
77
|
+
.with_udfs(udfs)
|
78
|
+
)
|
79
|
+
|
80
|
+
if filename:
|
81
|
+
print("Checkpoint saved to:", filename)
|
82
|
+
write_path = filename
|
83
|
+
else:
|
84
|
+
write_path = FILENAME
|
85
|
+
|
86
|
+
with open(Path(write_path), "wb") as output_file:
|
87
|
+
dill.dump(metadata, output_file)
|
88
|
+
for vs_list in serialization_order:
|
89
|
+
obj_list = []
|
90
|
+
for vs in vs_list:
|
91
|
+
obj_list.append(shell.user_ns[vs.name])
|
92
|
+
dill.dump(obj_list, output_file)
|
93
|
+
# Write the JSON file to the specified location. Uses the default location if a file path isn't specified.
|
94
|
+
# if filename:
|
95
|
+
# print("Checkpoint saved to:", filename)
|
96
|
+
# adapter.write_all(Path(filename), metadata)
|
97
|
+
# else:
|
98
|
+
# adapter.write_all(Path(FILENAME), metadata)
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# /usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
#
|
4
|
+
# Copyright 2021-2022 University of Illinois
|
5
|
+
|
6
|
+
import hashlib
|
7
|
+
|
8
|
+
# import polars as pl
|
9
|
+
import inspect
|
10
|
+
import mmap
|
11
|
+
import pickle
|
12
|
+
import types
|
13
|
+
|
14
|
+
import dill
|
15
|
+
import matplotlib.pyplot as plt
|
16
|
+
import networkx
|
17
|
+
import pandas as pd
|
18
|
+
import seaborn
|
19
|
+
from scipy import sparse
|
20
|
+
|
21
|
+
|
22
|
+
def is_picklable_fast(obj):
|
23
|
+
if type(obj) in {types.GeneratorType, mmap.mmap, hashlib.sha256}:
|
24
|
+
return False
|
25
|
+
return True
|
26
|
+
|
27
|
+
|
28
|
+
def is_picklable(obj):
|
29
|
+
"""
|
30
|
+
Checks whether an object is pickleable.
|
31
|
+
"""
|
32
|
+
if is_exception(obj) or inspect.ismodule(obj):
|
33
|
+
return True
|
34
|
+
try:
|
35
|
+
# This function can crash.
|
36
|
+
return _is_picklable_dill(obj)
|
37
|
+
except Exception:
|
38
|
+
try:
|
39
|
+
# Double check with function from pickle module
|
40
|
+
return _is_picklable_raw(obj)
|
41
|
+
except Exception:
|
42
|
+
return False
|
43
|
+
|
44
|
+
|
45
|
+
def is_exception(obj):
|
46
|
+
"""
|
47
|
+
List of objects which _is_picklable_dill returns false (or crashes) but are picklable.
|
48
|
+
"""
|
49
|
+
if hasattr(obj, "__module__") and getattr(obj, "__module__", None).split(".")[
|
50
|
+
0
|
51
|
+
] in {plt.__name__, seaborn.__name__, networkx.__name__, pd.__name__}:
|
52
|
+
return True
|
53
|
+
exceptions = [pd.core.frame.DataFrame, sparse.csr.csr_matrix]
|
54
|
+
return type(obj) in exceptions
|
55
|
+
|
56
|
+
|
57
|
+
def _is_picklable_raw(obj):
|
58
|
+
try:
|
59
|
+
# dumps can be slow for large objects that can be pickled
|
60
|
+
pickle.dumps(obj)
|
61
|
+
except Exception:
|
62
|
+
return False
|
63
|
+
return True
|
64
|
+
|
65
|
+
|
66
|
+
def _is_picklable_dill(obj):
|
67
|
+
# compared to _is_picklable_raw, this may be faster
|
68
|
+
# however, dill's correctness is worrying because
|
69
|
+
# it currently considers Pandas DataFrame as not
|
70
|
+
# picklable, which is not true
|
71
|
+
return dill.pickles(obj)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
#
|
4
|
+
# Copyright 2021-2022 University of Illinois
|
5
|
+
|
6
|
+
import _pickle
|
7
|
+
import dill
|
8
|
+
|
9
|
+
from elastic_notebook.core.graph.graph import DependencyGraph
|
10
|
+
|
11
|
+
|
12
|
+
def resume(filename):
|
13
|
+
"""
|
14
|
+
Resumes the notebook from the checkpoint file.
|
15
|
+
Args:
|
16
|
+
filename (str): location of the checkpoint file.
|
17
|
+
Returns:
|
18
|
+
tuple: (dependency_graph, fingerprint_dict, udfs, recomputation_ces, overlapping_vss)
|
19
|
+
"""
|
20
|
+
try:
|
21
|
+
with open(filename, "rb") as output_file:
|
22
|
+
try:
|
23
|
+
# 新しい形式で保存されたデータを読み込む
|
24
|
+
graph, user_ns, udfs, metadata = dill.load(output_file)
|
25
|
+
|
26
|
+
# fingerprint_dictを再構築
|
27
|
+
fingerprint_dict = {}
|
28
|
+
for var_name in user_ns:
|
29
|
+
if var_name in graph.variable_snapshots:
|
30
|
+
fingerprint_dict[var_name] = (
|
31
|
+
None,
|
32
|
+
set(),
|
33
|
+
None,
|
34
|
+
) # 簡易的な指紋情報
|
35
|
+
|
36
|
+
# metadataから情報を取得
|
37
|
+
recomputation_ces = metadata.get("recomputation_ces", {})
|
38
|
+
overlapping_vss = metadata.get("overlapping_vss", [])
|
39
|
+
|
40
|
+
return graph, fingerprint_dict, udfs, recomputation_ces, overlapping_vss
|
41
|
+
|
42
|
+
except _pickle.UnpicklingError as e:
|
43
|
+
print(f"Warning: Checkpoint file is corrupted: {e}")
|
44
|
+
return DependencyGraph(), {}, set(), {}, []
|
45
|
+
except EOFError as e:
|
46
|
+
print(f"Warning: Checkpoint file is incomplete: {e}")
|
47
|
+
return DependencyGraph(), {}, set(), {}, []
|
48
|
+
|
49
|
+
except Exception as e:
|
50
|
+
print(f"Error loading checkpoint: {e}")
|
51
|
+
return DependencyGraph(), {}, set(), {}, []
|
File without changes
|
@@ -0,0 +1,184 @@
|
|
1
|
+
import time
|
2
|
+
from collections.abc import Iterable
|
3
|
+
from types import FunctionType
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
from elastic_notebook.core.mutation.id_graph import (
|
8
|
+
construct_id_graph,
|
9
|
+
is_root_equals,
|
10
|
+
is_structure_equals,
|
11
|
+
)
|
12
|
+
from elastic_notebook.core.mutation.object_hash import (
|
13
|
+
DataframeObj,
|
14
|
+
ImmutableObj,
|
15
|
+
ModuleObj,
|
16
|
+
NoneObj,
|
17
|
+
NpArrayObj,
|
18
|
+
NxGraphObj,
|
19
|
+
ScipyArrayObj,
|
20
|
+
TorchTensorObj,
|
21
|
+
UncomparableObj,
|
22
|
+
UnserializableObj,
|
23
|
+
construct_object_hash,
|
24
|
+
)
|
25
|
+
|
26
|
+
BASE_TYPES = [
|
27
|
+
str,
|
28
|
+
int,
|
29
|
+
float,
|
30
|
+
bool,
|
31
|
+
type(None),
|
32
|
+
FunctionType,
|
33
|
+
ImmutableObj,
|
34
|
+
UncomparableObj,
|
35
|
+
NoneObj,
|
36
|
+
NxGraphObj,
|
37
|
+
TorchTensorObj,
|
38
|
+
ModuleObj,
|
39
|
+
UnserializableObj,
|
40
|
+
NpArrayObj,
|
41
|
+
ScipyArrayObj,
|
42
|
+
DataframeObj,
|
43
|
+
]
|
44
|
+
|
45
|
+
|
46
|
+
def base_typed(obj, visited):
|
47
|
+
"""
|
48
|
+
Recursive reflection method to convert any object property into a comparable form.
|
49
|
+
From: https://stackoverflow.com/questions/1227121/compare-object-instances-for-equality-by-their-attributes
|
50
|
+
"""
|
51
|
+
T = type(obj)
|
52
|
+
from_numpy = T.__module__ == "numpy"
|
53
|
+
|
54
|
+
if T in BASE_TYPES or callable(obj) or (from_numpy and not isinstance(T, Iterable)):
|
55
|
+
return obj
|
56
|
+
|
57
|
+
visited.add(id(obj))
|
58
|
+
|
59
|
+
if isinstance(obj, Iterable):
|
60
|
+
return obj
|
61
|
+
d = obj if T is dict else obj.__dict__
|
62
|
+
|
63
|
+
comp_dict = {}
|
64
|
+
for k, v in d.items():
|
65
|
+
if id(v) not in visited:
|
66
|
+
comp_dict[k] = base_typed(v, visited)
|
67
|
+
|
68
|
+
return comp_dict
|
69
|
+
|
70
|
+
|
71
|
+
def deep_equals(*args):
|
72
|
+
"""
|
73
|
+
Extended equality comparison which compares objects recursively by their attributes, i.e., it also works for
|
74
|
+
certain user-defined objects with no equality (__eq__) defined.
|
75
|
+
"""
|
76
|
+
return all(
|
77
|
+
base_typed(args[0], set()) == base_typed(other, set()) for other in args[1:]
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
def construct_fingerprint(obj, profile_dict):
|
82
|
+
"""
|
83
|
+
Construct a fingerprint of the object (ID graph + hash).
|
84
|
+
"""
|
85
|
+
start = time.time()
|
86
|
+
id_graph, id_set = construct_id_graph(obj)
|
87
|
+
end = time.time()
|
88
|
+
profile_dict["idgraph"] += end - start
|
89
|
+
|
90
|
+
start = time.time()
|
91
|
+
object_representation = construct_object_hash(obj, deepcopy=True)
|
92
|
+
end = time.time()
|
93
|
+
profile_dict["representation"] += end - start
|
94
|
+
|
95
|
+
return [id_graph, id_set, object_representation]
|
96
|
+
|
97
|
+
|
98
|
+
def compare_fingerprint(
|
99
|
+
fingerprint_list, new_obj, profile_dict, input_variables_id_graph_union
|
100
|
+
):
|
101
|
+
"""
|
102
|
+
Check whether an object has been changed by comparing it to its previous fingerprint.
|
103
|
+
"""
|
104
|
+
changed = False
|
105
|
+
overwritten = False
|
106
|
+
uncomparable = False
|
107
|
+
|
108
|
+
# Hack: check for pandas dataframes and series: if the flag has been flipped back on, the object has been changed.
|
109
|
+
if isinstance(new_obj, pd.DataFrame):
|
110
|
+
for _, col in new_obj.items():
|
111
|
+
if col.__array__().flags.writeable:
|
112
|
+
changed = True
|
113
|
+
break
|
114
|
+
|
115
|
+
elif isinstance(new_obj, pd.Series):
|
116
|
+
if new_obj.__array__().flags.writeable:
|
117
|
+
changed = True
|
118
|
+
|
119
|
+
# ID graph check: check whether the structure of the object has changed (i.e. its relation with other objects)
|
120
|
+
start = time.time()
|
121
|
+
|
122
|
+
id_graph, id_set = construct_id_graph(new_obj)
|
123
|
+
|
124
|
+
if id_set != fingerprint_list[1] or not is_structure_equals(
|
125
|
+
id_graph, fingerprint_list[0]
|
126
|
+
):
|
127
|
+
# Distinguish between overwritten variables and modified variables (i.e., x = 1 vs. x[0] = 1)
|
128
|
+
if not is_root_equals(id_graph, fingerprint_list[0]):
|
129
|
+
overwritten = True
|
130
|
+
changed = True
|
131
|
+
fingerprint_list[0] = id_graph
|
132
|
+
fingerprint_list[1] = id_set
|
133
|
+
|
134
|
+
end = time.time()
|
135
|
+
profile_dict["idgraph"] += end - start
|
136
|
+
|
137
|
+
# Value check via object hash: check whether the object's value has changed
|
138
|
+
if not changed:
|
139
|
+
start = time.time()
|
140
|
+
try:
|
141
|
+
new_repr = construct_object_hash(new_obj, deepcopy=False)
|
142
|
+
|
143
|
+
# Variable is uncomparable
|
144
|
+
if isinstance(new_repr, UncomparableObj):
|
145
|
+
if id_set.intersection(input_variables_id_graph_union):
|
146
|
+
changed = True
|
147
|
+
uncomparable = True
|
148
|
+
fingerprint_list[2] = UncomparableObj()
|
149
|
+
else:
|
150
|
+
if not deep_equals(new_repr, fingerprint_list[2]):
|
151
|
+
# Variable has equality defined; the variable has been modified.
|
152
|
+
if (
|
153
|
+
"__eq__" in type(new_repr).__dict__.keys()
|
154
|
+
or "eq" in type(new_repr).__dict__.keys()
|
155
|
+
):
|
156
|
+
changed = True
|
157
|
+
else:
|
158
|
+
# Object is uncomparable
|
159
|
+
if id_set.intersection(input_variables_id_graph_union):
|
160
|
+
changed = True
|
161
|
+
uncomparable = True
|
162
|
+
fingerprint_list[2] = UncomparableObj()
|
163
|
+
except Exception as e:
|
164
|
+
print(e)
|
165
|
+
# Variable is uncomparable
|
166
|
+
if id_set.intersection(input_variables_id_graph_union):
|
167
|
+
changed = True
|
168
|
+
uncomparable = True
|
169
|
+
fingerprint_list[2] = UncomparableObj()
|
170
|
+
|
171
|
+
# Update the object hash if either:
|
172
|
+
# 1. the object has been completely overwritten
|
173
|
+
# 2. the object has been modified, and is of a comparable type (i.e., hashable or unhashable but has equality
|
174
|
+
# defined)
|
175
|
+
if overwritten or (
|
176
|
+
changed
|
177
|
+
and not uncomparable
|
178
|
+
and not isinstance(fingerprint_list[2], UncomparableObj)
|
179
|
+
):
|
180
|
+
fingerprint_list[2] = construct_object_hash(new_obj, deepcopy=True)
|
181
|
+
end = time.time()
|
182
|
+
profile_dict["representation"] += end - start
|
183
|
+
|
184
|
+
return changed, overwritten
|