tensorquantlib 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. tensorquantlib/__init__.py +313 -0
  2. tensorquantlib/__main__.py +315 -0
  3. tensorquantlib/backtest/__init__.py +48 -0
  4. tensorquantlib/backtest/engine.py +240 -0
  5. tensorquantlib/backtest/metrics.py +320 -0
  6. tensorquantlib/backtest/strategy.py +348 -0
  7. tensorquantlib/core/__init__.py +6 -0
  8. tensorquantlib/core/ops.py +70 -0
  9. tensorquantlib/core/second_order.py +465 -0
  10. tensorquantlib/core/tensor.py +928 -0
  11. tensorquantlib/data/__init__.py +16 -0
  12. tensorquantlib/data/market.py +160 -0
  13. tensorquantlib/finance/__init__.py +52 -0
  14. tensorquantlib/finance/american.py +263 -0
  15. tensorquantlib/finance/basket.py +291 -0
  16. tensorquantlib/finance/black_scholes.py +219 -0
  17. tensorquantlib/finance/credit.py +199 -0
  18. tensorquantlib/finance/exotics.py +885 -0
  19. tensorquantlib/finance/fx.py +204 -0
  20. tensorquantlib/finance/greeks.py +133 -0
  21. tensorquantlib/finance/heston.py +543 -0
  22. tensorquantlib/finance/implied_vol.py +277 -0
  23. tensorquantlib/finance/ir_derivatives.py +203 -0
  24. tensorquantlib/finance/jump_diffusion.py +203 -0
  25. tensorquantlib/finance/local_vol.py +146 -0
  26. tensorquantlib/finance/rates.py +381 -0
  27. tensorquantlib/finance/risk.py +344 -0
  28. tensorquantlib/finance/variance_reduction.py +420 -0
  29. tensorquantlib/finance/volatility.py +355 -0
  30. tensorquantlib/py.typed +0 -0
  31. tensorquantlib/tt/__init__.py +43 -0
  32. tensorquantlib/tt/decompose.py +576 -0
  33. tensorquantlib/tt/ops.py +386 -0
  34. tensorquantlib/tt/pricing.py +304 -0
  35. tensorquantlib/tt/surrogate.py +634 -0
  36. tensorquantlib/utils/__init__.py +5 -0
  37. tensorquantlib/utils/validation.py +126 -0
  38. tensorquantlib/viz/__init__.py +27 -0
  39. tensorquantlib/viz/plots.py +331 -0
  40. tensorquantlib-0.3.0.dist-info/METADATA +602 -0
  41. tensorquantlib-0.3.0.dist-info/RECORD +44 -0
  42. tensorquantlib-0.3.0.dist-info/WHEEL +5 -0
  43. tensorquantlib-0.3.0.dist-info/licenses/LICENSE +21 -0
  44. tensorquantlib-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,576 @@
1
+ """
2
+ Tensor-Train decomposition algorithms.
3
+
4
+ Implements:
5
+ - tt_svd: TT-SVD decomposition (Oseledets, 2011)
6
+ - tt_round: TT-rounding via orthogonalization + truncated SVD
7
+ - tt_cross: Black-box TT-Cross approximation (Oseledets & Tyrtyshnikov, 2010)
8
+ Builds a TT decomposition without forming the full tensor,
9
+ making 6+ asset problems feasible.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Callable
15
+
16
+ import numpy as np
17
+ from scipy.linalg import qr, solve
18
+
19
+
20
+ def tt_svd(
21
+ tensor: np.ndarray,
22
+ eps: float = 1e-6,
23
+ max_rank: int | None = None,
24
+ ) -> list[np.ndarray]:
25
+ """Tensor-Train SVD decomposition.
26
+
27
+ Decomposes a d-dimensional tensor A of shape (n1, n2, ..., nd) into
28
+ a list of TT-cores [G1, G2, ..., Gd] where
29
+ G_k has shape (r_{k-1}, n_k, r_k) and r_0 = r_d = 1.
30
+
31
+ The reconstruction satisfies
32
+ ``||A - A_TT||_F <= eps * ||A||_F``.
33
+
34
+ Algorithm: Sequential left-to-right unfolding with truncated SVD.
35
+ Per-step truncation threshold: ``delta = eps * ||A||_F / sqrt(d-1)``.
36
+
37
+ Args:
38
+ tensor: Input tensor, shape (n1, n2, ..., nd).
39
+ eps: Relative truncation tolerance.
40
+ max_rank: Maximum TT-rank (optional safety cap).
41
+
42
+ Returns:
43
+ List of TT-cores. cores[k].shape = (r_{k-1}, n_k, r_k).
44
+ """
45
+ d = tensor.ndim
46
+ if d < 2:
47
+ raise ValueError(f"Tensor must have at least 2 dimensions, got {d}")
48
+ if eps < 0:
49
+ raise ValueError(f"eps must be non-negative, got {eps}")
50
+ if max_rank is not None and max_rank < 1:
51
+ raise ValueError(f"max_rank must be >= 1, got {max_rank}")
52
+
53
+ shape = tensor.shape
54
+ norm_A = np.linalg.norm(tensor)
55
+
56
+ # Handle zero tensor
57
+ if norm_A < 1e-15:
58
+ cores = []
59
+ for k in range(d):
60
+ cores.append(np.zeros((1, shape[k], 1)))
61
+ return cores
62
+
63
+ # Per-step truncation threshold (guarantees total error <= eps * ||A||_F)
64
+ delta = eps * norm_A / np.sqrt(d - 1)
65
+
66
+ cores = []
67
+ C = tensor.copy().astype(np.float64)
68
+ r_prev = 1
69
+
70
+ for k in range(d - 1):
71
+ n_k = shape[k]
72
+ # Reshape C into 2D matrix: (r_prev * n_k) x (remaining dimensions)
73
+ C = C.reshape(r_prev * n_k, -1)
74
+
75
+ # Economy SVD
76
+ U, S, Vt = np.linalg.svd(C, full_matrices=False)
77
+
78
+ # Rank selection: find smallest r_k such that
79
+ # sqrt(sum(S[r_k:]^2)) <= delta
80
+ # Use reverse cumsum for numerical stability
81
+ S_sq = S ** 2
82
+ tail_norms_sq = np.cumsum(S_sq[::-1])[::-1] # tail_norms_sq[i] = sum(S[i:]^2)
83
+
84
+ # Find rank: smallest r such that tail_norms_sq[r] <= delta^2
85
+ # tail_norms_sq has length len(S), and we want the smallest r >= 1
86
+ # such that tail_norms_sq[r] <= delta^2 (where tail_norms_sq[len(S)] = 0)
87
+ delta_sq = delta ** 2
88
+ r_k = len(S) # default: keep all
89
+ for i in range(1, len(S)):
90
+ if tail_norms_sq[i] <= delta_sq:
91
+ r_k = i
92
+ break
93
+
94
+ # Apply max_rank cap
95
+ if max_rank is not None:
96
+ r_k = min(r_k, max_rank)
97
+
98
+ # Ensure at least rank 1
99
+ r_k = max(r_k, 1)
100
+
101
+ # Truncate
102
+ U_trunc = U[:, :r_k]
103
+ S_trunc = S[:r_k]
104
+ Vt_trunc = Vt[:r_k, :]
105
+
106
+ # Store core: reshape U into (r_prev, n_k, r_k)
107
+ cores.append(U_trunc.reshape(r_prev, n_k, r_k))
108
+
109
+ # Prepare next iteration: C = diag(S) @ Vt
110
+ C = np.diag(S_trunc) @ Vt_trunc
111
+ r_prev = r_k
112
+
113
+ # Last core: reshape remaining matrix into (r_prev, n_d, 1)
114
+ n_d = shape[-1]
115
+ cores.append(C.reshape(r_prev, n_d, 1))
116
+
117
+ return cores
118
+
119
+
120
+ def tt_round(
121
+ cores: list[np.ndarray],
122
+ eps: float = 1e-6,
123
+ max_rank: int | None = None,
124
+ ) -> list[np.ndarray]:
125
+ """Reduce TT-ranks via orthogonalization + truncated SVD sweep.
126
+
127
+ Two-pass algorithm:
128
+ 1. Right-to-left QR sweep (right-orthogonalize)
129
+ 2. Left-to-right SVD sweep with truncation
130
+
131
+ This is used after TT arithmetic (e.g., tt_add) which inflates ranks.
132
+
133
+ Args:
134
+ cores: List of TT-cores.
135
+ eps: Relative truncation tolerance.
136
+ max_rank: Maximum allowed rank.
137
+
138
+ Returns:
139
+ New list of TT-cores with reduced ranks.
140
+ """
141
+ d = len(cores)
142
+ if d < 2:
143
+ return [c.copy() for c in cores]
144
+
145
+ # Work with copies
146
+ cores = [c.copy() for c in cores]
147
+
148
+ # Compute norm for truncation threshold
149
+ # Reconstruct isn't feasible for large tensors, so estimate from cores
150
+ # We use the Frobenius norm through the TT structure
151
+ # For simplicity, do full reconstruction if small, otherwise use core norms
152
+ norm_est = _tt_norm(cores)
153
+ if norm_est < 1e-15:
154
+ return cores
155
+
156
+ delta = eps * norm_est / np.sqrt(d - 1)
157
+
158
+ # ---- Pass 1: Right-to-left QR sweep ----
159
+ for k in range(d - 1, 0, -1):
160
+ r_left, n_k, r_right = cores[k].shape
161
+ # Reshape core to (r_left, n_k * r_right) then transpose → (n_k * r_right, r_left)
162
+ M = cores[k].reshape(r_left, n_k * r_right).T
163
+ Q, R = np.linalg.qr(M)
164
+ # Q: (n_k * r_right, new_r), R: (new_r, r_left)
165
+ new_r = Q.shape[1]
166
+ cores[k] = Q.T.reshape(new_r, n_k, r_right)
167
+ # Absorb R into the previous core: contract on right bond dimension
168
+ r_left_prev, n_prev, _ = cores[k - 1].shape
169
+ # cores[k-1]: (r_left_prev, n_prev, r_left), R: (new_r, r_left)
170
+ # new_core[k-1][i,j,l] = sum_m cores[k-1][i,j,m] * R[l,m]
171
+ new_prev = cores[k - 1].reshape(r_left_prev * n_prev, r_left) @ R.T
172
+ cores[k - 1] = new_prev.reshape(r_left_prev, n_prev, new_r)
173
+
174
+ # ---- Pass 2: Left-to-right SVD sweep with truncation ----
175
+ for k in range(d - 1):
176
+ r_left, n_k, r_right = cores[k].shape
177
+ M = cores[k].reshape(r_left * n_k, r_right)
178
+
179
+ U, S, Vt = np.linalg.svd(M, full_matrices=False)
180
+
181
+ # Rank truncation
182
+ S_sq = S ** 2
183
+ tail_norms_sq = np.cumsum(S_sq[::-1])[::-1]
184
+ delta_sq = delta ** 2
185
+
186
+ r_new = len(S)
187
+ for i in range(1, len(S)):
188
+ if tail_norms_sq[i] <= delta_sq:
189
+ r_new = i
190
+ break
191
+
192
+ if max_rank is not None:
193
+ r_new = min(r_new, max_rank)
194
+ r_new = max(r_new, 1)
195
+
196
+ U_trunc = U[:, :r_new]
197
+ S_trunc = S[:r_new]
198
+ Vt_trunc = Vt[:r_new, :]
199
+
200
+ cores[k] = U_trunc.reshape(r_left, n_k, r_new)
201
+
202
+ # Absorb S*Vt into next core
203
+ SV = np.diag(S_trunc) @ Vt_trunc # (r_new, r_right)
204
+ _r_left_next, _n_next, _r_right_next = cores[k + 1].shape
205
+ # cores[k+1] was (r_right, n_next, r_right_next), multiply from left
206
+ cores[k + 1] = np.einsum("ij,jkl->ikl", SV, cores[k + 1])
207
+
208
+ return cores
209
+
210
+
211
+ def _tt_norm(cores: list[np.ndarray]) -> float:
212
+ """Compute the Frobenius norm of a tensor in TT format.
213
+
214
+ Uses the transfer matrix approach: ||A||_F^2 = <A, A>_TT.
215
+ Complexity: O(d * n * r^4) where r is the max rank.
216
+ """
217
+ d = len(cores)
218
+ # Initialize: contract first core with itself
219
+ # cores[0] shape: (1, n_0, r_0)
220
+ G = cores[0]
221
+ # <G, G> along mode n_0: sum over n_0 of G[:, i, :] ⊗ G[:, i, :]
222
+ # Result shape: (r_0, r_0) — but since r_left=1 for first core, it's (r_0, r_0)
223
+ r_0 = G.shape[2]
224
+ Z = np.zeros((r_0, r_0))
225
+ for i in range(G.shape[1]):
226
+ Z += G[0, i, :].reshape(-1, 1) @ G[0, i, :].reshape(1, -1)
227
+
228
+ for k in range(1, d):
229
+ G = cores[k]
230
+ _r_left, n_k, r_right = G.shape
231
+ Z_new = np.zeros((r_right, r_right))
232
+ for i in range(n_k):
233
+ # G[:, i, :] is (r_left, r_right)
234
+ slice_k = G[:, i, :] # (r_left, r_right)
235
+ # Z is (r_left, r_left) from previous step
236
+ # Contribution: slice_k^T @ Z @ slice_k → (r_right, r_right)
237
+ Z_new += slice_k.T @ Z @ slice_k
238
+ Z = Z_new
239
+
240
+ # Z is now (1, 1) — the squared norm
241
+ return float(np.sqrt(float(Z.item())))
242
+
243
+
244
+ # ======================================================================
245
+ # TT-Cross (black-box approximation — no full tensor needed)
246
+ # ======================================================================
247
+
248
+ def _maxvol_greedy(A: np.ndarray, r: int, rng: np.random.Generator) -> np.ndarray:
249
+ """Approximate maximum-volume row subset of A (n × k, n ≥ k).
250
+
251
+ Returns r row indices forming an approximate maximum-volume
252
+ (r × k) submatrix of A. Uses greedy pivoting based on QR.
253
+
254
+ Algorithm
255
+ ---------
256
+ 1. Find first r pivots via QR with column pivoting on A^T.
257
+ 2. Iteratively swap rows to increase the determinant of the
258
+ selected submatrix until convergence (maxvol criterion).
259
+ """
260
+ n, k_cols = A.shape
261
+ r = min(r, n, k_cols)
262
+ if r == 0:
263
+ return np.array([], dtype=int)
264
+
265
+ # Initial pivot rows from QR
266
+ _, _, piv = qr(A.T, pivoting=True, mode="economic")
267
+ idx = piv[:r].copy()
268
+
269
+ # Iterative improvement: swap rows to increase abs(det)
270
+ # B = A @ inv(A[idx, :]) — each row B[i] represents how much
271
+ # row i is "outside" the current selection
272
+ sub = A[idx, :] # (r, k_cols)
273
+ try:
274
+ B = np.linalg.lstsq(sub.T, A.T, rcond=None)[0].T # (n, r)
275
+ except np.linalg.LinAlgError:
276
+ return idx
277
+
278
+ max_iter = min(100, n)
279
+ tol = 1.0 + 1e-4
280
+ for _ in range(max_iter):
281
+ i_best, j_best = np.unravel_index(np.argmax(np.abs(B)), B.shape)
282
+ if abs(B[i_best, j_best]) <= tol:
283
+ break
284
+ # Swap row i_best into position j_best
285
+ idx[j_best] = i_best
286
+ sub = A[idx, :]
287
+ try:
288
+ B = np.linalg.lstsq(sub.T, A.T, rcond=None)[0].T
289
+ except np.linalg.LinAlgError:
290
+ break
291
+
292
+ return idx
293
+
294
+
295
+ def _eval_fiber(
296
+ fn: Callable[..., float],
297
+ left_idx: np.ndarray, # shape (r_l, k) — left multi-indices
298
+ k: int, # current mode position (0-based)
299
+ n_k: int, # size of mode k
300
+ right_idx: np.ndarray, # shape (r_r, d-k-1) — right multi-indices
301
+ d: int,
302
+ ) -> np.ndarray:
303
+ """Evaluate fn on all (left × {0..n_k-1} × right) index combinations.
304
+
305
+ Returns
306
+ -------
307
+ np.ndarray of shape ``(r_l * n_k, r_r)``
308
+ C[il * n_k + ik, ir] = fn(*left_idx[il], ik, *right_idx[ir])
309
+ """
310
+ r_l = left_idx.shape[0]
311
+ r_r = right_idx.shape[0] if right_idx.ndim > 0 and right_idx.size > 0 else 1
312
+ C = np.zeros((r_l * n_k, r_r))
313
+ for il in range(r_l):
314
+ left_part = left_idx[il].tolist() if k > 0 else []
315
+ for ik in range(n_k):
316
+ row = il * n_k + ik
317
+ if k == d - 1:
318
+ # Last mode: no right indices
319
+ C[row, 0] = fn(*left_part, ik)
320
+ else:
321
+ for ir in range(r_r):
322
+ right_part = right_idx[ir].tolist() if (d - k - 1) > 0 else []
323
+ C[row, ir] = fn(*left_part, ik, *right_part)
324
+ return C
325
+
326
+
327
+ def _eval_interface(
328
+ fn: Callable[..., float],
329
+ left_idx: np.ndarray, # shape (r_l, k+1) — left pivots at next boundary
330
+ right_idx: np.ndarray, # shape (r_r, d-k-1) — right pivots at current boundary
331
+ d: int,
332
+ ) -> np.ndarray:
333
+ """Evaluate fn on all (left × right) combinations.
334
+
335
+ Returns
336
+ -------
337
+ np.ndarray of shape ``(r_l, r_r)``
338
+ Z[il, ir] = fn(*left_idx[il], *right_idx[ir])
339
+ """
340
+ r_l = left_idx.shape[0]
341
+ r_r = right_idx.shape[0] if right_idx.ndim > 0 and right_idx.size > 0 else 1
342
+ n_right_dims = right_idx.shape[1] if right_idx.ndim > 1 else 0
343
+ Z = np.zeros((r_l, r_r))
344
+ for il in range(r_l):
345
+ for ir in range(r_r):
346
+ idx = list(left_idx[il]) + (list(right_idx[ir]) if n_right_dims > 0 else [])
347
+ Z[il, ir] = fn(*idx)
348
+ return Z
349
+
350
+
351
+ def tt_cross(
352
+ fn: Callable[..., float],
353
+ shape: tuple[int, ...],
354
+ eps: float = 1e-4,
355
+ max_rank: int = 20,
356
+ n_sweeps: int = 8,
357
+ seed: int = 42,
358
+ ) -> list[np.ndarray]:
359
+ """TT-Cross black-box approximation (Oseledets & Tyrtyshnikov, 2010).
360
+
361
+ Constructs a Tensor-Train decomposition of a *d*-dimensional function
362
+ **without forming the full tensor**. Only queries ``fn`` at a
363
+ carefully selected set of index combinations — O(d · r² · n) evaluations
364
+ instead of O(n^d) for TT-SVD.
365
+
366
+ This makes 6+ asset problems feasible:
367
+
368
+ * 6 assets, 15 pts/axis, rank 10 → ~54,000 evaluations
369
+ * vs. 15^6 = 11,390,625 for full-grid TT-SVD
370
+
371
+ Algorithm
372
+ ---------
373
+ 1. **Initialise** right index sets J_k randomly.
374
+ 2. **Left-to-right sweep**: for each core k, evaluate the cross
375
+ C_k = f(I_k × {0..n_k-1} × J_k) and select new left pivots I_{k+1}
376
+ via greedy maxvol on the QR factor of C_k.
377
+ 3. **Build TT-cores** using the cross-interpolation formula:
378
+ Core_k = C_k @ pinv(Z_k) where Z_k = f(I_{k+1} ++ J_k) is the
379
+ (r_k × r_k) interface matrix.
380
+ 4. **Alternating sweeps** refine accuracy.
381
+
382
+ Parameters
383
+ ----------
384
+ fn : callable
385
+ Function accepting ``d`` integer arguments (grid indices)
386
+ and returning a float::
387
+
388
+ fn(i_0, i_1, ..., i_{d-1}) -> float
389
+
390
+ Use ``functools.partial`` or a lambda to curry other parameters.
391
+ shape : tuple of int
392
+ Mode sizes ``(n_0, n_1, ..., n_{d-1})``. These are *index* sizes.
393
+ To convert continuous axes to indices, wrap ``fn`` accordingly.
394
+ eps : float
395
+ Target relative accuracy. Controls rank selection via the
396
+ tolerance passed to the SVD truncation after each cross.
397
+ max_rank : int
398
+ Hard upper bound on TT-ranks.
399
+ n_sweeps : int
400
+ Number of left-to-right + right-to-left alternating sweeps.
401
+ ``n_sweeps=1`` gives a single L→R pass (fast, lower accuracy).
402
+ ``n_sweeps=4`` is sufficient for smooth pricing surfaces.
403
+ seed : int
404
+ Random seed for initialising right index sets.
405
+
406
+ Returns
407
+ -------
408
+ list of np.ndarray
409
+ TT-cores[k].shape = ``(r_{k-1}, n_k, r_k)`` with
410
+ ``r_0 = r_d = 1``.
411
+
412
+ Examples
413
+ --------
414
+ Compress a 6-asset basket payoff without forming the 15^6 grid::
415
+
416
+ import numpy as np
417
+ from tensorquantlib.tt.decompose import tt_cross
418
+
419
+ # Suppose price_lookup(i0, i1, i2, i3, i4, i5) evaluates the
420
+ # basket option price at the i-th point on each asset's price axis.
421
+ axes = [np.linspace(80, 120, 15)] * 6
422
+ def price_lookup(*indices):
423
+ spots = [axes[k][i] for k, i in enumerate(indices)]
424
+ return basket_mc(spots, ...) # your existing pricer
425
+
426
+ cores = tt_cross(price_lookup, shape=(15,)*6, max_rank=15, n_sweeps=6)
427
+
428
+ Notes
429
+ -----
430
+ After calling ``tt_cross``, wrap the result in a ``TTSurrogate``::
431
+
432
+ from tensorquantlib.tt.surrogate import TTSurrogate
433
+ surr = TTSurrogate(cores=cores, axes=axes, eps=eps)
434
+ """
435
+ d = len(shape)
436
+ if d < 2:
437
+ raise ValueError(f"TT-Cross requires at least 2 dimensions, got {d}")
438
+ if eps < 0:
439
+ raise ValueError(f"eps must be non-negative, got {eps}")
440
+ if max_rank < 1:
441
+ raise ValueError(f"max_rank must be >= 1, got {max_rank}")
442
+
443
+ rng = np.random.default_rng(seed)
444
+
445
+ # ------------------------------------------------------------------
446
+ # Step 1: Initialise right index sets J[k], shape (r_init, d-k-1)
447
+ # J[k] stores right multi-indices used when building core k.
448
+ # ------------------------------------------------------------------
449
+ r_init = min(2, max_rank)
450
+
451
+ # J[k] — right pivots at interface k → k+1
452
+ # Each row of J[k] is a (d-k-1)-dimensional multi-index.
453
+ J: list[np.ndarray] = []
454
+ for k in range(d - 1):
455
+ n_right = d - k - 1
456
+ if n_right > 0:
457
+ rows = np.stack(
458
+ [rng.integers(0, shape[k + 1 + j], size=r_init) for j in range(n_right)],
459
+ axis=1,
460
+ )
461
+ else:
462
+ rows = np.zeros((r_init, 0), dtype=int)
463
+ J.append(rows)
464
+
465
+ # ------------------------------------------------------------------
466
+ # Step 2: Left-to-right sweep to build left pivot sets I[k]
467
+ # I[k] — left pivots at interface k-1 → k, shape (r_k, k)
468
+ # I[0] is a single "empty" index — the left boundary has rank 1.
469
+ # ------------------------------------------------------------------
470
+ I: list[np.ndarray] = [np.zeros((1, 0), dtype=int)]
471
+
472
+ for sweep in range(n_sweeps):
473
+ # ---- Left-to-right ----
474
+ for k in range(d - 1):
475
+ r_l = I[k].shape[0]
476
+ r_r = J[k].shape[0]
477
+ n_k = shape[k]
478
+
479
+ # Evaluate cross C: shape (r_l * n_k, r_r)
480
+ C = _eval_fiber(fn, I[k], k, n_k, J[k], d)
481
+
482
+ # QR + maxvol to select r_new pivot rows
483
+ r_candidate = min(max_rank, r_l * n_k, max(r_r, 1))
484
+ Q_mat, _, piv = qr(C, pivoting=True, mode="economic")
485
+ Q_r = Q_mat[:, :r_candidate]
486
+ pivot_rows = _maxvol_greedy(Q_r, r_candidate, rng)
487
+
488
+ # Decode rows back to (il, ik) pairs
489
+ r_new = len(pivot_rows)
490
+ new_I = np.zeros((r_new, k + 1), dtype=int)
491
+ for j, row in enumerate(pivot_rows):
492
+ il_dec = int(row) // n_k
493
+ ik_dec = int(row) % n_k
494
+ if k > 0 and il_dec < I[k].shape[0]:
495
+ new_I[j, :k] = I[k][il_dec, :]
496
+ new_I[j, k] = ik_dec
497
+
498
+ if sweep == 0:
499
+ I.append(new_I)
500
+ else:
501
+ I[k + 1] = new_I
502
+
503
+ # ---- Right-to-left (refine J) ----
504
+ for k in range(d - 2, -1, -1):
505
+ r_l = I[k].shape[0]
506
+ r_r = J[k].shape[0]
507
+ n_k = shape[k]
508
+
509
+ C = _eval_fiber(fn, I[k], k, n_k, J[k], d)
510
+
511
+ # Select new right pivots from column pivoting of C^T
512
+ r_candidate = min(max_rank, r_l * n_k, max(r_r, 1))
513
+ _, _, piv_col = qr(C.T, pivoting=True, mode="economic")
514
+ pivot_rows = piv_col[:r_candidate]
515
+
516
+ r_new = len(pivot_rows)
517
+ n_right = d - k - 1
518
+ new_J = np.zeros((r_new, n_right), dtype=int)
519
+ for j, row in enumerate(pivot_rows):
520
+ il_dec = int(row) // n_k
521
+ ik_dec = int(row) % n_k
522
+ # Right multi-index = (ik_dec, J[k][il_dec])
523
+ if n_right == 1:
524
+ new_J[j, 0] = ik_dec
525
+ elif n_right > 1 and il_dec < I[k + 1].shape[0]:
526
+ # current right pivot is the k+1 index combined with J[k+1]
527
+ new_J[j, 0] = ik_dec
528
+ if k + 1 < len(J) and il_dec < J[k + 1].shape[0]:
529
+ new_J[j, 1:] = J[k + 1][il_dec % J[k + 1].shape[0], :]
530
+
531
+ J[k] = new_J
532
+
533
+ # ------------------------------------------------------------------
534
+ # Step 3: Build final TT-cores using the cross-interpolation formula.
535
+ # Core_k = C_k @ pinv(Z_k) reshaped to (r_{k-1}, n_k, r_k)
536
+ # ------------------------------------------------------------------
537
+ cores: list[np.ndarray] = []
538
+
539
+ for k in range(d):
540
+ n_k = shape[k]
541
+ r_l = I[k].shape[0]
542
+
543
+ if k < d - 1:
544
+ r_r = J[k].shape[0]
545
+ # Fiber: (r_l * n_k, r_r)
546
+ C = _eval_fiber(fn, I[k], k, n_k, J[k], d)
547
+
548
+ # Interface matrix Z: (|I[k+1]|, r_r)
549
+ r_next = I[k + 1].shape[0]
550
+ Z = _eval_interface(fn, I[k + 1], J[k], d)
551
+
552
+ # Core = C @ pinv(Z): shape (r_l * n_k, r_next)
553
+ # pinv handles rank-deficient Z gracefully
554
+ Z_pinv = np.linalg.pinv(Z) # (r_r, r_next)
555
+ core_mat = C @ Z_pinv # (r_l * n_k, r_next)
556
+
557
+ # Truncate numerical noise via SVD
558
+ U, s, Vt = np.linalg.svd(core_mat, full_matrices=False)
559
+ # Keep singular values above eps * max
560
+ thresh = eps * s[0] if s[0] > 0 else eps
561
+ r_trunc = max(1, int(np.sum(s > thresh)))
562
+ r_trunc = min(r_trunc, max_rank)
563
+ core_mat = (U[:, :r_trunc] * s[:r_trunc]) @ Vt[:r_trunc, :]
564
+
565
+ # Adjust r_next to r_trunc
566
+ r_out = core_mat.shape[1]
567
+ cores.append(core_mat.reshape(r_l, n_k, r_out))
568
+
569
+ else:
570
+ # Last core: fiber only, no interface
571
+ right_dummy = np.zeros((1, 0), dtype=int)
572
+ C = _eval_fiber(fn, I[k], k, n_k, right_dummy, d) # (r_l * n_k, 1)
573
+ cores.append(C.reshape(r_l, n_k, 1))
574
+
575
+ return cores
576
+