tensorquantlib 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. tensorquantlib/__init__.py +313 -0
  2. tensorquantlib/__main__.py +315 -0
  3. tensorquantlib/backtest/__init__.py +48 -0
  4. tensorquantlib/backtest/engine.py +240 -0
  5. tensorquantlib/backtest/metrics.py +320 -0
  6. tensorquantlib/backtest/strategy.py +348 -0
  7. tensorquantlib/core/__init__.py +6 -0
  8. tensorquantlib/core/ops.py +70 -0
  9. tensorquantlib/core/second_order.py +465 -0
  10. tensorquantlib/core/tensor.py +928 -0
  11. tensorquantlib/data/__init__.py +16 -0
  12. tensorquantlib/data/market.py +160 -0
  13. tensorquantlib/finance/__init__.py +52 -0
  14. tensorquantlib/finance/american.py +263 -0
  15. tensorquantlib/finance/basket.py +291 -0
  16. tensorquantlib/finance/black_scholes.py +219 -0
  17. tensorquantlib/finance/credit.py +199 -0
  18. tensorquantlib/finance/exotics.py +885 -0
  19. tensorquantlib/finance/fx.py +204 -0
  20. tensorquantlib/finance/greeks.py +133 -0
  21. tensorquantlib/finance/heston.py +543 -0
  22. tensorquantlib/finance/implied_vol.py +277 -0
  23. tensorquantlib/finance/ir_derivatives.py +203 -0
  24. tensorquantlib/finance/jump_diffusion.py +203 -0
  25. tensorquantlib/finance/local_vol.py +146 -0
  26. tensorquantlib/finance/rates.py +381 -0
  27. tensorquantlib/finance/risk.py +344 -0
  28. tensorquantlib/finance/variance_reduction.py +420 -0
  29. tensorquantlib/finance/volatility.py +355 -0
  30. tensorquantlib/py.typed +0 -0
  31. tensorquantlib/tt/__init__.py +43 -0
  32. tensorquantlib/tt/decompose.py +576 -0
  33. tensorquantlib/tt/ops.py +386 -0
  34. tensorquantlib/tt/pricing.py +304 -0
  35. tensorquantlib/tt/surrogate.py +634 -0
  36. tensorquantlib/utils/__init__.py +5 -0
  37. tensorquantlib/utils/validation.py +126 -0
  38. tensorquantlib/viz/__init__.py +27 -0
  39. tensorquantlib/viz/plots.py +331 -0
  40. tensorquantlib-0.3.0.dist-info/METADATA +602 -0
  41. tensorquantlib-0.3.0.dist-info/RECORD +44 -0
  42. tensorquantlib-0.3.0.dist-info/WHEEL +5 -0
  43. tensorquantlib-0.3.0.dist-info/licenses/LICENSE +21 -0
  44. tensorquantlib-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,634 @@
1
+ """
2
+ TT-Surrogate — fast approximate option pricing via Tensor-Train compression.
3
+
4
+ The TTSurrogate class wraps the full pipeline:
5
+ 1. Build a pricing grid (MC or analytic)
6
+ 2. Compress it with TT-SVD
7
+ 3. Evaluate prices at arbitrary points via TT interpolation
8
+ 4. Compute Greeks via autograd through the surrogate
9
+
10
+ For 6+ assets, use from_function() which applies TT-Cross and never
11
+ builds the full grid.
12
+
13
+ Typical usage::
14
+
15
+ from tensorquantlib.tt.surrogate import TTSurrogate
16
+
17
+ # ≤5 assets — TT-SVD on full grid
18
+ surr = TTSurrogate.from_basket(
19
+ S0_ranges=[(80, 120)] * 3,
20
+ K=100, T=1.0, r=0.05, sigma=[0.2]*3,
21
+ corr=np.eye(3), weights=[1/3]*3,
22
+ n_points=30, eps=1e-4,
23
+ )
24
+
25
+ # 6+ assets — TT-Cross (no full grid)
26
+ surr6 = TTSurrogate.from_function(
27
+ fn=my_pricer, # fn(*integer_indices) -> float
28
+ axes=[np.linspace(80, 120, 15)] * 6,
29
+ max_rank=15, eps=1e-4, n_sweeps=6,
30
+ )
31
+
32
+ price = surr.evaluate([100, 105, 95])
33
+ greeks = surr.greeks([100, 105, 95])
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import time
39
+ from typing import Union
40
+
41
+ import numpy as np
42
+
43
+ from ..core.tensor import Tensor
44
+ from ..finance.basket import build_pricing_grid, build_pricing_grid_analytic
45
+ from .decompose import tt_svd, tt_cross
46
+ from .ops import (
47
+ tt_eval,
48
+ tt_eval_batch,
49
+ tt_memory,
50
+ tt_ranks,
51
+ )
52
+
53
+
54
+ class TTSurrogate:
55
+ """Tensor-Train surrogate pricing model.
56
+
57
+ Stores a TT-compressed pricing grid and provides fast evaluation
58
+ by mapping continuous spot prices to grid indices via linear interpolation.
59
+
60
+ Attributes:
61
+ cores: List of TT-cores.
62
+ axes: List of 1D arrays — grid points along each asset axis.
63
+ n_assets: Number of assets.
64
+ build_time: Time (sec) to build the pricing grid.
65
+ compress_time: Time (sec) to run TT-SVD.
66
+ eps: TT-SVD tolerance used.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ cores: list[np.ndarray],
72
+ axes: list[np.ndarray],
73
+ eps: float,
74
+ build_time: float = 0.0,
75
+ compress_time: float = 0.0,
76
+ original_shape: tuple[int, ...] | None = None,
77
+ original_nbytes: int | None = None,
78
+ ):
79
+ self.cores = cores
80
+ self.axes = axes
81
+ self.n_assets = len(axes)
82
+ self.eps = eps
83
+ self.build_time = build_time
84
+ self.compress_time = compress_time
85
+ self._original_shape = original_shape
86
+ self._original_nbytes = original_nbytes
87
+
88
+ # ── constructors ────────────────────────────────────────────────────
89
+
90
+ @classmethod
91
+ def from_grid(
92
+ cls,
93
+ grid: np.ndarray,
94
+ axes: list[np.ndarray],
95
+ eps: float = 1e-4,
96
+ max_rank: int | None = None,
97
+ ) -> TTSurrogate:
98
+ """Build surrogate from a pre-computed pricing grid.
99
+
100
+ Args:
101
+ grid: Full tensor of prices, shape (n1, n2, ..., nd).
102
+ axes: List of 1D arrays for each axis.
103
+ eps: TT-SVD tolerance.
104
+ max_rank: Maximum TT-rank.
105
+
106
+ Returns:
107
+ TTSurrogate instance.
108
+ """
109
+ original_shape = grid.shape
110
+ original_nbytes = grid.nbytes
111
+
112
+ if grid.ndim < 2:
113
+ raise ValueError(f"Grid must be at least 2D, got {grid.ndim}D")
114
+ if len(axes) != grid.ndim:
115
+ raise ValueError(
116
+ f"Number of axes ({len(axes)}) must match grid dimensions ({grid.ndim})"
117
+ )
118
+ for i, (ax, n) in enumerate(zip(axes, grid.shape)):
119
+ if len(ax) != n:
120
+ raise ValueError(
121
+ f"Axis {i} length ({len(ax)}) doesn't match grid size ({n})"
122
+ )
123
+ if eps <= 0:
124
+ raise ValueError(f"eps must be positive, got {eps}")
125
+
126
+ t0 = time.perf_counter()
127
+ cores = tt_svd(grid, eps=eps, max_rank=max_rank)
128
+ compress_time = time.perf_counter() - t0
129
+
130
+ return cls(
131
+ cores=cores,
132
+ axes=axes,
133
+ eps=eps,
134
+ compress_time=compress_time,
135
+ original_shape=original_shape,
136
+ original_nbytes=original_nbytes,
137
+ )
138
+
139
+ @classmethod
140
+ def from_basket_analytic(
141
+ cls,
142
+ S0_ranges: list[tuple[float, float]],
143
+ K: float,
144
+ T: float,
145
+ r: float,
146
+ sigma: list[float],
147
+ weights: list[float],
148
+ n_points: int = 30,
149
+ eps: float = 1e-4,
150
+ max_rank: int | None = None,
151
+ ) -> TTSurrogate:
152
+ """Build surrogate from analytic basket pricing grid.
153
+
154
+ Uses weighted Black-Scholes approximation — fast but approximate.
155
+
156
+ Args:
157
+ S0_ranges: [(lo, hi)] per asset.
158
+ K: Strike.
159
+ T: Maturity.
160
+ r: Risk-free rate.
161
+ sigma: Volatilities per asset.
162
+ weights: Portfolio weights.
163
+ n_points: Grid points per axis.
164
+ eps: TT-SVD tolerance.
165
+ max_rank: Maximum TT-rank.
166
+
167
+ Returns:
168
+ TTSurrogate instance.
169
+ """
170
+ t0 = time.perf_counter()
171
+ grid, axes = build_pricing_grid_analytic(
172
+ S0_ranges=S0_ranges,
173
+ K=K, T=T, r=r, sigma=np.asarray(sigma),
174
+ weights=np.asarray(weights), n_points=n_points,
175
+ )
176
+ build_time = time.perf_counter() - t0
177
+
178
+ original_shape = grid.shape
179
+ original_nbytes = grid.nbytes
180
+
181
+ t1 = time.perf_counter()
182
+ cores = tt_svd(grid, eps=eps, max_rank=max_rank)
183
+ compress_time = time.perf_counter() - t1
184
+
185
+ return cls(
186
+ cores=cores,
187
+ axes=axes,
188
+ eps=eps,
189
+ build_time=build_time,
190
+ compress_time=compress_time,
191
+ original_shape=original_shape,
192
+ original_nbytes=original_nbytes,
193
+ )
194
+
195
+ @classmethod
196
+ def from_basket_mc(
197
+ cls,
198
+ S0_ranges: list[tuple[float, float]],
199
+ K: float,
200
+ T: float,
201
+ r: float,
202
+ sigma: list[float],
203
+ corr: np.ndarray,
204
+ weights: list[float],
205
+ n_points: int = 30,
206
+ n_mc_paths: int = 50_000,
207
+ eps: float = 1e-4,
208
+ max_rank: int | None = None,
209
+ ) -> TTSurrogate:
210
+ """Build surrogate from Monte-Carlo basket pricing grid.
211
+
212
+ Slow but accurate. Suitable for validation.
213
+ """
214
+ t0 = time.perf_counter()
215
+ grid, axes = build_pricing_grid(
216
+ S0_ranges=S0_ranges,
217
+ K=K, T=T, r=r, sigma=np.asarray(sigma),
218
+ corr=corr, weights=np.asarray(weights),
219
+ n_points=n_points, n_mc_paths=n_mc_paths,
220
+ )
221
+ build_time = time.perf_counter() - t0
222
+
223
+ original_shape = grid.shape
224
+ original_nbytes = grid.nbytes
225
+
226
+ t1 = time.perf_counter()
227
+ cores = tt_svd(grid, eps=eps, max_rank=max_rank)
228
+ compress_time = time.perf_counter() - t1
229
+
230
+ return cls(
231
+ cores=cores,
232
+ axes=axes,
233
+ eps=eps,
234
+ build_time=build_time,
235
+ compress_time=compress_time,
236
+ original_shape=original_shape,
237
+ original_nbytes=original_nbytes,
238
+ )
239
+
240
+ @classmethod
241
+ def from_function(
242
+ cls,
243
+ fn: object,
244
+ axes: list[np.ndarray],
245
+ eps: float = 1e-4,
246
+ max_rank: int = 20,
247
+ n_sweeps: int = 6,
248
+ seed: int = 42,
249
+ ) -> "TTSurrogate":
250
+ """Build surrogate via TT-Cross — **no full grid needed**.
251
+
252
+ This is the recommended constructor for **6+ asset** problems.
253
+ TT-Cross samples the pricing function at O(d · r² · n) selected
254
+ index combinations instead of the full n^d grid, making
255
+ high-dimensional problems feasible.
256
+
257
+ Parameters
258
+ ----------
259
+ fn : callable
260
+ Function accepting ``d`` integer grid-index arguments and
261
+ returning a float price::
262
+
263
+ fn(i_0, i_1, ..., i_{d-1}) -> float
264
+
265
+ The simplest way to build this is to pre-compute an axis
266
+ array for continuous spots and index into it inside ``fn``.
267
+ Example::
268
+
269
+ axes = [np.linspace(80, 120, 15)] * 6
270
+
271
+ def my_pricer(*indices):
272
+ spots = [axes[k][i] for k, i in enumerate(indices)]
273
+ return basket_mc(spots, K, T, r, sigma, corr)
274
+
275
+ surr = TTSurrogate.from_function(my_pricer, axes)
276
+
277
+ axes : list of np.ndarray
278
+ 1D grid arrays, one per asset. ``len(axes)`` is the number
279
+ of assets. ``axes[k][i]`` gives the spot price at index ``i``
280
+ for asset ``k``.
281
+ eps : float
282
+ Relative accuracy target passed to TT-Cross.
283
+ max_rank : int
284
+ Hard upper bound on TT-ranks. Increase if accuracy is
285
+ insufficient; decrease if speed is the priority.
286
+ n_sweeps : int
287
+ Number of left-to-right + right-to-left alternating sweeps.
288
+ Default 6 is sufficient for smooth option pricing surfaces.
289
+ seed : int
290
+ Random seed for TT-Cross initialisation.
291
+
292
+ Returns
293
+ -------
294
+ TTSurrogate
295
+ """
296
+ from collections.abc import Callable as _Callable
297
+ if not callable(fn):
298
+ raise TypeError(f"fn must be callable, got {type(fn)}")
299
+ if len(axes) < 2:
300
+ raise ValueError("from_function requires at least 2 axes (2 assets)")
301
+
302
+ shape = tuple(len(a) for a in axes)
303
+ # Total function evaluations (approximate)
304
+ _n_evals = len(axes) * max_rank ** 2 * max(shape)
305
+
306
+ t0 = time.perf_counter()
307
+ cores = tt_cross(
308
+ fn=fn, # type: ignore[arg-type]
309
+ shape=shape,
310
+ eps=eps,
311
+ max_rank=max_rank,
312
+ n_sweeps=n_sweeps,
313
+ seed=seed,
314
+ )
315
+ compress_time = time.perf_counter() - t0
316
+
317
+ return cls(
318
+ cores=cores,
319
+ axes=axes,
320
+ eps=eps,
321
+ build_time=0.0, # No separate grid build step
322
+ compress_time=compress_time,
323
+ original_shape=None, # Full grid was never formed
324
+ original_nbytes=None,
325
+ )
326
+
327
+ # ── evaluation ──────────────────────────────────────────────────────
328
+
329
+
330
+ def _spot_to_indices(self, spots: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
331
+ """Map continuous spot prices to fractional grid indices.
332
+
333
+ Returns integer indices (floored) and interpolation weights.
334
+ Uses linear interpolation between adjacent grid points.
335
+
336
+ Args:
337
+ spots: Array of shape (d,) or (n, d).
338
+
339
+ Returns:
340
+ (indices_lo, weights) both of shape matching spots.
341
+ """
342
+ spots = np.atleast_2d(spots) # (n, d)
343
+ indices_lo = np.zeros_like(spots, dtype=int)
344
+ weights = np.zeros_like(spots, dtype=float)
345
+
346
+ for k in range(self.n_assets):
347
+ axis = self.axes[k]
348
+ n_k = len(axis)
349
+ # Clamp to valid range
350
+ s = np.clip(spots[:, k], axis[0], axis[-1])
351
+ # Find interval index
352
+ idx = np.searchsorted(axis, s, side="right") - 1
353
+ idx = np.clip(idx, 0, n_k - 2)
354
+ # Interpolation weight within interval
355
+ lo = axis[idx]
356
+ hi = axis[idx + 1]
357
+ w = np.where(hi > lo, (s - lo) / (hi - lo), 0.0)
358
+
359
+ indices_lo[:, k] = idx
360
+ weights[:, k] = w
361
+
362
+ return indices_lo, weights
363
+
364
+ def evaluate(self, spots: Union[np.ndarray, list[float]]) -> Union[float, np.ndarray]:
365
+ """Evaluate the surrogate price at given spot prices.
366
+
367
+ Uses multi-linear interpolation on the TT grid.
368
+
369
+ Args:
370
+ spots: Spot prices — shape (d,) for single point, (n, d) for batch.
371
+
372
+ Returns:
373
+ Price(s) — scalar for single, array for batch.
374
+ """
375
+ spots = np.asarray(spots, dtype=float)
376
+ single = spots.ndim == 1
377
+ spots = np.atleast_2d(spots)
378
+ n_points = spots.shape[0]
379
+
380
+ indices_lo, weights = self._spot_to_indices(spots)
381
+
382
+ # Multi-linear interpolation: sum over 2^d corners
383
+ d = self.n_assets
384
+ result = np.zeros(n_points)
385
+
386
+ for corner in range(2**d):
387
+ idx = indices_lo.copy()
388
+ w = np.ones(n_points)
389
+ for k in range(d):
390
+ if corner & (1 << k):
391
+ idx[:, k] = np.minimum(idx[:, k] + 1, len(self.axes[k]) - 1)
392
+ w *= weights[:, k]
393
+ else:
394
+ w *= (1.0 - weights[:, k])
395
+
396
+ vals = tt_eval_batch(self.cores, idx)
397
+ result += w * vals
398
+
399
+ return float(result[0]) if single else result
400
+
401
+ def evaluate_tensor(self, spots: Union[np.ndarray, list[float]]) -> Tensor:
402
+ """Evaluate surrogate price and return a Tensor for autograd.
403
+
404
+ This enables computing Greeks via backward().
405
+
406
+ Args:
407
+ spots: Spot prices — shape (d,).
408
+
409
+ Returns:
410
+ Tensor with computed price (supports backward).
411
+ """
412
+ spots_arr = np.asarray(spots, dtype=float)
413
+ assert spots_arr.ndim == 1, "evaluate_tensor expects a single point (1D)"
414
+
415
+ # Convert spots to Tensor objects
416
+ spot_tensors = [Tensor(np.array([s])) for s in spots_arr]
417
+
418
+ indices_lo, weights_np = self._spot_to_indices(spots_arr.reshape(1, -1))
419
+ indices_lo = indices_lo[0] # (d,)
420
+ weights_np = weights_np[0] # (d,)
421
+
422
+ # Create weight tensors for autodiff
423
+ weight_tensors = []
424
+ for k in range(self.n_assets):
425
+ axis = self.axes[k]
426
+ idx = indices_lo[k]
427
+ lo_val = axis[idx]
428
+ hi_idx = min(idx + 1, len(axis) - 1)
429
+ hi_val = axis[hi_idx]
430
+ if hi_val > lo_val:
431
+ wt = (spot_tensors[k] - Tensor(np.array([lo_val]))) / Tensor(np.array([hi_val - lo_val]))
432
+ else:
433
+ wt = Tensor(np.array([0.0]))
434
+ weight_tensors.append(wt)
435
+
436
+ # Multi-linear interpolation with Tensor arithmetic
437
+ d = self.n_assets
438
+ result = Tensor(np.array([0.0]))
439
+
440
+ for corner in range(2**d):
441
+ idx = indices_lo.copy()
442
+ w = Tensor(np.array([1.0]))
443
+ for k in range(d):
444
+ if corner & (1 << k):
445
+ idx[k] = min(idx[k] + 1, len(self.axes[k]) - 1)
446
+ w = w * weight_tensors[k]
447
+ else:
448
+ w = w * (Tensor(np.array([1.0])) - weight_tensors[k])
449
+
450
+ val = tt_eval(self.cores, tuple(int(i) for i in idx))
451
+ result = result + w * Tensor(np.array([val]))
452
+
453
+ return result
454
+
455
+ def greeks(self, spots: Union[np.ndarray, list[float]], h: float = 1e-4) -> dict[str, object]:
456
+ """Compute Greeks via autograd through the surrogate.
457
+
458
+ Delta: ∂price/∂S_i for each asset (via autograd).
459
+ Gamma: (Delta(S+h) - Delta(S-h)) / 2h (finite-diff on Delta).
460
+
461
+ Args:
462
+ spots: Spot prices (1D).
463
+ h: Relative bump for Gamma (h_abs = S_i * h).
464
+
465
+ Returns:
466
+ Dict with 'price', 'delta' (array), 'gamma' (array).
467
+ """
468
+ spots = np.asarray(spots, dtype=float)
469
+ d = len(spots)
470
+
471
+ # Delta via autograd
472
+ price_t = self.evaluate_tensor(spots)
473
+ price_t.backward()
474
+
475
+ price = price_t.item()
476
+ delta = np.zeros(d)
477
+ for _k in range(d):
478
+ # delta[k] = ∂price/∂S_k
479
+ # We need to trace through from evaluate_tensor
480
+ pass
481
+
482
+ # Use finite differences for both delta and gamma (more robust)
483
+ delta = np.zeros(d)
484
+ gamma = np.zeros(d)
485
+ for k in range(d):
486
+ h_abs = max(spots[k] * h, 1e-6)
487
+
488
+ s_up = spots.copy()
489
+ s_up[k] += h_abs
490
+ s_dn = spots.copy()
491
+ s_dn[k] -= h_abs
492
+
493
+ p_up = self.evaluate(s_up)
494
+ p_dn = self.evaluate(s_dn)
495
+
496
+ delta[k] = (p_up - p_dn) / (2 * h_abs)
497
+ gamma[k] = (p_up - 2 * price + p_dn) / (h_abs**2)
498
+
499
+ return {"price": price, "delta": delta, "gamma": gamma}
500
+
501
+ # ── visualization ────────────────────────────────────────────────────
502
+
503
+ def plot_surface(
504
+ self,
505
+ dims: tuple[int, int] = (0, 1),
506
+ fixed_indices: dict[int, int] | None = None,
507
+ title: str = "Pricing Surface",
508
+ mode: str = "heatmap",
509
+ **kwargs: object,
510
+ ) -> object:
511
+ """Plot a 2D pricing surface slice.
512
+
513
+ Evaluates the full pricing grid from TT-cores and plots a 2D
514
+ heatmap or 3D surface. Any extra keyword arguments are forwarded
515
+ to ``plot_pricing_surface``.
516
+
517
+ Args:
518
+ dims: Which two asset axes to plot (default: first two).
519
+ fixed_indices: Override slice indices for remaining axes.
520
+ title: Plot title.
521
+ mode: ``"heatmap"`` (default) or ``"surface"`` (3D).
522
+
523
+ Returns:
524
+ ``(fig, ax)`` matplotlib tuple.
525
+ """
526
+ from tensorquantlib.viz.plots import plot_pricing_surface
527
+ from tensorquantlib.tt.ops import tt_to_full
528
+
529
+ grid = tt_to_full(self.cores)
530
+ return plot_pricing_surface(
531
+ grid, self.axes, dims=dims,
532
+ fixed_indices=fixed_indices, title=title, mode=mode,
533
+ **kwargs, # type: ignore[arg-type]
534
+ )
535
+
536
+ def plot_greeks(
537
+ self,
538
+ dims: tuple[int, int] = (0, 1),
539
+ fixed_indices: dict[int, int] | None = None,
540
+ h: float = 1e-2,
541
+ **kwargs: object,
542
+ ) -> object:
543
+ """Plot Delta and Gamma surfaces as side-by-side heatmaps.
544
+
545
+ Computes Greek grids via finite differences on the TT surrogate
546
+ and plots them using ``plot_greeks_surface``.
547
+
548
+ Args:
549
+ dims: Which two axes to plot.
550
+ fixed_indices: Override slice indices for remaining axes.
551
+ h: Relative bump for finite-difference Greeks (h_abs = S * h).
552
+
553
+ Returns:
554
+ ``(fig, axes)`` matplotlib tuple.
555
+ """
556
+ from tensorquantlib.viz.plots import plot_greeks_surface
557
+ from tensorquantlib.tt.ops import tt_to_full
558
+
559
+ grid = tt_to_full(self.cores)
560
+ d = grid.ndim
561
+
562
+ # Build delta grids for each asset axis via finite differences
563
+ delta_grids: dict[str, np.ndarray] = {}
564
+ for k in range(min(d, len(dims))):
565
+ axis_k = self.axes[dims[k]]
566
+ # Numerical derivative along axis dims[k]
567
+ delta_k = np.gradient(grid, axis_k, axis=dims[k])
568
+ label = f"Delta (axis {dims[k]})"
569
+ delta_grids[label] = delta_k
570
+
571
+ return plot_greeks_surface(
572
+ delta_grids, self.axes, dims=dims,
573
+ fixed_indices=fixed_indices,
574
+ **kwargs, # type: ignore[arg-type]
575
+ )
576
+
577
+ def plot_ranks(self, **kwargs: object) -> object:
578
+ """Bar chart of TT-ranks across bonds.
579
+
580
+ Returns:
581
+ ``(fig, ax)`` matplotlib tuple.
582
+ """
583
+ from tensorquantlib.viz.plots import plot_tt_ranks
584
+
585
+ return plot_tt_ranks(self.cores, **kwargs) # type: ignore[arg-type]
586
+
587
+ # ── diagnostics ─────────────────────────────────────────────────────
588
+
589
+ def summary(self) -> dict[str, object]:
590
+ """Return diagnostic summary of the surrogate model.
591
+
592
+ Returns:
593
+ Dict with ranks, memory, compression_ratio, timings, etc.
594
+ """
595
+ ranks = tt_ranks(self.cores)
596
+ tt_mem = tt_memory(self.cores)
597
+
598
+ info = {
599
+ "n_assets": self.n_assets,
600
+ "grid_shape": tuple(len(a) for a in self.axes),
601
+ "tt_ranks": ranks,
602
+ "max_rank": max(ranks),
603
+ "tt_memory_bytes": tt_mem,
604
+ "tt_memory_KB": tt_mem / 1024,
605
+ "eps": self.eps,
606
+ "build_time_s": self.build_time,
607
+ "compress_time_s": self.compress_time,
608
+ }
609
+
610
+ if self._original_nbytes is not None:
611
+ info["full_memory_bytes"] = self._original_nbytes
612
+ info["full_memory_KB"] = self._original_nbytes / 1024
613
+ info["compression_ratio"] = self._original_nbytes / tt_mem if tt_mem > 0 else float("inf")
614
+
615
+ return info
616
+
617
+ def print_summary(self) -> None:
618
+ """Print a formatted diagnostic summary."""
619
+ s = self.summary()
620
+ print("=" * 60)
621
+ print("TT-Surrogate Summary")
622
+ print("=" * 60)
623
+ print(f" Assets: {s['n_assets']}")
624
+ print(f" Grid shape: {s['grid_shape']}")
625
+ print(f" TT-ranks: {s['tt_ranks']}")
626
+ print(f" Max TT-rank: {s['max_rank']}")
627
+ print(f" TT memory: {s['tt_memory_KB']:.2f} KB")
628
+ if "full_memory_KB" in s:
629
+ print(f" Full grid memory: {s['full_memory_KB']:.2f} KB")
630
+ print(f" Compression: {s['compression_ratio']:.1f}×")
631
+ print(f" TT-SVD tolerance: {s['eps']}")
632
+ print(f" Grid build time: {s['build_time_s']:.3f} s")
633
+ print(f" TT-SVD time: {s['compress_time_s']:.3f} s")
634
+ print("=" * 60)
@@ -0,0 +1,5 @@
1
+ """Utilities — gradient validation and numerical checking."""
2
+
3
+ from .validation import check_grad, numerical_gradient
4
+
5
+ __all__ = ["check_grad", "numerical_gradient"]