edtrace 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edtrace-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: edtrace
3
+ Version: 0.1.0
4
+ Summary: Library for tracing through Python programs for educational presentation
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: torch
8
+ Requires-Dist: sympy
9
+ Requires-Dist: numpy
10
+ Requires-Dist: requests
11
+ Requires-Dist: altair>=5.5.0
@@ -0,0 +1,13 @@
1
+ [project]
2
+ name = "edtrace"
3
+ version = "0.1.0"
4
+ description = "Library for tracing through Python programs for educational presentation"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "torch",
9
+ "sympy",
10
+ "numpy",
11
+ "requests",
12
+ "altair>=5.5.0",
13
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,50 @@
1
+ import re
2
+ import xml.etree.ElementTree as ET
3
+ from file_util import cached
4
+ from reference import Reference
5
+
6
+
7
+ def canonicalize(text: str):
8
+ """Remove newlines and extra whitespace with one space."""
9
+ text = text.replace("\n", " ")
10
+ text = re.sub(r"\s+", " ", text)
11
+ text = text.strip()
12
+ return text
13
+
14
+
15
+ def is_arxiv_link(url: str) -> bool:
16
+ return url.startswith("https://arxiv.org/")
17
+
18
+ def arxiv_reference(url: str, **kwargs) -> Reference:
19
+ """
20
+ Parse an arXiv reference from a URL (e.g., https://arxiv.org/abs/2005.14165).
21
+ Cache the result.
22
+ """
23
+ # Figure out the paper ID
24
+ paper_id = None
25
+ m = re.search(r'arxiv.org\/...\/(\d+\.\d+)(v\d)?(\.pdf)?$', url)
26
+ if not m:
27
+ raise ValueError(f"Cannot handle this URL: {url}")
28
+ paper_id = m.group(1)
29
+
30
+ metadata_url = f"http://export.arxiv.org/api/query?id_list={paper_id}"
31
+ metadata_path = cached(metadata_url, "arxiv")
32
+ with open(metadata_path, "r") as f:
33
+ contents = f.read()
34
+ root = ET.fromstring(contents)
35
+
36
+ # Extract the relevant metadata
37
+ entry = root.find('{http://www.w3.org/2005/Atom}entry')
38
+ title = canonicalize(entry.find('{http://www.w3.org/2005/Atom}title').text)
39
+ authors = [canonicalize(author.find('{http://www.w3.org/2005/Atom}name').text) for author in entry.findall('{http://www.w3.org/2005/Atom}author')]
40
+ summary = canonicalize(entry.find('{http://www.w3.org/2005/Atom}summary').text)
41
+ published = entry.find('{http://www.w3.org/2005/Atom}published').text
42
+
43
+ return Reference(
44
+ title=title,
45
+ authors=authors,
46
+ url=url,
47
+ date=published,
48
+ description=summary,
49
+ **kwargs,
50
+ )
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ import sympy
3
+ import torch
4
+ from datetime import datetime
5
+ from dataclasses import dataclass
6
+ from execute_util import text, image, link, system_text
7
+
8
+
9
+ def main():
10
+ display()
11
+ inspect_values()
12
+
13
+ def compute():
14
+ x = 0
15
+ for i in range(100):
16
+ x = i * i
17
+ return x
18
+
19
+
20
+ def display():
21
+ text("Hello, world!")
22
+ text("Math: $x^2$")
23
+ text("- Bullet 1")
24
+ text("- Bullet 2: this is a long thing that should wrap at some point because it will keep on going on and on")
25
+ text("# Heading 1")
26
+ text("## Heading 2")
27
+ text("### Heading 3")
28
+ text("**Bold** *italic*")
29
+ text("Multiline text: "
30
+ "wrapped around") # @hide
31
+ text("One text th"), text("at is made up multi"), text("ple text calls")
32
+ image("https://www.google.com/logos/doodles/2025/labor-day-2025-6753651837110707.4-l.webp", width=200)
33
+ link(title="Google", url="https://www.google.com")
34
+ link("https://arxiv.org/abs/2005.14165")
35
+ system_text(["date"])
36
+
37
+ x = compute() # @inspect x @stepover
38
+ text("Should still show the value of x.")
39
+ text("Let's move on (value of x should not be shown).") # @clear x
40
+
41
+
42
+ def inspect_values():
43
+ # Numpy arrays of different dtypes
44
+ x = np.array([1, 2, 3]) # @inspect x
45
+ x = np.array([1, 2, 3], dtype=np.int8) # @inspect x
46
+ x = np.array([1, 2, 3], dtype=np.int16) # @inspect x
47
+ x = np.array([1, 2, 3], dtype=np.int32) # @inspect x
48
+ x = np.array([1, 2, 3], dtype=np.int64) # @inspect x
49
+ x = np.array([1, 2, 3], dtype=np.float16) # @inspect x
50
+ x = np.array([1, 2, 3], dtype=np.float32) # @inspect x
51
+ x = np.array([1, 2, 3], dtype=np.float64) # @inspect x
52
+
53
+ # PyTorch tensors of different dtypes
54
+ x = torch.tensor([1, 2, 3]) # @inspect x
55
+ x = torch.tensor([1, 2, 3], dtype=torch.int8) # @inspect x
56
+ x = torch.tensor([1, 2, 3], dtype=torch.int16) # @inspect x
57
+ x = torch.tensor([1, 2, 3], dtype=torch.int32) # @inspect x
58
+ x = torch.tensor([1, 2, 3], dtype=torch.int64) # @inspect x
59
+ x = torch.tensor([1, 2, 3], dtype=torch.float16) # @inspect x
60
+ x = torch.tensor([1, 2, 3], dtype=torch.float32) # @inspect x
61
+ x = torch.tensor([1, 2, 3], dtype=torch.float64) # @inspect x
62
+
63
+ # Different scalars
64
+ x = torch.tensor(1, dtype=torch.int64) # @inspect x
65
+ x = torch.tensor(1, dtype=torch.float64) # @inspect x
66
+ x = np.int64(1) # @inspect x
67
+ x = np.float64(1) # @inspect x
68
+
69
+ # Multi-dimensional arrays
70
+ x = np.zeros((2, 3)) # @inspect x
71
+ x = np.zeros((2, 2, 3)) # @inspect x
72
+ x = torch.zeros((2, 3)) # @inspect x
73
+ x = torch.zeros((2, 2, 3)) # @inspect x
74
+
75
+ # Sympy
76
+ x = sympy.symbols('x') # @inspect x
77
+ x = 0 * sympy.symbols('x') # @inspect x
78
+ x = 0.5 * sympy.symbols('x') # @inspect x
79
+
80
+ # Lists
81
+ x = [] # @inspect x
82
+ x = [1, 2, 3] # @inspect x
83
+ x = [[1, 2, 3], [4, 5, "hello"]] # @inspect x
84
+
85
+ # Dicts
86
+ x = {} # @inspect x
87
+ x = {"a": [1, 2, 3], "b": [4, 5, "hello"]} # @inspect x
88
+
89
+ # Dataclasses
90
+ @dataclass(frozen=True)
91
+ class MyDataclass:
92
+ a: int
93
+ b: list[int]
94
+ x = MyDataclass(a=1, b=[2, 3]) # @inspect x
95
+
96
+ # Datetimes
97
+ x = datetime.now() # @inspect x
98
+
99
+ if __name__ == "__main__":
100
+ main()
@@ -0,0 +1,376 @@
1
+ import io
2
+ import argparse
3
+ import math
4
+ import importlib
5
+ import inspect
6
+ import sys
7
+ import json
8
+ import traceback
9
+ import numpy as np
10
+ import torch
11
+ import sympy
12
+ from dataclasses import dataclass, asdict, field, is_dataclass, fields
13
+ import os
14
+ import re
15
+ from execute_util import Rendering, pop_renderings
16
+ from file_util import ensure_directory_exists, relativize
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class StackElement:
21
+ path: str
22
+ """The path to the file containing the code."""
23
+
24
+ line_number: int
25
+ """The line number of the code."""
26
+
27
+ function_name: str
28
+ """The name of the function that we're in."""
29
+
30
+ code: str
31
+ """The source code that is executed."""
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class Value:
36
+ """Represents the value of an environment variable."""
37
+ type: str
38
+ """The type of the value."""
39
+
40
+ contents: any
41
+ """The contents itself."""
42
+
43
+ dtype: str | None = None
44
+ """If `contents` is a tensor/array, then this is its dtype (e.g., "float32")."""
45
+
46
+ shape: list[int] | None = None
47
+ """If `contents` is a tensor/array, then this is its shape (e.g., [2, 3] for a 2x3 matrix)."""
48
+
49
+
50
+ @dataclass
51
+ class Step:
52
+ """Not frozen because the renderings need to be updated."""
53
+ stack: list[StackElement]
54
+ """The stack of function calls."""
55
+
56
+ env: dict[str, Value]
57
+ """The local variables including function arguments(that we're @inspect-ing)."""
58
+
59
+ renderings: list[Rendering] = field(default_factory=list)
60
+ """The output of the code (see execute_util.py)."""
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class Trace:
65
+ files: dict[str, str]
66
+ """Mapping from file path to file contents."""
67
+
68
+ hidden_line_numbers: dict[str, list[int]]
69
+ """Mapping from file path to list of line numbers to hide."""
70
+
71
+ steps: list[Step]
72
+ """The steps of the trace."""
73
+
74
+
75
+ DIRECTIVE_INSPECT = "@inspect" # Show (and update) the value of a variable
76
+ DIRECTIVE_CLEAR = "@clear" # Stop showing the value of a variable
77
+ DIRECTIVE_STEPOVER = "@stepover" # Don't trace into the current line
78
+ DIRECTIVE_HIDE = "@hide" # Don't show this line at all
79
+ ACCEPTED_DIRECTIVES = [DIRECTIVE_INSPECT, DIRECTIVE_CLEAR, DIRECTIVE_STEPOVER, DIRECTIVE_HIDE]
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class Directive:
84
+ name: str
85
+ """The name of the directive."""
86
+ args: list[str]
87
+ """The arguments of the directive."""
88
+
89
+
90
+ def parse_directives(line: str) -> list[Directive]:
91
+ """
92
+ Parse the directives from the line.
93
+ Examples:
94
+ "... # @inspect x y @hide" -> [Directive(name="@inspect", args=["x", "y"]), Directive(name="@hide", args=[])]
95
+ """
96
+ # Get tokens after the "#"
97
+ if "#" not in line:
98
+ return []
99
+ tokens = line.split("#")[1].split()
100
+ directives: list[Directive] = []
101
+ for token in tokens:
102
+ if token.startswith("@"):
103
+ if token not in ACCEPTED_DIRECTIVES:
104
+ print(f"WARNING: {token} is not a valid directive.")
105
+ name = token
106
+ args = []
107
+ directives.append(Directive(name=name, args=args))
108
+ else:
109
+ if len(directives) > 0:
110
+ directives[-1].args.append(token)
111
+ return directives
112
+
113
+
114
+ def get_inspect_variables(directives: list[Directive]) -> list[str]:
115
+ """
116
+ If code contains "@inspect <variable>" (as a comment), return those variables.
117
+ Example code:
118
+ x, y = str.split("a,b") # @inspect x @inspect y
119
+ We would return ["x", "y"]
120
+ """
121
+ variables = []
122
+ for directive in directives:
123
+ if directive.name == DIRECTIVE_INSPECT:
124
+ variables.extend(directive.args)
125
+ return variables
126
+
127
+
128
+ def get_clear_variables(directives: list[Directive]) -> list[str]:
129
+ """
130
+ If code contains "@clear <variable>" (as a comment), return the variables to clear.
131
+ Example code:
132
+ y = np.array([1, 2, 3]) # @clear y
133
+ We would return ["y"]
134
+ """
135
+ variables = []
136
+ for directive in directives:
137
+ if directive.name == DIRECTIVE_CLEAR:
138
+ variables.extend(directive.args)
139
+ return variables
140
+
141
+
142
+ def to_primitive(value: any) -> any:
143
+ if isinstance(value, (int, float, str, bool)):
144
+ return value
145
+ # Force it to be a primitive
146
+ return str(value)
147
+
148
+
149
+ def to_serializable_value(value: any) -> Value:
150
+ """Convert `value` to something that's serializable to JSON."""
151
+ value_type = get_type_str(value)
152
+
153
+ # Primitive types
154
+ if isinstance(value, (bool, int, float, str)):
155
+ # Serialize inf and nan values specially since JSON doesn't support it
156
+ if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
157
+ return Value(type=value_type, contents=str(value))
158
+ return Value(type=value_type, contents=value)
159
+
160
+ # Tensors
161
+ if isinstance(value, (np.int64,)):
162
+ return Value(type=value_type, contents=int(value)) # Hope no rounding issues
163
+ if isinstance(value, (np.float64,)):
164
+ return Value(type=value_type, contents=float(value)) # Hope no rounding issues
165
+ if isinstance(value, np.ndarray):
166
+ return Value(type=value_type, dtype=str(value.dtype), shape=list(value.shape), contents=value.tolist())
167
+ if isinstance(value, torch.Tensor):
168
+ return Value(type=value_type, dtype=str(value.dtype), shape=list(value.shape), contents=value.tolist())
169
+
170
+ # Symbols
171
+ if value_type.startswith("sympy.core."):
172
+ if isinstance(value, sympy.core.numbers.Integer):
173
+ return Value(type=value_type, contents=int(value))
174
+ if isinstance(value, sympy.core.numbers.Float):
175
+ return Value(type=value_type, contents=float(value))
176
+ return Value(type=value_type, contents=str(value))
177
+
178
+ # Recursive types
179
+ if isinstance(value, list):
180
+ return Value(type=value_type, contents=[to_serializable_value(item) for item in value])
181
+ if isinstance(value, dict):
182
+ return Value(type=value_type, contents={to_primitive(k): to_serializable_value(v) for k, v in value.items()})
183
+ if is_dataclass(value):
184
+ return Value(type=value_type, contents={
185
+ field.name: to_serializable_value(getattr(value, field.name))
186
+ for field in fields(value)
187
+ })
188
+
189
+ # Force contents to be a string to avoid serialization errors
190
+ return Value(type=type(value).__name__, contents=str(value))
191
+
192
+
193
+ def get_type_str(value: any) -> str:
194
+ """Return the string representation of the type of `value`."""
195
+ value_type = type(value)
196
+ if value_type.__module__ == "builtins": # e.g., int, float, str, bool
197
+ return value_type.__name__
198
+ return value_type.__module__ + "." + value_type.__name__
199
+
200
+
201
+ def execute(module_name: str, inspect_all_variables: bool) -> Trace:
202
+ """
203
+ Execute the module and return a trace of the execution.
204
+ """
205
+ steps: list[Step] = []
206
+
207
+ # Figure out which files we're actually tracing
208
+ visible_paths = []
209
+
210
+ # Stack of locations that we're stepping over
211
+ stepovers = []
212
+
213
+ def get_stack() -> list[StackElement]:
214
+ """Return the last element of `stack`, but skip over items where local_trace_func is active."""
215
+ stack = []
216
+ # stack looks like this:
217
+ # <module> execute [good stuff to return] local_trace_func trace_func get_stack
218
+ items = traceback.extract_stack()
219
+ assert items[0].name == "<module>"
220
+ assert items[1].name == "execute"
221
+ for item in traceback.extract_stack()[2:]:
222
+ if item.name in ("trace_func", "local_trace_func", "get_stack"):
223
+ continue
224
+ stack.append(StackElement(
225
+ path=relativize(item.filename),
226
+ line_number=item.lineno,
227
+ function_name=item.name,
228
+ code=item.line,
229
+ ))
230
+ return stack
231
+
232
+ def trace_func(frame, event, arg):
233
+ """
234
+ trace_func and local_trace_func are called on various lines of code when executed.
235
+ - trace_func is called *before* a line of code is executed.
236
+ - local_trace_func is called *after* a line of code has been executed
237
+ and will have the values of the variables.
238
+ We generally keep the local_trace_func version. However, when you have
239
+ a function call that you're tracing through, you want to keep both
240
+ versions.
241
+
242
+ We don't care about all the events, so here are the rules:
243
+ - In local_trace_func, if the previous event was the same line (presumably the trace_func)
244
+ - Remove all trace_func(return)
245
+ """
246
+
247
+ # Get the current file path from the frame and skip if not in visible paths
248
+ # to avoid tracing deep into imports (which would be slow and irrelevant)
249
+ current_path = frame.f_code.co_filename
250
+ if current_path not in visible_paths:
251
+ return trace_func
252
+
253
+ stack = get_stack()
254
+
255
+ if event == "return":
256
+ return trace_func
257
+
258
+ # Print the current line of code
259
+ item = stack[-1]
260
+
261
+ # Don't step into comprehensions since they're redundant and just stay on the line
262
+ if item.function_name == "<listcomp>":
263
+ return trace_func
264
+
265
+ # Handle @stepover (don't recurse)
266
+ directives = parse_directives(item.code)
267
+ if any(directive.name == DIRECTIVE_STEPOVER for directive in directives):
268
+ # If stepping over this line
269
+ if len(stepovers) > 0 and stepovers[-1] == (item.path, item.line_number):
270
+ # Stop skipping since we're back to this line
271
+ stepovers.pop()
272
+ else:
273
+ # Just starting to skip starting here
274
+ stepovers.append((item.path, item.line_number))
275
+
276
+ # Skip everything that is strictly under stepovers
277
+ if any(stepover[0] == item.path and stepover[1] == item.line_number for stepover in stepovers for item in stack[:-1]):
278
+ return trace_func
279
+
280
+ print(f" [{len(steps)} {os.path.basename(item.path)}:{item.line_number}] {item.code}")
281
+
282
+ open_step = Step(
283
+ stack=stack,
284
+ env={},
285
+ )
286
+ if len(steps) == 0 or open_step.stack != steps[-1].stack: # Only add a step if it's not redundant
287
+ steps.append(open_step)
288
+ open_step_index = len(steps) - 1
289
+
290
+ def local_trace_func(frame, event, arg):
291
+ """This is called *after* a line of code has been executed."""
292
+ # If the last step was the same line, then just use the same one
293
+ # Otherwise, create a new step (e.g., returning from a function)
294
+ if open_step_index == len(steps) - 1:
295
+ close_step = steps[-1]
296
+ else:
297
+ print(f" [{len(steps)} {os.path.basename(item.path)}:{item.line_number}] {item.code}")
298
+
299
+ close_step = Step(
300
+ stack=stack,
301
+ env={},
302
+ )
303
+ steps.append(close_step)
304
+
305
+ # Update the environment with the actual values
306
+ locals = frame.f_locals
307
+ if inspect_all_variables:
308
+ vars = locals.keys()
309
+ else:
310
+ vars = get_inspect_variables(directives)
311
+ for var in vars:
312
+ if var in locals:
313
+ close_step.env[var] = to_serializable_value(locals[var])
314
+ else:
315
+ print(f"WARNING: variable {var} not found in locals")
316
+ print(f" env: {var} = {close_step.env.get(var)}")
317
+
318
+ clear_vars = get_clear_variables(directives)
319
+ for var in clear_vars:
320
+ if var in locals:
321
+ close_step.env[var] = None
322
+
323
+ # Capture the renderings of the last line
324
+ close_step.renderings = pop_renderings()
325
+
326
+ # Pass control back to the global trace function
327
+ return trace_func(frame, event, arg)
328
+
329
+ # Pass control to local_trace_func to update the environment
330
+ return local_trace_func
331
+
332
+ # Run the module
333
+ module = importlib.import_module(module_name)
334
+ visible_paths.append(inspect.getfile(module))
335
+ sys.settrace(trace_func)
336
+ module.main()
337
+ sys.settrace(None)
338
+
339
+ files = {relativize(path): open(path).read() for path in visible_paths}
340
+ hidden_line_numbers = compute_hidden_line_numbers(files)
341
+ trace = Trace(steps=steps, files=files, hidden_line_numbers=hidden_line_numbers)
342
+ return trace
343
+
344
+
345
+ def compute_hidden_line_numbers(files: dict[str, str]) -> dict[str, list[int]]:
346
+ """Compute the line numbers to hide based on the @hide comments."""
347
+ hidden_line_numbers = {}
348
+ for path, contents in files.items():
349
+ hidden_line_numbers[path] = []
350
+ for index, line in enumerate(contents.split("\n")):
351
+ directives = parse_directives(line)
352
+ if any(directive.name == DIRECTIVE_HIDE for directive in directives):
353
+ line_number = index + 1
354
+ hidden_line_numbers[path].append(line_number)
355
+ return hidden_line_numbers
356
+
357
+
358
+
359
+ if __name__ == "__main__":
360
+ parser = argparse.ArgumentParser()
361
+ parser.add_argument("-m", "--module", help="List of modules to execute (e.g., lecture_01)", type=str, nargs="+")
362
+ parser.add_argument("-o", "--output_path", help="Path to save the trace", type=str, default="var/traces")
363
+ parser.add_argument("-I", "--inspect-all-variables", help="Inspect all variables (default: only inspect variables mentioned in @inspect comments)", action="store_true")
364
+ args = parser.parse_args()
365
+
366
+ ensure_directory_exists(args.output_path)
367
+
368
+ for module in args.module:
369
+ module = module.replace(".py", "") # Just in case
370
+ print(f"Executing {module}...")
371
+ trace = execute(module_name=module, inspect_all_variables=args.inspect_all_variables)
372
+ print(f"{len(trace.steps)} steps")
373
+ output_path = os.path.join(args.output_path, f"{module}.json")
374
+ print(f"Saving trace to {output_path}...")
375
+ with open(output_path, "w") as f:
376
+ json.dump(asdict(trace), f, indent=2)
@@ -0,0 +1,138 @@
1
+ """
2
+ Functions such as (e.g., note, image, link) populate the list of renderings,
3
+ which will be shown in place of the line of code in the interface.
4
+ """
5
+
6
+ import os
7
+ import inspect
8
+ import re
9
+ import subprocess
10
+ from file_util import cached, relativize
11
+ from dataclasses import dataclass
12
+ from arxiv_util import is_arxiv_link, arxiv_reference
13
+ from reference import Reference
14
+
15
+ @dataclass(frozen=True)
16
+ class CodeLocation:
17
+ """Refers to a specific line of code."""
18
+ path: str
19
+ line_number: int
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class Rendering:
24
+ """
25
+ Specifies what to display instead of a line of code. Types:
26
+ - text: plain text (verbatim)
27
+ - markdown: to be rendered as markdown
28
+ - image: an image (data = url)
29
+ - link: an link to internal code or external URL
30
+ """
31
+ type: str
32
+ data: str | None = None
33
+ style: dict | None = None
34
+ external_link: Reference | None = None
35
+ internal_link: CodeLocation | None = None
36
+
37
+ ############################################################
38
+
39
+ def text(message: str, style: dict | None = None, verbatim: bool = False):
40
+ """Make a note (bullet point) with `message`."""
41
+ style = style or {}
42
+ if verbatim:
43
+ messages = message.split("\n")
44
+ style = {
45
+ "fontFamily": "monospace",
46
+ "whiteSpace": "pre",
47
+ **style
48
+ }
49
+ else:
50
+ messages = [message]
51
+
52
+ for message in messages:
53
+ _current_renderings.append(Rendering(type="markdown", data=message, style=style))
54
+
55
+
56
+ def image(url: str, style: dict | None = None, width: int | str | None = None):
57
+ """Show the image at `url`."""
58
+ style = style or {}
59
+ if width is not None:
60
+ style["width"] = width
61
+
62
+ if is_url(url):
63
+ path = cached(url, "image")
64
+ else:
65
+ path = url
66
+ if not os.path.exists(path):
67
+ raise ValueError(f"Image not found: {path}")
68
+
69
+ _current_renderings.append(Rendering(type="image", data=path, style=style))
70
+
71
+
72
+ def is_url(url: str) -> bool:
73
+ """Check if `url` looks like a URL."""
74
+ return url.startswith("http")
75
+
76
+
77
+ def link(arg: type | Reference | str | None = None, style: dict | None = None, **kwargs):
78
+ """
79
+ Shows a link. There are four possible usages:
80
+ 1. link(title="...", url="...") [Creates a new reference]
81
+ 2. link(arg: Reference) [Shows an existing reference]
82
+ 3. link(arg: type) [Shows a link to the code]
83
+ 4. link(arg: str) [Creates a new reference with the given URL]
84
+ """
85
+ style = style or {}
86
+
87
+ if arg is None:
88
+ reference = Reference(**kwargs)
89
+ _current_renderings.append(Rendering(type="link", style=style, external_link=reference))
90
+ elif isinstance(arg, Reference):
91
+ _current_renderings.append(Rendering(type="link", style=style, external_link=arg))
92
+ elif isinstance(arg, type) or callable(arg):
93
+ path = inspect.getfile(arg)
94
+ _, line_number = inspect.getsourcelines(arg)
95
+ anchor = CodeLocation(relativize(path), line_number)
96
+ _current_renderings.append(Rendering(type="link", data=arg.__name__, style=style, internal_link=anchor))
97
+ elif isinstance(arg, str):
98
+ if is_arxiv_link(arg):
99
+ reference = arxiv_reference(arg, **kwargs)
100
+ _current_renderings.append(Rendering(type="link", style=style, external_link=reference))
101
+ else:
102
+ reference = Reference(url=arg, **kwargs)
103
+ _current_renderings.append(Rendering(type="link", style=style, external_link=reference))
104
+ else:
105
+ raise ValueError(f"Invalid argument: {arg}")
106
+
107
+
108
+ def plot(spec: any):
109
+ """Show a plot given `spec`."""
110
+ _current_renderings.append(Rendering(type="plot", data=spec))
111
+
112
+
113
+ def note(message: str):
114
+ """Show a note."""
115
+ _current_renderings.append(Rendering(type="note", data=message))
116
+
117
+
118
+ ############################################################
119
+
120
+ # Accumulate the renderings during execution (gets flushed).
121
+ _current_renderings: list[Rendering] = []
122
+
123
+ def pop_renderings() -> list[Rendering]:
124
+ """Return the renderings and clear the list."""
125
+ renderings = _current_renderings.copy()
126
+ _current_renderings.clear()
127
+ return renderings
128
+
129
+
130
+ def system_text(command: list[str]):
131
+ output = subprocess.check_output(command).decode('utf-8')
132
+ output = remove_ansi_escape_sequences(output)
133
+ text(output, verbatim=True)
134
+
135
+
136
+ def remove_ansi_escape_sequences(text):
137
+ ansi_escape_pattern = re.compile(r'\x1b\[[0-9;]*m')
138
+ return ansi_escape_pattern.sub('', text)
@@ -0,0 +1,43 @@
1
+ import os
2
+ import re
3
+ import hashlib
4
+ import shutil
5
+ from io import BytesIO
6
+ import requests
7
+
8
+
9
+ def ensure_directory_exists(path: str):
10
+ """Create directory at `path` if it doesn't already exist."""
11
+ if not os.path.exists(path):
12
+ os.mkdir(path)
13
+
14
+
15
+ def download_file(url: str, filename: str):
16
+ """Download `url` and save the contents to `filename`. Skip if `filename` already exists."""
17
+ if not os.path.exists(filename):
18
+ print(f"Downloading {url} to {filename}")
19
+ headers = {
20
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:123.0) Gecko/20100101 Firefox/123.0"
21
+ }
22
+ response = requests.get(url, headers=headers)
23
+ with open(filename, "wb") as f:
24
+ shutil.copyfileobj(BytesIO(response.content), f)
25
+
26
+
27
+ def cached(url: str, prefix: str) -> str:
28
+ """Download `url` if needed and return the location of the cached file."""
29
+ name = re.sub(r"[^\w_-]+", "_", url)
30
+ url_hash = hashlib.md5(url.encode('utf-8')).hexdigest()
31
+
32
+ ensure_directory_exists("var/files")
33
+ path = os.path.join("var/files", prefix + "-" + url_hash + "-" + name)
34
+ download_file(url, path)
35
+ return path
36
+
37
+
38
+ def relativize(path: str) -> str:
39
+ """
40
+ Given a path, return a path relative to the current working directory.
41
+ """
42
+ return os.path.relpath(path, os.getcwd())
43
+
@@ -0,0 +1,15 @@
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass(frozen=True)
4
+ class Reference:
5
+ title: str | None = None
6
+ authors: list[str] | None = None
7
+ organization: str | None = None
8
+ date: str | None = None
9
+ url: str | None = None
10
+ description: str | None = None
11
+ notes: str | None = None
12
+
13
+
14
+ def join(*lines: list[str]) -> str:
15
+ return "\n".join(lines)
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: edtrace
3
+ Version: 0.1.0
4
+ Summary: Library for tracing through Python programs for educational presentation
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: torch
8
+ Requires-Dist: sympy
9
+ Requires-Dist: numpy
10
+ Requires-Dist: requests
11
+ Requires-Dist: altair>=5.5.0
@@ -0,0 +1,12 @@
1
+ pyproject.toml
2
+ src/edtrace/arxiv_util.py
3
+ src/edtrace/examples.py
4
+ src/edtrace/execute.py
5
+ src/edtrace/execute_util.py
6
+ src/edtrace/file_util.py
7
+ src/edtrace/reference.py
8
+ src/edtrace.egg-info/PKG-INFO
9
+ src/edtrace.egg-info/SOURCES.txt
10
+ src/edtrace.egg-info/dependency_links.txt
11
+ src/edtrace.egg-info/requires.txt
12
+ src/edtrace.egg-info/top_level.txt
@@ -0,0 +1,5 @@
1
+ torch
2
+ sympy
3
+ numpy
4
+ requests
5
+ altair>=5.5.0
@@ -0,0 +1 @@
1
+ edtrace