jupyterlab-codex-sidebar 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. package/.claude/settings.local.json +9 -0
  2. package/.github/workflows/unit-tests.yml +27 -0
  3. package/AGENTS.md +42 -0
  4. package/README.md +67 -9
  5. package/docs/images/codex-sidebar-screenshot.png +0 -0
  6. package/jupyterlab_codex/handlers.py +938 -297
  7. package/jupyterlab_codex/labextension/package.json +13 -3
  8. package/jupyterlab_codex/labextension/static/525.224526d045c727069de6.js +2 -0
  9. package/jupyterlab_codex/labextension/static/855.d20f6158cd81bb4c9056.js +1 -0
  10. package/jupyterlab_codex/labextension/static/{remoteEntry.b2fdc03a1c4582e79156.js → remoteEntry.c1e865f207776f7f24ff.js} +1 -1
  11. package/jupyterlab_codex/protocol.py +297 -0
  12. package/jupyterlab_codex/runner.py +137 -31
  13. package/jupyterlab_codex/sessions.py +582 -97
  14. package/lib/codexChat.d.ts +13 -0
  15. package/lib/codexChat.js +2410 -0
  16. package/lib/codexChat.js.map +1 -0
  17. package/lib/codexChatAttachmentDedup.d.ts +10 -0
  18. package/lib/codexChatAttachmentDedup.js +35 -0
  19. package/lib/codexChatAttachmentDedup.js.map +1 -0
  20. package/lib/codexChatAttachmentLimit.d.ts +8 -0
  21. package/lib/codexChatAttachmentLimit.js +61 -0
  22. package/lib/codexChatAttachmentLimit.js.map +1 -0
  23. package/lib/codexChatDocumentUtils.d.ts +68 -0
  24. package/lib/codexChatDocumentUtils.js +480 -0
  25. package/lib/codexChatDocumentUtils.js.map +1 -0
  26. package/lib/codexChatFormatting.d.ts +11 -0
  27. package/lib/codexChatFormatting.js +83 -0
  28. package/lib/codexChatFormatting.js.map +1 -0
  29. package/lib/codexChatNotice.d.ts +3 -0
  30. package/lib/codexChatNotice.js +74 -0
  31. package/lib/codexChatNotice.js.map +1 -0
  32. package/lib/codexChatPersistence.d.ts +35 -0
  33. package/lib/codexChatPersistence.js +158 -0
  34. package/lib/codexChatPersistence.js.map +1 -0
  35. package/lib/codexChatPrimitives.d.ts +41 -0
  36. package/lib/codexChatPrimitives.js +152 -0
  37. package/lib/codexChatPrimitives.js.map +1 -0
  38. package/lib/codexChatRender.d.ts +24 -0
  39. package/lib/codexChatRender.js +293 -0
  40. package/lib/codexChatRender.js.map +1 -0
  41. package/lib/codexChatSessionFactory.d.ts +15 -0
  42. package/lib/codexChatSessionFactory.js +45 -0
  43. package/lib/codexChatSessionFactory.js.map +1 -0
  44. package/lib/codexChatSessionKey.d.ts +3 -0
  45. package/lib/codexChatSessionKey.js +14 -0
  46. package/lib/codexChatSessionKey.js.map +1 -0
  47. package/lib/codexChatStorage.d.ts +4 -0
  48. package/lib/codexChatStorage.js +37 -0
  49. package/lib/codexChatStorage.js.map +1 -0
  50. package/lib/codexSessionResolver.d.ts +12 -0
  51. package/lib/codexSessionResolver.js +38 -0
  52. package/lib/codexSessionResolver.js.map +1 -0
  53. package/lib/handlers/activitySummarizer.d.ts +15 -0
  54. package/lib/handlers/activitySummarizer.js +327 -0
  55. package/lib/handlers/activitySummarizer.js.map +1 -0
  56. package/lib/handlers/codexMessageTypes.d.ts +30 -0
  57. package/lib/handlers/codexMessageTypes.js +2 -0
  58. package/lib/handlers/codexMessageTypes.js.map +1 -0
  59. package/lib/handlers/codexMessageUtils.d.ts +46 -0
  60. package/lib/handlers/codexMessageUtils.js +144 -0
  61. package/lib/handlers/codexMessageUtils.js.map +1 -0
  62. package/lib/handlers/handleCodexSocketMessage.d.ts +107 -0
  63. package/lib/handlers/handleCodexSocketMessage.js +78 -0
  64. package/lib/handlers/handleCodexSocketMessage.js.map +1 -0
  65. package/lib/handlers/sessionSyncHandler.d.ts +34 -0
  66. package/lib/handlers/sessionSyncHandler.js +181 -0
  67. package/lib/handlers/sessionSyncHandler.js.map +1 -0
  68. package/lib/hooks/useCodexSocket.d.ts +15 -0
  69. package/lib/hooks/useCodexSocket.js +84 -0
  70. package/lib/hooks/useCodexSocket.js.map +1 -0
  71. package/lib/index.js +1 -1
  72. package/lib/index.js.map +1 -1
  73. package/lib/panel.d.ts +1 -11
  74. package/lib/panel.js +1 -2768
  75. package/lib/panel.js.map +1 -1
  76. package/lib/protocol.d.ts +235 -0
  77. package/lib/protocol.js +278 -0
  78. package/lib/protocol.js.map +1 -0
  79. package/package.json +13 -3
  80. package/playwright.config.cjs +24 -0
  81. package/playwright.unit.config.cjs +19 -0
  82. package/pyproject.toml +1 -1
  83. package/release.sh +243 -0
  84. package/scripts/run_playwright_e2e.sh +96 -0
  85. package/scripts/run_playwright_freeze_repro.sh +58 -0
  86. package/scripts/run_playwright_queue_repro.sh +60 -0
  87. package/scripts/run_playwright_repro.sh +55 -0
  88. package/src/codexChat.tsx +3755 -0
  89. package/src/codexChatAttachmentDedup.ts +47 -0
  90. package/src/codexChatAttachmentLimit.ts +82 -0
  91. package/src/codexChatDocumentUtils.ts +612 -0
  92. package/src/codexChatFormatting.ts +94 -0
  93. package/src/codexChatNotice.ts +95 -0
  94. package/src/codexChatPersistence.ts +191 -0
  95. package/src/codexChatPrimitives.tsx +422 -0
  96. package/src/codexChatRender.tsx +376 -0
  97. package/src/codexChatSessionFactory.ts +79 -0
  98. package/src/codexChatSessionKey.ts +16 -0
  99. package/src/codexChatStorage.ts +36 -0
  100. package/src/codexSessionResolver.ts +56 -0
  101. package/src/handlers/activitySummarizer.ts +369 -0
  102. package/src/handlers/codexMessageTypes.ts +34 -0
  103. package/src/handlers/codexMessageUtils.ts +217 -0
  104. package/src/handlers/handleCodexSocketMessage.ts +204 -0
  105. package/src/handlers/sessionSyncHandler.ts +308 -0
  106. package/src/hooks/useCodexSocket.ts +109 -0
  107. package/src/index.ts +1 -1
  108. package/src/panel.tsx +1 -4131
  109. package/src/protocol.ts +582 -0
  110. package/style/index.css +424 -11
  111. package/tests/e2e/fixtures/notebooks/tab1.ipynb +322 -0
  112. package/tests/e2e/fixtures/notebooks/tab1.py +272 -0
  113. package/tests/e2e/fixtures/notebooks/tab2.ipynb +252 -0
  114. package/tests/e2e/fixtures/notebooks/tab2.py +231 -0
  115. package/tests/e2e/fixtures/notebooks/tab3.ipynb +403 -0
  116. package/tests/e2e/fixtures/notebooks/tab3.py +331 -0
  117. package/tests/e2e/fixtures/notebooks/tab4.py +339 -0
  118. package/tests/e2e/freeze-notebook-tabs-repro.spec.js +295 -0
  119. package/tests/e2e/mock-codex-cli-flood.py +127 -0
  120. package/tests/e2e/mock-codex-cli.py +95 -0
  121. package/tests/e2e/queue-multitab-repro.spec.js +189 -0
  122. package/tests/test_handlers.py +116 -0
  123. package/tests/test_protocol.py +169 -0
  124. package/tests/test_session_store_limits.py +50 -0
  125. package/tests/unit/codexChatAttachmentDedup.spec.ts +56 -0
  126. package/tests/unit/codexChatAttachmentLimit.spec.ts +42 -0
  127. package/tests/unit/codexChatLimit.spec.ts +18 -0
  128. package/tests/unit/codexChatNotice.spec.ts +45 -0
  129. package/tests/unit/codexChatPersistence.spec.ts +199 -0
  130. package/tests/unit/codexChatSessionFactory.spec.ts +94 -0
  131. package/tests/unit/codexChatSessionKey.spec.ts +18 -0
  132. package/tests/unit/codexMessageUtils.spec.ts +89 -0
  133. package/tests/unit/codexSessionResolver.spec.ts +92 -0
  134. package/tests/unit/handleCodexSocketMessage.spec.ts +476 -0
  135. package/tsconfig.tsbuildinfo +1 -1
  136. package/webpack.config.js +6 -0
  137. package/jupyterlab_codex/labextension/static/504.335f3447c84ba3d74517.js +0 -2
  138. package/jupyterlab_codex/labextension/static/972.d43137b7438a053eeb72.js +0 -1
  139. /package/jupyterlab_codex/labextension/static/{504.335f3447c84ba3d74517.js.LICENSE.txt → 525.224526d045c727069de6.js.LICENSE.txt} +0 -0
@@ -0,0 +1,331 @@
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # formats: ipynb,py:percent
5
+ # text_representation:
6
+ # extension: .py
7
+ # format_name: percent
8
+ # format_version: '1.3'
9
+ # jupytext_version: 1.19.1
10
+ # kernelspec:
11
+ # display_name: Python 3 (ipykernel)
12
+ # language: python
13
+ # name: python3
14
+ # ---
15
+
16
+ # %%
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+
20
+ # %%
21
+ # 1D time-dependent Schrodinger equation simulation (hbar = m = 1).
22
+ # Method: split-operator (FFT), which keeps time evolution numerically stable.
23
+
24
+ def split_operator_step(psi_state, phase_v_half, phase_t):
25
+ # One split-operator update: V/2 -> T -> V/2
26
+ psi_state = phase_v_half * psi_state
27
+ psi_k_state = np.fft.fft(psi_state)
28
+ psi_k_state *= phase_t
29
+ psi_state = np.fft.ifft(psi_k_state)
30
+ psi_state = phase_v_half * psi_state
31
+ return psi_state
32
+
33
+
34
+ # %%
35
+ # Simulation helpers
36
+ def _validate_simulation_params(
37
+ hbar, mass, grid_size, x_min, x_max, dt, steps, save_every, sigma, barrier_width, barrier_region
38
+ ):
39
+ if grid_size <= 0:
40
+ raise ValueError("grid_size must be positive.")
41
+ if x_max <= x_min:
42
+ raise ValueError("x_max must be greater than x_min.")
43
+ if dt <= 0:
44
+ raise ValueError("dt must be positive.")
45
+ if steps < 0:
46
+ raise ValueError("steps must be non-negative.")
47
+ if save_every <= 0:
48
+ raise ValueError("save_every must be positive.")
49
+ if mass <= 0:
50
+ raise ValueError("mass must be positive.")
51
+ if hbar <= 0:
52
+ raise ValueError("hbar must be positive.")
53
+ if sigma <= 0:
54
+ raise ValueError("sigma must be positive.")
55
+ if barrier_width <= 0:
56
+ raise ValueError("barrier_width must be positive.")
57
+ if barrier_region < 0:
58
+ raise ValueError("barrier_region must be non-negative.")
59
+
60
+
61
+ def _initialize_system(
62
+ hbar, mass, grid_size, x_min, x_max, dt, barrier_height, barrier_width, x0, sigma, k0
63
+ ):
64
+ x = np.linspace(x_min, x_max, grid_size, endpoint=False)
65
+ dx = x[1] - x[0]
66
+ k = 2.0 * np.pi * np.fft.fftfreq(grid_size, d=dx)
67
+
68
+ V = barrier_height * np.exp(-(x / barrier_width) ** 2)
69
+ psi = np.exp(-((x - x0) ** 2) / (2.0 * sigma**2)) * np.exp(1j * k0 * x)
70
+ psi /= np.sqrt(np.sum(np.abs(psi) ** 2) * dx)
71
+
72
+ phase_v_half = np.exp(-1j * V * dt / (2.0 * hbar))
73
+ T_k = (hbar**2) * (k**2) / (2.0 * mass)
74
+ phase_t = np.exp(-1j * T_k * dt / hbar)
75
+ return x, dx, k, V, psi, phase_v_half, phase_t
76
+
77
+
78
+ def _record_observables(psi, x, dx, k, hbar):
79
+ density = np.abs(psi) ** 2
80
+ norm = np.sum(density) * dx
81
+
82
+ x_mean = np.sum(x * density) * dx
83
+ x2_mean = np.sum((x**2) * density) * dx
84
+ psi_k = np.fft.fft(psi)
85
+ p_psi = np.fft.ifft(hbar * k * psi_k)
86
+ p2_psi = np.fft.ifft((hbar * k) ** 2 * psi_k)
87
+ p_mean = np.real(np.sum(np.conj(psi) * p_psi) * dx)
88
+ p2_mean = np.real(np.sum(np.conj(psi) * p2_psi) * dx)
89
+
90
+ x_var = max(x2_mean - x_mean**2, 0.0)
91
+ p_var = max(p2_mean - p_mean**2, 0.0)
92
+ uncertainty = np.sqrt(x_var) * np.sqrt(p_var)
93
+ return density, norm, x_mean, p_mean, uncertainty
94
+
95
+
96
+ def _compute_diagnostics(final_density, x, dx, barrier_region, uncertainty_history):
97
+ reflection = np.sum(final_density[x < -barrier_region]) * dx
98
+ transmission = np.sum(final_density[x > barrier_region]) * dx
99
+ near_barrier = np.sum(final_density[np.abs(x) <= barrier_region]) * dx
100
+
101
+ return {
102
+ "final_norm": np.sum(final_density) * dx,
103
+ "reflection": reflection,
104
+ "transmission": transmission,
105
+ "near_barrier": near_barrier,
106
+ "probability_sum": reflection + transmission + near_barrier,
107
+ "min_uncertainty": float(np.min(uncertainty_history)),
108
+ }
109
+
110
+
111
+ # %%
112
+ # Main simulation
113
+ def simulate_1d_quantum_scattering(
114
+ hbar=1.0,
115
+ mass=1.0,
116
+ grid_size=2048,
117
+ x_min=-100.0,
118
+ x_max=100.0,
119
+ dt=0.05,
120
+ steps=900,
121
+ save_every=90,
122
+ barrier_height=1.6,
123
+ barrier_width=2.2,
124
+ x0=-35.0,
125
+ sigma=3.5,
126
+ k0=1.6,
127
+ barrier_region=6.0,
128
+ ):
129
+ _validate_simulation_params(
130
+ hbar,
131
+ mass,
132
+ grid_size,
133
+ x_min,
134
+ x_max,
135
+ dt,
136
+ steps,
137
+ save_every,
138
+ sigma,
139
+ barrier_width,
140
+ barrier_region,
141
+ )
142
+ x, dx, k, V, psi, phase_v_half, phase_t = _initialize_system(
143
+ hbar, mass, grid_size, x_min, x_max, dt, barrier_height, barrier_width, x0, sigma, k0
144
+ )
145
+
146
+ snapshots = []
147
+ times = []
148
+ norm_history = []
149
+ x_mean_history = []
150
+ p_mean_history = []
151
+ uncertainty_history = []
152
+
153
+ for n in range(steps + 1):
154
+ if n % save_every == 0:
155
+ density, norm, x_mean, p_mean, uncertainty = _record_observables(psi, x, dx, k, hbar)
156
+ snapshots.append(density.copy())
157
+ times.append(n * dt)
158
+ norm_history.append(norm)
159
+ x_mean_history.append(x_mean)
160
+ p_mean_history.append(p_mean)
161
+ uncertainty_history.append(uncertainty)
162
+
163
+ psi = split_operator_step(psi, phase_v_half, phase_t)
164
+
165
+ final_density = np.abs(psi) ** 2
166
+ diagnostics = _compute_diagnostics(final_density, x, dx, barrier_region, uncertainty_history)
167
+
168
+ return {
169
+ "x": x,
170
+ "V": V,
171
+ "snapshots": snapshots,
172
+ "times": times,
173
+ "norm_history": norm_history,
174
+ "x_mean_history": x_mean_history,
175
+ "p_mean_history": p_mean_history,
176
+ "uncertainty_history": uncertainty_history,
177
+ "diagnostics": diagnostics,
178
+ }
179
+
180
+
181
+ # %%
182
+ # Plot helpers
183
+ def _plot_density_panel(ax, x, snapshots, times, V):
184
+ for density, t in zip(snapshots, times):
185
+ ax.plot(x, density, label=f"t={t:.1f}")
186
+ vmax = np.max(V)
187
+ if vmax > 0:
188
+ scale = np.max(snapshots[0]) / vmax
189
+ ax.plot(x, V * scale, "--", linewidth=2, label="Barrier (scaled)")
190
+ ax.set_title("1D Quantum Wave Packet Scattering")
191
+ ax.set_xlabel("x")
192
+ ax.set_ylabel(r"Probability density $|\psi|^2$")
193
+ ax.legend()
194
+ ax.grid(alpha=0.25)
195
+
196
+
197
+ def _plot_norm_panel(ax, times, norm_history):
198
+ ax.plot(times, norm_history, marker="o")
199
+ ax.set_title("Normalization Conservation Check")
200
+ ax.set_xlabel("time")
201
+ ax.set_ylabel("Integral |psi|^2 dx")
202
+ ax.grid(alpha=0.25)
203
+
204
+
205
+ def _plot_expectation_panel(ax, times, x_mean_history, p_mean_history):
206
+ ax.plot(times, x_mean_history, marker="o", label="<x>")
207
+ ax.plot(times, p_mean_history, marker="s", label="<p>")
208
+ ax.set_title("Expectation Values Over Time")
209
+ ax.set_xlabel("time")
210
+ ax.set_ylabel("value")
211
+ ax.legend()
212
+ ax.grid(alpha=0.25)
213
+
214
+
215
+ def _plot_uncertainty_panel(ax, times, uncertainty_history, hbar):
216
+ ax.plot(times, uncertainty_history, marker="^", label="DxDp")
217
+ ax.axhline(0.5 * hbar, linestyle="--", color="red", label="hbar/2")
218
+ ax.set_title("Uncertainty Principle Check")
219
+ ax.set_xlabel("time")
220
+ ax.set_ylabel("DxDp")
221
+ ax.legend()
222
+ ax.grid(alpha=0.25)
223
+
224
+
225
+ # %%
226
+ def plot_density_3d(result):
227
+ x = result["x"]
228
+ times = np.array(result["times"])
229
+ snapshots = np.array(result["snapshots"])
230
+
231
+ X, T = np.meshgrid(x, times)
232
+ fig = plt.figure(figsize=(10, 6))
233
+ ax = fig.add_subplot(111, projection="3d")
234
+ ax.plot_surface(X, T, snapshots, cmap="viridis", linewidth=0, antialiased=True)
235
+ ax.set_title("Probability Density Surface")
236
+ ax.set_xlabel("x")
237
+ ax.set_ylabel("time")
238
+ ax.set_zlabel(r"$|\psi|^2$")
239
+ plt.tight_layout()
240
+ plt.show()
241
+
242
+
243
+ # %%
244
+ # Plot entrypoint
245
+ def plot_results(result, hbar=1.0):
246
+ x = result["x"]
247
+ V = result["V"]
248
+ snapshots = result["snapshots"]
249
+ times = result["times"]
250
+ norm_history = result["norm_history"]
251
+ x_mean_history = result["x_mean_history"]
252
+ p_mean_history = result["p_mean_history"]
253
+ uncertainty_history = result["uncertainty_history"]
254
+ fig, axes = plt.subplots(4, 1, figsize=(10, 14), sharex=False)
255
+
256
+ _plot_density_panel(axes[0], x, snapshots, times, V)
257
+ _plot_norm_panel(axes[1], times, norm_history)
258
+ _plot_expectation_panel(axes[2], times, x_mean_history, p_mean_history)
259
+ _plot_uncertainty_panel(axes[3], times, uncertainty_history, hbar)
260
+
261
+ plt.tight_layout()
262
+ plt.show()
263
+ plot_density_3d(result)
264
+
265
+
266
+ # %%
267
+ def validate_diagnostics(diagnostics, hbar=1.0, tol=1e-3):
268
+ """Return validation checks for core physical constraints."""
269
+ uncertainty_bound = 0.5 * hbar
270
+ reflection = diagnostics["reflection"]
271
+ transmission = diagnostics["transmission"]
272
+ near_barrier = diagnostics["near_barrier"]
273
+ finite_values = np.all(
274
+ np.isfinite(
275
+ [
276
+ diagnostics["final_norm"],
277
+ reflection,
278
+ transmission,
279
+ near_barrier,
280
+ diagnostics["probability_sum"],
281
+ diagnostics["min_uncertainty"],
282
+ uncertainty_bound,
283
+ ]
284
+ )
285
+ )
286
+ checks = {
287
+ "diagnostics_finite": bool(finite_values),
288
+ "norm_close_to_one": bool(abs(diagnostics["final_norm"] - 1.0) <= tol),
289
+ "probability_conserved": bool(abs(diagnostics["probability_sum"] - 1.0) <= tol),
290
+ "probability_terms_nonnegative": bool(
291
+ reflection >= -tol and transmission >= -tol and near_barrier >= -tol
292
+ ),
293
+ "probability_terms_le_one": bool(
294
+ reflection <= 1.0 + tol
295
+ and transmission <= 1.0 + tol
296
+ and near_barrier <= 1.0 + tol
297
+ ),
298
+ "uncertainty_respected": bool(diagnostics["min_uncertainty"] + tol >= uncertainty_bound),
299
+ }
300
+ checks["all_passed"] = all(checks.values())
301
+ return checks
302
+
303
+
304
+ # %%
305
+ def main():
306
+ hbar = 1.0
307
+ result = simulate_1d_quantum_scattering(hbar=hbar)
308
+ diag = result["diagnostics"]
309
+ checks = validate_diagnostics(diag, hbar=hbar)
310
+
311
+ print(f"Final normalization (should be ~1): {diag['final_norm']:.6f}")
312
+ print(f"Reflection probability : {diag['reflection']:.6f}")
313
+ print(f"Transmission probability : {diag['transmission']:.6f}")
314
+ print(f"Near-barrier probability : {diag['near_barrier']:.6f}")
315
+ print(f"R + T + Near-barrier : {diag['probability_sum']:.6f}")
316
+ print(
317
+ "Min uncertainty DxDp : "
318
+ f"{diag['min_uncertainty']:.6f} (>= {0.5 * hbar:.3f})"
319
+ )
320
+ print(f"Validation all passed : {checks['all_passed']}")
321
+
322
+ for name, passed in checks.items():
323
+ if name == "all_passed":
324
+ continue
325
+ print(f" - {name:25}: {passed}")
326
+
327
+ plot_results(result, hbar=hbar)
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()
@@ -0,0 +1,339 @@
1
+ import numpy as np
2
+ import argparse
3
+
4
+
5
+ def normalize(psi: np.ndarray, dx: float) -> np.ndarray:
6
+ norm = np.sqrt(np.sum(np.abs(psi) ** 2) * dx)
7
+ return psi / norm
8
+
9
+
10
+ def run_quantum_tunneling_simulation(
11
+ n_grid: int = 2048,
12
+ x_min: float = -200.0,
13
+ x_max: float = 200.0,
14
+ dt: float = 0.05,
15
+ n_steps: int = 2200,
16
+ snapshot_every: int = 250,
17
+ ):
18
+ """
19
+ 1D time-dependent Schrodinger equation simulation using split-operator FFT.
20
+ Units are dimensionless with hbar = m = 1.
21
+ """
22
+ hbar = 1.0
23
+ mass = 1.0
24
+
25
+ x = np.linspace(x_min, x_max, n_grid, endpoint=False)
26
+ dx = x[1] - x[0]
27
+ k = 2.0 * np.pi * np.fft.fftfreq(n_grid, d=dx)
28
+
29
+ # Square potential barrier in the middle (for tunneling demo)
30
+ v0 = 0.12
31
+ barrier_half_width = 8.0
32
+ v = np.where(np.abs(x) < barrier_half_width, v0, 0.0)
33
+
34
+ # Initial Gaussian wave packet moving right
35
+ x0 = -90.0
36
+ sigma = 10.0
37
+ k0 = 0.55
38
+ psi = np.exp(-((x - x0) ** 2) / (4.0 * sigma**2)) * np.exp(1j * k0 * x)
39
+ psi = normalize(psi, dx)
40
+
41
+ # Split-operator evolution factors
42
+ exp_v_half = np.exp(-1j * v * dt / (2.0 * hbar))
43
+ exp_t = np.exp(-1j * (hbar * k**2) * dt / (2.0 * mass))
44
+
45
+ snapshots = [(0, np.abs(psi) ** 2)]
46
+ norms = [np.sum(np.abs(psi) ** 2) * dx]
47
+
48
+ for step in range(1, n_steps + 1):
49
+ psi *= exp_v_half
50
+ psi_k = np.fft.fft(psi)
51
+ psi_k *= exp_t
52
+ psi = np.fft.ifft(psi_k)
53
+ psi *= exp_v_half
54
+
55
+ if step % snapshot_every == 0 or step == n_steps:
56
+ snapshots.append((step, np.abs(psi) ** 2))
57
+ norms.append(np.sum(np.abs(psi) ** 2) * dx)
58
+
59
+ prob = np.abs(psi) ** 2
60
+ reflected = np.sum(prob[x < -barrier_half_width]) * dx
61
+ transmitted = np.sum(prob[x > barrier_half_width]) * dx
62
+
63
+ return {
64
+ "x": x,
65
+ "V": v,
66
+ "snapshots": snapshots,
67
+ "norms": np.array(norms),
68
+ "R": reflected,
69
+ "T": transmitted,
70
+ }
71
+
72
+
73
+ def plot_results(result: dict):
74
+ import matplotlib.pyplot as plt
75
+
76
+ x = result["x"]
77
+ v = result["V"]
78
+ snapshots = result["snapshots"]
79
+ norms = result["norms"]
80
+
81
+ fig, axes = plt.subplots(2, 1, figsize=(10, 8), constrained_layout=True)
82
+
83
+ ax = axes[0]
84
+ for step, dens in snapshots:
85
+ ax.plot(x, dens, lw=1.8, label=f"step={step}")
86
+ ax.plot(x, v / (np.max(v) + 1e-12) * np.max(snapshots[0][1]), "k--", lw=2, label="scaled V(x)")
87
+ ax.set_title("1D Quantum Tunneling (Probability Density)")
88
+ ax.set_xlabel("x")
89
+ ax.set_ylabel(r"$|\psi(x,t)|^2$")
90
+ ax.legend(loc="upper right", fontsize=8)
91
+ ax.grid(alpha=0.25)
92
+
93
+ ax2 = axes[1]
94
+ ax2.plot(norms, marker="o", lw=1.5)
95
+ ax2.set_title("Norm Check (should stay near 1)")
96
+ ax2.set_xlabel("snapshot index")
97
+ ax2.set_ylabel(r"$\int |\psi|^2 dx$")
98
+ ax2.grid(alpha=0.25)
99
+
100
+ plt.show()
101
+
102
+ print(f"Reflection probability R ≈ {result['R']:.4f}")
103
+ print(f"Transmission probability T ≈ {result['T']:.4f}")
104
+ print(f"R + T ≈ {result['R'] + result['T']:.4f}")
105
+
106
+
107
+ def solve_harmonic_oscillator(
108
+ n_grid: int = 500,
109
+ x_min: float = -8.0,
110
+ x_max: float = 8.0,
111
+ omega: float = 1.0,
112
+ n_states: int = 4,
113
+ ):
114
+ """
115
+ Solve stationary states of 1D quantum harmonic oscillator using finite differences.
116
+ Units are dimensionless with hbar = m = 1.
117
+ """
118
+ x = np.linspace(x_min, x_max, n_grid)
119
+ dx = x[1] - x[0]
120
+
121
+ main_diag = -2.0 * np.ones(n_grid)
122
+ off_diag = np.ones(n_grid - 1)
123
+ d2 = (
124
+ np.diag(main_diag, 0)
125
+ + np.diag(off_diag, 1)
126
+ + np.diag(off_diag, -1)
127
+ ) / (dx**2)
128
+
129
+ v = 0.5 * (omega**2) * (x**2)
130
+ hamiltonian = -0.5 * d2 + np.diag(v)
131
+
132
+ eigenvalues, eigenvectors = np.linalg.eigh(hamiltonian)
133
+ eigenvalues = eigenvalues[:n_states]
134
+ states = eigenvectors[:, :n_states]
135
+
136
+ for i in range(states.shape[1]):
137
+ states[:, i] = normalize(states[:, i], dx)
138
+
139
+ return {"x": x, "V": v, "energies": eigenvalues, "states": states}
140
+
141
+
142
+ def plot_harmonic_oscillator(result: dict):
143
+ import matplotlib.pyplot as plt
144
+
145
+ x = result["x"]
146
+ v = result["V"]
147
+ energies = result["energies"]
148
+ states = result["states"]
149
+
150
+ plt.figure(figsize=(10, 6))
151
+ plt.plot(x, v, "k--", lw=2, label="V(x)=x^2/2")
152
+
153
+ for i, energy in enumerate(energies):
154
+ psi = states[:, i]
155
+ scale = 0.7
156
+ plt.plot(x, scale * psi + energy, lw=1.8, label=f"n={i}, E={energy:.3f}")
157
+ plt.hlines(energy, x[0], x[-1], colors="gray", linestyles=":", lw=0.8)
158
+
159
+ plt.title("Quantum Harmonic Oscillator: Eigenstates and Energies")
160
+ plt.xlabel("x")
161
+ plt.ylabel("Energy / shifted wavefunction")
162
+ plt.legend(loc="upper left", fontsize=9)
163
+ plt.grid(alpha=0.25)
164
+ plt.show()
165
+
166
+ print("Lowest energies (numerical):")
167
+ for i, energy in enumerate(energies):
168
+ print(f"n={i}: E ≈ {energy:.6f} (exact: {i + 0.5:.6f})")
169
+
170
+
171
+ def solve_finite_square_well(
172
+ n_grid: int = 700,
173
+ x_min: float = -12.0,
174
+ x_max: float = 12.0,
175
+ well_depth: float = 8.0,
176
+ well_width: float = 4.0,
177
+ n_states: int = 4,
178
+ ):
179
+ """
180
+ Solve bound states of a 1D finite square well using finite differences.
181
+ Potential: V(x) = -well_depth for |x| <= well_width/2, else 0.
182
+ """
183
+ x = np.linspace(x_min, x_max, n_grid)
184
+ dx = x[1] - x[0]
185
+
186
+ main_diag = -2.0 * np.ones(n_grid)
187
+ off_diag = np.ones(n_grid - 1)
188
+ d2 = (
189
+ np.diag(main_diag, 0)
190
+ + np.diag(off_diag, 1)
191
+ + np.diag(off_diag, -1)
192
+ ) / (dx**2)
193
+
194
+ v = np.where(np.abs(x) <= (well_width / 2.0), -well_depth, 0.0)
195
+ hamiltonian = -0.5 * d2 + np.diag(v)
196
+
197
+ eigenvalues, eigenvectors = np.linalg.eigh(hamiltonian)
198
+ bound_indices = np.where(eigenvalues < 0.0)[0]
199
+
200
+ if len(bound_indices) == 0:
201
+ chosen = np.arange(min(n_states, len(eigenvalues)))
202
+ only_bound = False
203
+ else:
204
+ chosen = bound_indices[:n_states]
205
+ only_bound = True
206
+
207
+ energies = eigenvalues[chosen]
208
+ states = eigenvectors[:, chosen]
209
+
210
+ for i in range(states.shape[1]):
211
+ states[:, i] = normalize(states[:, i], dx)
212
+
213
+ return {
214
+ "x": x,
215
+ "V": v,
216
+ "energies": energies,
217
+ "states": states,
218
+ "bound_only": only_bound,
219
+ }
220
+
221
+
222
+ def plot_finite_square_well(result: dict):
223
+ import matplotlib.pyplot as plt
224
+
225
+ x = result["x"]
226
+ v = result["V"]
227
+ energies = result["energies"]
228
+ states = result["states"]
229
+
230
+ plt.figure(figsize=(10, 6))
231
+ plt.plot(x, v, "k--", lw=2.0, label="V(x)")
232
+
233
+ state_scale = max(0.5, 0.15 * np.max(np.abs(v)))
234
+ for i, energy in enumerate(energies):
235
+ psi = states[:, i]
236
+ plt.plot(x, state_scale * psi + energy, lw=1.7, label=f"state={i}, E={energy:.3f}")
237
+ plt.hlines(energy, x[0], x[-1], colors="gray", linestyles=":", lw=0.8)
238
+
239
+ plt.title("Finite Square Well: Bound-State Spectrum")
240
+ plt.xlabel("x")
241
+ plt.ylabel("Energy / shifted wavefunction")
242
+ plt.legend(loc="upper right", fontsize=9)
243
+ plt.grid(alpha=0.25)
244
+ plt.show()
245
+
246
+ if result["bound_only"]:
247
+ print("Bound-state energies (E < 0):")
248
+ else:
249
+ print("No bound state found with current parameters. Showing lowest states:")
250
+ for i, energy in enumerate(energies):
251
+ print(f"state={i}: E ≈ {energy:.6f}")
252
+
253
+
254
+ def simulate_two_level_rabi(
255
+ omega: float = 1.0,
256
+ detuning: float = 0.2,
257
+ t_max: float = 30.0,
258
+ n_steps: int = 1200,
259
+ ):
260
+ """
261
+ Two-level system (qubit) Rabi oscillation simulation.
262
+ Hamiltonian: H = 0.5 * [[detuning, omega], [omega, -detuning]]
263
+ Units are dimensionless with hbar = 1.
264
+ """
265
+ times = np.linspace(0.0, t_max, n_steps)
266
+ hamiltonian = 0.5 * np.array(
267
+ [[detuning, omega], [omega, -detuning]],
268
+ dtype=np.complex128,
269
+ )
270
+
271
+ eigvals, eigvecs = np.linalg.eigh(hamiltonian)
272
+ psi0 = np.array([1.0 + 0.0j, 0.0 + 0.0j], dtype=np.complex128)
273
+
274
+ p0 = np.zeros_like(times)
275
+ p1 = np.zeros_like(times)
276
+
277
+ for i, t in enumerate(times):
278
+ phase = np.exp(-1j * eigvals * t)
279
+ u_t = eigvecs @ np.diag(phase) @ eigvecs.conj().T
280
+ psi_t = u_t @ psi0
281
+ p0[i] = np.abs(psi_t[0]) ** 2
282
+ p1[i] = np.abs(psi_t[1]) ** 2
283
+
284
+ return {"t": times, "P0": p0, "P1": p1, "omega": omega, "detuning": detuning}
285
+
286
+
287
+ def plot_rabi(result: dict):
288
+ import matplotlib.pyplot as plt
289
+
290
+ t = result["t"]
291
+ p0 = result["P0"]
292
+ p1 = result["P1"]
293
+
294
+ plt.figure(figsize=(10, 5))
295
+ plt.plot(t, p0, lw=2.0, label="P(|0>)")
296
+ plt.plot(t, p1, lw=2.0, label="P(|1>)")
297
+ plt.title("Two-Level Quantum Rabi Oscillation")
298
+ plt.xlabel("time")
299
+ plt.ylabel("probability")
300
+ plt.ylim(-0.02, 1.02)
301
+ plt.grid(alpha=0.25)
302
+ plt.legend()
303
+ plt.show()
304
+
305
+ print(f"omega={result['omega']:.3f}, detuning={result['detuning']:.3f}")
306
+ print(f"max P(|1>) ≈ {np.max(result['P1']):.4f}")
307
+
308
+
309
+ def parse_args():
310
+ parser = argparse.ArgumentParser(description="Quantum mechanics simulations in 1D.")
311
+ parser.add_argument(
312
+ "--mode",
313
+ choices=["tunnel", "oscillator", "well", "rabi"],
314
+ default="tunnel",
315
+ help="Simulation mode: 'tunnel', 'oscillator', 'well', or 'rabi'.",
316
+ )
317
+ # Use parse_known_args for notebook/IPython compatibility.
318
+ args, _ = parser.parse_known_args()
319
+ return args
320
+
321
+
322
+ def main(mode: str = "tunnel"):
323
+ if mode == "tunnel":
324
+ simulation_result = run_quantum_tunneling_simulation()
325
+ plot_results(simulation_result)
326
+ elif mode == "oscillator":
327
+ oscillator_result = solve_harmonic_oscillator()
328
+ plot_harmonic_oscillator(oscillator_result)
329
+ elif mode == "well":
330
+ well_result = solve_finite_square_well()
331
+ plot_finite_square_well(well_result)
332
+ else:
333
+ rabi_result = simulate_two_level_rabi()
334
+ plot_rabi(rabi_result)
335
+
336
+
337
+ if __name__ == "__main__":
338
+ args = parse_args()
339
+ main(args.mode)