jupyterlab-codex-sidebar 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. package/.claude/settings.local.json +9 -0
  2. package/.github/workflows/unit-tests.yml +27 -0
  3. package/.jupyterlab-playwright.log +0 -0
  4. package/README.md +83 -9
  5. package/docs/images/codex-sidebar-screenshot.png +0 -0
  6. package/jupyterlab_codex/handlers.py +938 -297
  7. package/jupyterlab_codex/labextension/package.json +13 -3
  8. package/jupyterlab_codex/labextension/static/525.224526d045c727069de6.js +2 -0
  9. package/jupyterlab_codex/labextension/static/737.e7de3ad9dd6ded798340.js +1 -0
  10. package/jupyterlab_codex/labextension/static/remoteEntry.6ef5e7167763a316c000.js +1 -0
  11. package/jupyterlab_codex/protocol.py +297 -0
  12. package/jupyterlab_codex/runner.py +58 -15
  13. package/jupyterlab_codex/sessions.py +582 -97
  14. package/lib/codexChat.d.ts +13 -0
  15. package/lib/codexChat.js +2506 -0
  16. package/lib/codexChat.js.map +1 -0
  17. package/lib/codexChatAttachmentDedup.d.ts +10 -0
  18. package/lib/codexChatAttachmentDedup.js +35 -0
  19. package/lib/codexChatAttachmentDedup.js.map +1 -0
  20. package/lib/codexChatAttachmentLimit.d.ts +18 -0
  21. package/lib/codexChatAttachmentLimit.js +50 -0
  22. package/lib/codexChatAttachmentLimit.js.map +1 -0
  23. package/lib/codexChatAttachmentState.d.ts +15 -0
  24. package/lib/codexChatAttachmentState.js +16 -0
  25. package/lib/codexChatAttachmentState.js.map +1 -0
  26. package/lib/codexChatDocumentUtils.d.ts +70 -0
  27. package/lib/codexChatDocumentUtils.js +506 -0
  28. package/lib/codexChatDocumentUtils.js.map +1 -0
  29. package/lib/codexChatFormatting.d.ts +11 -0
  30. package/lib/codexChatFormatting.js +83 -0
  31. package/lib/codexChatFormatting.js.map +1 -0
  32. package/lib/codexChatNotice.d.ts +3 -0
  33. package/lib/codexChatNotice.js +74 -0
  34. package/lib/codexChatNotice.js.map +1 -0
  35. package/lib/codexChatPersistence.d.ts +35 -0
  36. package/lib/codexChatPersistence.js +158 -0
  37. package/lib/codexChatPersistence.js.map +1 -0
  38. package/lib/codexChatPrimitives.d.ts +44 -0
  39. package/lib/codexChatPrimitives.js +156 -0
  40. package/lib/codexChatPrimitives.js.map +1 -0
  41. package/lib/codexChatRender.d.ts +24 -0
  42. package/lib/codexChatRender.js +293 -0
  43. package/lib/codexChatRender.js.map +1 -0
  44. package/lib/codexChatSessionFactory.d.ts +15 -0
  45. package/lib/codexChatSessionFactory.js +45 -0
  46. package/lib/codexChatSessionFactory.js.map +1 -0
  47. package/lib/codexChatSessionKey.d.ts +3 -0
  48. package/lib/codexChatSessionKey.js +14 -0
  49. package/lib/codexChatSessionKey.js.map +1 -0
  50. package/lib/codexChatStorage.d.ts +4 -0
  51. package/lib/codexChatStorage.js +37 -0
  52. package/lib/codexChatStorage.js.map +1 -0
  53. package/lib/codexSessionResolver.d.ts +12 -0
  54. package/lib/codexSessionResolver.js +38 -0
  55. package/lib/codexSessionResolver.js.map +1 -0
  56. package/lib/handlers/activitySummarizer.d.ts +15 -0
  57. package/lib/handlers/activitySummarizer.js +327 -0
  58. package/lib/handlers/activitySummarizer.js.map +1 -0
  59. package/lib/handlers/codexMessageTypes.d.ts +30 -0
  60. package/lib/handlers/codexMessageTypes.js +2 -0
  61. package/lib/handlers/codexMessageTypes.js.map +1 -0
  62. package/lib/handlers/codexMessageUtils.d.ts +46 -0
  63. package/lib/handlers/codexMessageUtils.js +144 -0
  64. package/lib/handlers/codexMessageUtils.js.map +1 -0
  65. package/lib/handlers/handleCodexSocketMessage.d.ts +107 -0
  66. package/lib/handlers/handleCodexSocketMessage.js +78 -0
  67. package/lib/handlers/handleCodexSocketMessage.js.map +1 -0
  68. package/lib/handlers/sessionSyncHandler.d.ts +34 -0
  69. package/lib/handlers/sessionSyncHandler.js +181 -0
  70. package/lib/handlers/sessionSyncHandler.js.map +1 -0
  71. package/lib/hooks/useCodexSocket.d.ts +15 -0
  72. package/lib/hooks/useCodexSocket.js +84 -0
  73. package/lib/hooks/useCodexSocket.js.map +1 -0
  74. package/lib/index.js +1 -1
  75. package/lib/index.js.map +1 -1
  76. package/lib/panel.d.ts +1 -11
  77. package/lib/panel.js +1 -2815
  78. package/lib/panel.js.map +1 -1
  79. package/lib/protocol.d.ts +235 -0
  80. package/lib/protocol.js +278 -0
  81. package/lib/protocol.js.map +1 -0
  82. package/package.json +13 -3
  83. package/playwright.config.cjs +27 -0
  84. package/playwright.unit.config.cjs +19 -0
  85. package/pyproject.toml +1 -1
  86. package/release.sh +52 -14
  87. package/scripts/run_playwright_e2e.sh +96 -0
  88. package/scripts/run_playwright_freeze_repro.sh +58 -0
  89. package/scripts/run_playwright_queue_repro.sh +60 -0
  90. package/scripts/run_playwright_repro.sh +55 -0
  91. package/src/codexChat.tsx +3914 -0
  92. package/src/codexChatAttachmentDedup.ts +47 -0
  93. package/src/codexChatAttachmentLimit.ts +81 -0
  94. package/src/codexChatAttachmentState.ts +37 -0
  95. package/src/codexChatDocumentUtils.ts +644 -0
  96. package/src/codexChatFormatting.ts +94 -0
  97. package/src/codexChatNotice.ts +95 -0
  98. package/src/codexChatPersistence.ts +191 -0
  99. package/src/codexChatPrimitives.tsx +446 -0
  100. package/src/codexChatRender.tsx +376 -0
  101. package/src/codexChatSessionFactory.ts +79 -0
  102. package/src/codexChatSessionKey.ts +16 -0
  103. package/src/codexChatStorage.ts +36 -0
  104. package/src/codexSessionResolver.ts +56 -0
  105. package/src/handlers/activitySummarizer.ts +369 -0
  106. package/src/handlers/codexMessageTypes.ts +34 -0
  107. package/src/handlers/codexMessageUtils.ts +217 -0
  108. package/src/handlers/handleCodexSocketMessage.ts +204 -0
  109. package/src/handlers/sessionSyncHandler.ts +308 -0
  110. package/src/hooks/useCodexSocket.ts +109 -0
  111. package/src/index.ts +1 -1
  112. package/src/panel.tsx +1 -4184
  113. package/src/protocol.ts +582 -0
  114. package/style/index.css +480 -11
  115. package/test-results/.last-run.json +4 -0
  116. package/test.py +0 -0
  117. package/tests/e2e/cell-output-error-tail.spec.js +156 -0
  118. package/tests/e2e/codex-ui-test-helpers.js +138 -0
  119. package/tests/e2e/fixtures/notebooks/error-output-tail.ipynb +58 -0
  120. package/tests/e2e/fixtures/notebooks/error-output-tail.py +19 -0
  121. package/tests/e2e/fixtures/notebooks/tab1.ipynb +322 -0
  122. package/tests/e2e/fixtures/notebooks/tab1.py +272 -0
  123. package/tests/e2e/fixtures/notebooks/tab2.ipynb +252 -0
  124. package/tests/e2e/fixtures/notebooks/tab2.py +231 -0
  125. package/tests/e2e/fixtures/notebooks/tab3.ipynb +403 -0
  126. package/tests/e2e/fixtures/notebooks/tab3.py +331 -0
  127. package/tests/e2e/fixtures/notebooks/tab4.py +339 -0
  128. package/tests/e2e/freeze-notebook-tabs-repro.spec.js +295 -0
  129. package/tests/e2e/mock-codex-cli-flood.py +127 -0
  130. package/tests/e2e/mock-codex-cli-prompt-echo.py +88 -0
  131. package/tests/e2e/mock-codex-cli.py +95 -0
  132. package/tests/e2e/queue-multitab-repro.spec.js +189 -0
  133. package/tests/test_handlers.py +116 -0
  134. package/tests/test_protocol.py +169 -0
  135. package/tests/test_session_store_limits.py +50 -0
  136. package/tests/unit/codexChatAttachmentDedup.spec.ts +56 -0
  137. package/tests/unit/codexChatAttachmentLimit.spec.ts +57 -0
  138. package/tests/unit/codexChatAttachmentState.spec.ts +71 -0
  139. package/tests/unit/codexChatDocumentUtils.spec.ts +63 -0
  140. package/tests/unit/codexChatLimit.spec.ts +18 -0
  141. package/tests/unit/codexChatNotice.spec.ts +45 -0
  142. package/tests/unit/codexChatPersistence.spec.ts +199 -0
  143. package/tests/unit/codexChatSessionFactory.spec.ts +94 -0
  144. package/tests/unit/codexChatSessionKey.spec.ts +18 -0
  145. package/tests/unit/codexMessageUtils.spec.ts +89 -0
  146. package/tests/unit/codexSessionResolver.spec.ts +92 -0
  147. package/tests/unit/handleCodexSocketMessage.spec.ts +476 -0
  148. package/tsconfig.tsbuildinfo +1 -1
  149. package/webpack.config.js +6 -0
  150. package/jupyterlab_codex/labextension/static/504.335f3447c84ba3d74517.js +0 -2
  151. package/jupyterlab_codex/labextension/static/972.8e856719e40acc1ef4cb.js +0 -1
  152. package/jupyterlab_codex/labextension/static/remoteEntry.a2982f776a1f0f515640.js +0 -1
  153. /package/jupyterlab_codex/labextension/static/{504.335f3447c84ba3d74517.js.LICENSE.txt → 525.224526d045c727069de6.js.LICENSE.txt} +0 -0
@@ -0,0 +1,252 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "tab2-cell",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import numpy as np\n",
11
+ "\n",
12
+ "\n",
13
+ "def simulate_schrodinger_1d(\n",
14
+ " n_grid: int = 1024,\n",
15
+ " x_min: float = -20.0,\n",
16
+ " x_max: float = 20.0,\n",
17
+ " dt: float = 0.005,\n",
18
+ " steps: int = 2500,\n",
19
+ " save_every: int = 250,\n",
20
+ " potential_type: str = \"harmonic\",\n",
21
+ " omega: float = 0.2,\n",
22
+ " barrier_height: float = 1.5,\n",
23
+ " barrier_width: float = 1.5,\n",
24
+ " x0: float = -7.0,\n",
25
+ " sigma: float = 1.0,\n",
26
+ " k0: float = 2.0,\n",
27
+ ") -> dict[str, object]:\n",
28
+ " \"\"\"\n",
29
+ " Simulate a 1D wave packet under the time-dependent Schrodinger equation\n",
30
+ " using a split-operator FFT method in natural units (hbar = m = 1).\n",
31
+ " \"\"\"\n",
32
+ " if not isinstance(n_grid, int) or isinstance(n_grid, bool):\n",
33
+ " raise TypeError(\"n_grid must be int\")\n",
34
+ " if n_grid < 64:\n",
35
+ " raise ValueError(\"n_grid must be >= 64\")\n",
36
+ " if x_max <= x_min:\n",
37
+ " raise ValueError(\"x_max must be greater than x_min\")\n",
38
+ " if dt <= 0:\n",
39
+ " raise ValueError(\"dt must be > 0\")\n",
40
+ " if not isinstance(steps, int) or isinstance(steps, bool) or steps < 1:\n",
41
+ " raise ValueError(\"steps must be an int >= 1\")\n",
42
+ " if not isinstance(save_every, int) or isinstance(save_every, bool) or save_every < 1:\n",
43
+ " raise ValueError(\"save_every must be an int >= 1\")\n",
44
+ " if sigma <= 0:\n",
45
+ " raise ValueError(\"sigma must be > 0\")\n",
46
+ "\n",
47
+ " x = np.linspace(x_min, x_max, n_grid, endpoint=False)\n",
48
+ " dx = x[1] - x[0]\n",
49
+ "\n",
50
+ " if potential_type == \"harmonic\":\n",
51
+ " if omega < 0:\n",
52
+ " raise ValueError(\"omega must be >= 0\")\n",
53
+ " potential = 0.5 * (omega**2) * (x**2)\n",
54
+ " elif potential_type == \"barrier\":\n",
55
+ " if barrier_width < 0:\n",
56
+ " raise ValueError(\"barrier_width must be >= 0\")\n",
57
+ " potential = np.where(np.abs(x) < barrier_width / 2.0, barrier_height, 0.0)\n",
58
+ " else:\n",
59
+ " raise ValueError(\"potential_type must be 'harmonic' or 'barrier'\")\n",
60
+ "\n",
61
+ " psi = np.exp(-((x - x0) ** 2) / (2.0 * sigma**2)) * np.exp(1j * k0 * x)\n",
62
+ "\n",
63
+ " def norm(wavefunc: np.ndarray) -> float:\n",
64
+ " return float(np.sum(np.abs(wavefunc) ** 2) * dx)\n",
65
+ "\n",
66
+ " def expected_x(wavefunc: np.ndarray) -> float:\n",
67
+ " density = np.abs(wavefunc) ** 2\n",
68
+ " return float(np.sum(x * density) * dx)\n",
69
+ "\n",
70
+ " def total_energy(wavefunc: np.ndarray) -> float:\n",
71
+ " grad = np.gradient(wavefunc, dx)\n",
72
+ " kinetic = 0.5 * np.sum(np.abs(grad) ** 2) * dx\n",
73
+ " potential_energy = np.sum(potential * (np.abs(wavefunc) ** 2)) * dx\n",
74
+ " return float(np.real(kinetic + potential_energy))\n",
75
+ "\n",
76
+ " psi /= np.sqrt(norm(psi))\n",
77
+ " k = 2.0 * np.pi * np.fft.fftfreq(n_grid, d=dx)\n",
78
+ " kinetic_phase = np.exp(-0.5j * (k**2) * dt)\n",
79
+ " potential_half_phase = np.exp(-0.5j * potential * dt)\n",
80
+ "\n",
81
+ " snapshots: list[tuple[float, np.ndarray]] = [(0.0, np.abs(psi) ** 2)]\n",
82
+ " norm_history = [norm(psi)]\n",
83
+ " x_expect_history = [expected_x(psi)]\n",
84
+ " energy_history = [total_energy(psi)]\n",
85
+ "\n",
86
+ " for step in range(1, steps + 1):\n",
87
+ " psi = potential_half_phase * psi\n",
88
+ " psi_k = np.fft.fft(psi)\n",
89
+ " psi_k = kinetic_phase * psi_k\n",
90
+ " psi = np.fft.ifft(psi_k)\n",
91
+ " psi = potential_half_phase * psi\n",
92
+ " psi /= np.sqrt(norm(psi))\n",
93
+ "\n",
94
+ " current_norm = norm(psi)\n",
95
+ " norm_history.append(current_norm)\n",
96
+ " x_expect_history.append(expected_x(psi))\n",
97
+ " energy_history.append(total_energy(psi))\n",
98
+ "\n",
99
+ " if step % save_every == 0:\n",
100
+ " snapshots.append((step * dt, np.abs(psi) ** 2))\n",
101
+ "\n",
102
+ " if steps % save_every != 0:\n",
103
+ " snapshots.append((steps * dt, np.abs(psi) ** 2))\n",
104
+ "\n",
105
+ " return {\n",
106
+ " \"x\": x,\n",
107
+ " \"V\": potential,\n",
108
+ " \"psi\": psi,\n",
109
+ " \"dx\": dx,\n",
110
+ " \"dt\": dt,\n",
111
+ " \"steps\": steps,\n",
112
+ " \"snapshots\": snapshots,\n",
113
+ " \"final_norm\": norm(psi),\n",
114
+ " \"norm_history\": np.array(norm_history),\n",
115
+ " \"x_expect_history\": np.array(x_expect_history),\n",
116
+ " \"energy_history\": np.array(energy_history),\n",
117
+ " }\n",
118
+ "\n",
119
+ "\n",
120
+ "def validate_simulation(\n",
121
+ " result: dict[str, object],\n",
122
+ " norm_tolerance: float = 1e-6,\n",
123
+ " energy_tolerance: float = 0.6,\n",
124
+ ") -> None:\n",
125
+ " norms = np.asarray(result[\"norm_history\"], dtype=float)\n",
126
+ " if not np.all(np.isclose(norms, 1.0, atol=norm_tolerance)):\n",
127
+ " min_norm = float(np.min(norms))\n",
128
+ " max_norm = float(np.max(norms))\n",
129
+ " raise AssertionError(\n",
130
+ " f\"norm preservation failed: min={min_norm:.8f}, max={max_norm:.8f}\"\n",
131
+ " )\n",
132
+ " energies = np.asarray(result[\"energy_history\"], dtype=float)\n",
133
+ " if energies.size > 1:\n",
134
+ " drift = float(np.max(energies) - np.min(energies))\n",
135
+ " if drift > energy_tolerance:\n",
136
+ " raise AssertionError(f\"energy drift too large: {drift:.6f}\")\n",
137
+ "\n",
138
+ "\n",
139
+ "def transmission_reflection_probabilities(\n",
140
+ " result: dict[str, object],\n",
141
+ " split_x: float = 0.0,\n",
142
+ ") -> tuple[float, float]:\n",
143
+ " x = np.asarray(result[\"x\"], dtype=float)\n",
144
+ " psi = np.asarray(result[\"psi\"], dtype=np.complex128)\n",
145
+ " dx = float(result[\"dx\"])\n",
146
+ " density = np.abs(psi) ** 2\n",
147
+ " reflection = float(np.sum(density[x < split_x]) * dx)\n",
148
+ " transmission = float(np.sum(density[x >= split_x]) * dx)\n",
149
+ " return transmission, reflection\n",
150
+ "\n",
151
+ "\n",
152
+ "def plot_simulation(result: dict[str, object]) -> None:\n",
153
+ " import matplotlib.pyplot as plt\n",
154
+ "\n",
155
+ " x = np.asarray(result[\"x\"])\n",
156
+ " potential = np.asarray(result[\"V\"])\n",
157
+ " snapshots = result[\"snapshots\"]\n",
158
+ "\n",
159
+ " max_density = max(float(np.max(density)) for _, density in snapshots)\n",
160
+ " shifted = potential - float(np.min(potential))\n",
161
+ " vmax = float(np.max(shifted))\n",
162
+ " if vmax > 0:\n",
163
+ " scaled_potential = (shifted / vmax) * (0.8 * max_density)\n",
164
+ " else:\n",
165
+ " scaled_potential = np.zeros_like(potential)\n",
166
+ "\n",
167
+ " plt.figure(figsize=(10, 5))\n",
168
+ " plt.plot(x, scaled_potential, \"k--\", linewidth=1.2, label=\"Scaled potential\")\n",
169
+ " for time_point, density in snapshots:\n",
170
+ " plt.plot(x, density, label=f\"t = {time_point:.2f}\")\n",
171
+ " plt.title(\"1D Time-Dependent Schrodinger Simulation\")\n",
172
+ " plt.xlabel(\"x\")\n",
173
+ " plt.ylabel(r\"Probability density $|\\psi(x,t)|^2$\")\n",
174
+ " plt.legend(ncol=2)\n",
175
+ " plt.grid(alpha=0.25)\n",
176
+ " plt.tight_layout()\n",
177
+ " plt.show()\n",
178
+ "\n",
179
+ "\n",
180
+ "def run_demo(enable_plot: bool = False) -> None:\n",
181
+ " harmonic_result = simulate_schrodinger_1d(\n",
182
+ " potential_type=\"harmonic\",\n",
183
+ " omega=0.2,\n",
184
+ " x0=-7.0,\n",
185
+ " k0=2.0,\n",
186
+ " )\n",
187
+ " validate_simulation(harmonic_result)\n",
188
+ "\n",
189
+ " final_norm = float(harmonic_result[\"final_norm\"])\n",
190
+ " x_expect = np.asarray(harmonic_result[\"x_expect_history\"], dtype=float)\n",
191
+ " print(\"Harmonic oscillator\")\n",
192
+ " print(f\"Final norm: {final_norm:.6f}\")\n",
193
+ " print(\n",
194
+ " \"Expectation x range:\",\n",
195
+ " f\"[{float(np.min(x_expect)):.3f}, {float(np.max(x_expect)):.3f}]\",\n",
196
+ " )\n",
197
+ "\n",
198
+ " barrier_result = simulate_schrodinger_1d(\n",
199
+ " potential_type=\"barrier\",\n",
200
+ " barrier_height=1.5,\n",
201
+ " barrier_width=1.5,\n",
202
+ " x0=-10.0,\n",
203
+ " k0=2.2,\n",
204
+ " steps=3000,\n",
205
+ " )\n",
206
+ " validate_simulation(barrier_result)\n",
207
+ " transmission, reflection = transmission_reflection_probabilities(\n",
208
+ " barrier_result,\n",
209
+ " split_x=0.0,\n",
210
+ " )\n",
211
+ " assert np.isclose(transmission + reflection, 1.0, atol=1e-5)\n",
212
+ " print(\"Barrier tunneling\")\n",
213
+ " print(\n",
214
+ " \"Transmission / Reflection:\",\n",
215
+ " f\"{transmission:.4f} / {reflection:.4f}\",\n",
216
+ " )\n",
217
+ " print(\"Validation passed.\")\n",
218
+ "\n",
219
+ " if enable_plot:\n",
220
+ " plot_simulation(harmonic_result)\n",
221
+ " plot_simulation(barrier_result)\n",
222
+ "\n",
223
+ "\n",
224
+ "run_demo(enable_plot=False)"
225
+ ]
226
+ }
227
+ ],
228
+ "metadata": {
229
+ "jupytext": {
230
+ "formats": "ipynb,py:percent"
231
+ },
232
+ "kernelspec": {
233
+ "display_name": "Python 3 (ipykernel)",
234
+ "language": "python",
235
+ "name": "python3"
236
+ },
237
+ "language_info": {
238
+ "codemirror_mode": {
239
+ "name": "ipython",
240
+ "version": 3
241
+ },
242
+ "file_extension": ".py",
243
+ "mimetype": "text/x-python",
244
+ "name": "python",
245
+ "nbconvert_exporter": "python",
246
+ "pygments_lexer": "ipython3",
247
+ "version": "3.12.10"
248
+ }
249
+ },
250
+ "nbformat": 4,
251
+ "nbformat_minor": 5
252
+ }
@@ -0,0 +1,231 @@
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # formats: ipynb,py:percent
5
+ # text_representation:
6
+ # extension: .py
7
+ # format_name: percent
8
+ # format_version: '1.3'
9
+ # jupytext_version: 1.17.2
10
+ # kernelspec:
11
+ # display_name: Python 3 (ipykernel)
12
+ # language: python
13
+ # name: python3
14
+ # ---
15
+
16
+ # %%
17
+ import numpy as np
18
+
19
+
20
+ def simulate_schrodinger_1d(
21
+ n_grid: int = 1024,
22
+ x_min: float = -20.0,
23
+ x_max: float = 20.0,
24
+ dt: float = 0.005,
25
+ steps: int = 2500,
26
+ save_every: int = 250,
27
+ potential_type: str = "harmonic",
28
+ omega: float = 0.2,
29
+ barrier_height: float = 1.5,
30
+ barrier_width: float = 1.5,
31
+ x0: float = -7.0,
32
+ sigma: float = 1.0,
33
+ k0: float = 2.0,
34
+ ) -> dict[str, object]:
35
+ """
36
+ Simulate a 1D wave packet under the time-dependent Schrodinger equation
37
+ using a split-operator FFT method in natural units (hbar = m = 1).
38
+ """
39
+ if not isinstance(n_grid, int) or isinstance(n_grid, bool):
40
+ raise TypeError("n_grid must be int")
41
+ if n_grid < 64:
42
+ raise ValueError("n_grid must be >= 64")
43
+ if x_max <= x_min:
44
+ raise ValueError("x_max must be greater than x_min")
45
+ if dt <= 0:
46
+ raise ValueError("dt must be > 0")
47
+ if not isinstance(steps, int) or isinstance(steps, bool) or steps < 1:
48
+ raise ValueError("steps must be an int >= 1")
49
+ if not isinstance(save_every, int) or isinstance(save_every, bool) or save_every < 1:
50
+ raise ValueError("save_every must be an int >= 1")
51
+ if sigma <= 0:
52
+ raise ValueError("sigma must be > 0")
53
+
54
+ x = np.linspace(x_min, x_max, n_grid, endpoint=False)
55
+ dx = x[1] - x[0]
56
+
57
+ if potential_type == "harmonic":
58
+ if omega < 0:
59
+ raise ValueError("omega must be >= 0")
60
+ potential = 0.5 * (omega**2) * (x**2)
61
+ elif potential_type == "barrier":
62
+ if barrier_width < 0:
63
+ raise ValueError("barrier_width must be >= 0")
64
+ potential = np.where(np.abs(x) < barrier_width / 2.0, barrier_height, 0.0)
65
+ else:
66
+ raise ValueError("potential_type must be 'harmonic' or 'barrier'")
67
+
68
+ psi = np.exp(-((x - x0) ** 2) / (2.0 * sigma**2)) * np.exp(1j * k0 * x)
69
+
70
+ def norm(wavefunc: np.ndarray) -> float:
71
+ return float(np.sum(np.abs(wavefunc) ** 2) * dx)
72
+
73
+ def expected_x(wavefunc: np.ndarray) -> float:
74
+ density = np.abs(wavefunc) ** 2
75
+ return float(np.sum(x * density) * dx)
76
+
77
+ def total_energy(wavefunc: np.ndarray) -> float:
78
+ grad = np.gradient(wavefunc, dx)
79
+ kinetic = 0.5 * np.sum(np.abs(grad) ** 2) * dx
80
+ potential_energy = np.sum(potential * (np.abs(wavefunc) ** 2)) * dx
81
+ return float(np.real(kinetic + potential_energy))
82
+
83
+ psi /= np.sqrt(norm(psi))
84
+ k = 2.0 * np.pi * np.fft.fftfreq(n_grid, d=dx)
85
+ kinetic_phase = np.exp(-0.5j * (k**2) * dt)
86
+ potential_half_phase = np.exp(-0.5j * potential * dt)
87
+
88
+ snapshots: list[tuple[float, np.ndarray]] = [(0.0, np.abs(psi) ** 2)]
89
+ norm_history = [norm(psi)]
90
+ x_expect_history = [expected_x(psi)]
91
+ energy_history = [total_energy(psi)]
92
+
93
+ for step in range(1, steps + 1):
94
+ psi = potential_half_phase * psi
95
+ psi_k = np.fft.fft(psi)
96
+ psi_k = kinetic_phase * psi_k
97
+ psi = np.fft.ifft(psi_k)
98
+ psi = potential_half_phase * psi
99
+ psi /= np.sqrt(norm(psi))
100
+
101
+ current_norm = norm(psi)
102
+ norm_history.append(current_norm)
103
+ x_expect_history.append(expected_x(psi))
104
+ energy_history.append(total_energy(psi))
105
+
106
+ if step % save_every == 0:
107
+ snapshots.append((step * dt, np.abs(psi) ** 2))
108
+
109
+ if steps % save_every != 0:
110
+ snapshots.append((steps * dt, np.abs(psi) ** 2))
111
+
112
+ return {
113
+ "x": x,
114
+ "V": potential,
115
+ "psi": psi,
116
+ "dx": dx,
117
+ "dt": dt,
118
+ "steps": steps,
119
+ "snapshots": snapshots,
120
+ "final_norm": norm(psi),
121
+ "norm_history": np.array(norm_history),
122
+ "x_expect_history": np.array(x_expect_history),
123
+ "energy_history": np.array(energy_history),
124
+ }
125
+
126
+
127
+ def validate_simulation(
128
+ result: dict[str, object],
129
+ norm_tolerance: float = 1e-6,
130
+ energy_tolerance: float = 0.6,
131
+ ) -> None:
132
+ norms = np.asarray(result["norm_history"], dtype=float)
133
+ if not np.all(np.isclose(norms, 1.0, atol=norm_tolerance)):
134
+ min_norm = float(np.min(norms))
135
+ max_norm = float(np.max(norms))
136
+ raise AssertionError(
137
+ f"norm preservation failed: min={min_norm:.8f}, max={max_norm:.8f}"
138
+ )
139
+ energies = np.asarray(result["energy_history"], dtype=float)
140
+ if energies.size > 1:
141
+ drift = float(np.max(energies) - np.min(energies))
142
+ if drift > energy_tolerance:
143
+ raise AssertionError(f"energy drift too large: {drift:.6f}")
144
+
145
+
146
+ def transmission_reflection_probabilities(
147
+ result: dict[str, object],
148
+ split_x: float = 0.0,
149
+ ) -> tuple[float, float]:
150
+ x = np.asarray(result["x"], dtype=float)
151
+ psi = np.asarray(result["psi"], dtype=np.complex128)
152
+ dx = float(result["dx"])
153
+ density = np.abs(psi) ** 2
154
+ reflection = float(np.sum(density[x < split_x]) * dx)
155
+ transmission = float(np.sum(density[x >= split_x]) * dx)
156
+ return transmission, reflection
157
+
158
+
159
+ def plot_simulation(result: dict[str, object]) -> None:
160
+ import matplotlib.pyplot as plt
161
+
162
+ x = np.asarray(result["x"])
163
+ potential = np.asarray(result["V"])
164
+ snapshots = result["snapshots"]
165
+
166
+ max_density = max(float(np.max(density)) for _, density in snapshots)
167
+ shifted = potential - float(np.min(potential))
168
+ vmax = float(np.max(shifted))
169
+ if vmax > 0:
170
+ scaled_potential = (shifted / vmax) * (0.8 * max_density)
171
+ else:
172
+ scaled_potential = np.zeros_like(potential)
173
+
174
+ plt.figure(figsize=(10, 5))
175
+ plt.plot(x, scaled_potential, "k--", linewidth=1.2, label="Scaled potential")
176
+ for time_point, density in snapshots:
177
+ plt.plot(x, density, label=f"t = {time_point:.2f}")
178
+ plt.title("1D Time-Dependent Schrodinger Simulation")
179
+ plt.xlabel("x")
180
+ plt.ylabel(r"Probability density $|\psi(x,t)|^2$")
181
+ plt.legend(ncol=2)
182
+ plt.grid(alpha=0.25)
183
+ plt.tight_layout()
184
+ plt.show()
185
+
186
+
187
+ def run_demo(enable_plot: bool = False) -> None:
188
+ harmonic_result = simulate_schrodinger_1d(
189
+ potential_type="harmonic",
190
+ omega=0.2,
191
+ x0=-7.0,
192
+ k0=2.0,
193
+ )
194
+ validate_simulation(harmonic_result)
195
+
196
+ final_norm = float(harmonic_result["final_norm"])
197
+ x_expect = np.asarray(harmonic_result["x_expect_history"], dtype=float)
198
+ print("Harmonic oscillator")
199
+ print(f"Final norm: {final_norm:.6f}")
200
+ print(
201
+ "Expectation x range:",
202
+ f"[{float(np.min(x_expect)):.3f}, {float(np.max(x_expect)):.3f}]",
203
+ )
204
+
205
+ barrier_result = simulate_schrodinger_1d(
206
+ potential_type="barrier",
207
+ barrier_height=1.5,
208
+ barrier_width=1.5,
209
+ x0=-10.0,
210
+ k0=2.2,
211
+ steps=3000,
212
+ )
213
+ validate_simulation(barrier_result)
214
+ transmission, reflection = transmission_reflection_probabilities(
215
+ barrier_result,
216
+ split_x=0.0,
217
+ )
218
+ assert np.isclose(transmission + reflection, 1.0, atol=1e-5)
219
+ print("Barrier tunneling")
220
+ print(
221
+ "Transmission / Reflection:",
222
+ f"{transmission:.4f} / {reflection:.4f}",
223
+ )
224
+ print("Validation passed.")
225
+
226
+ if enable_plot:
227
+ plot_simulation(harmonic_result)
228
+ plot_simulation(barrier_result)
229
+
230
+
231
+ run_demo(enable_plot=False)