ipc-module 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ Copyright (c) 2025 Katsuma Inoue.
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1,22 @@
1
+ Metadata-Version: 2.4
2
+ Name: ipc-module
3
+ Version: 0.1.2
4
+ Summary: GPGPU-accelerated information processing capacity (IPC) module
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE.txt
8
+ Requires-Dist: ipython>=9.7.0
9
+ Requires-Dist: joblib>=1.5.2
10
+ Requires-Dist: matplotlib>=3.10.7
11
+ Requires-Dist: numpy>=2.3.4
12
+ Requires-Dist: polars>=1.35.2
13
+ Requires-Dist: tqdm>=4.67.1
14
+ Provides-Extra: gpu
15
+ Requires-Dist: cupy>=13.6.0; extra == "gpu"
16
+ Requires-Dist: torch>=2.9.1; extra == "gpu"
17
+ Dynamic: license-file
18
+
19
+ # `ipc-module`
20
+
21
+ `ipc-module` is a Python library for measuring and analyzing information processing capacity (IPC), an indicator proposed in reservoir computing.
22
+ See the [documentation](https://rc-bootcamp.github.io/ipc-module/) for more details.
@@ -0,0 +1,4 @@
1
+ # `ipc-module`
2
+
3
+ `ipc-module` is a Python library for measuring and analyzing information processing capacity (IPC), an indicator proposed in reservoir computing.
4
+ See the [documentation](https://rc-bootcamp.github.io/ipc-module/) for more details.
@@ -0,0 +1,45 @@
1
+ [project]
2
+ name = "ipc-module"
3
+ version = "0.1.2"
4
+ description = "GPGPU-accelerated information processing capacity (IPC) module"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "ipython>=9.7.0",
9
+ "joblib>=1.5.2",
10
+ "matplotlib>=3.10.7",
11
+ "numpy>=2.3.4",
12
+ "polars>=1.35.2",
13
+ "tqdm>=4.67.1",
14
+ ]
15
+
16
+ [project.optional-dependencies]
17
+ gpu = [
18
+ "cupy>=13.6.0",
19
+ "torch>=2.9.1",
20
+ ]
21
+
22
+ [dependency-groups]
23
+ dev = [
24
+ "ipykernel>=7.1.0",
25
+ "ipywidgets>=8.1.8",
26
+ "matplotlib>=3.10.7",
27
+ "nbconvert>=7.16.6",
28
+ "nbqa>=1.9.1",
29
+ "ruff>=0.14.5",
30
+ "twine>=6.2.0",
31
+ ]
32
+ mkdocs = [
33
+ "mkdocs>=1.6.1",
34
+ "mkdocs-material>=9.7.0",
35
+ "mkdocstrings>=0.30.1",
36
+ "mkdocstrings-python>=1.19.0",
37
+ ]
38
+
39
+ [tool.ruff]
40
+ line-length = 100
41
+
42
+ [tool.ruff.lint]
43
+ preview = true
44
+ ignore = ["B018", "E402", "E501"]
45
+ select = ["B", "E", "W", "I", "F", "NPY"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
File without changes
@@ -0,0 +1,19 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright (c) 2025, Katsuma Inoue. All rights reserved.
5
+ # This code is licensed under the MIT License.
6
+
7
+
8
+ SHOW_PROGRESS_BAR = True
9
+
10
+
11
+ def set_progress_bar(show: bool):
12
+ """
13
+ Set the global flag to show or hide the progress bar.
14
+
15
+ Parameters:
16
+ show (bool): If True, show the progress bar; otherwise, hide it.
17
+ """
18
+ global SHOW_PROGRESS_BAR
19
+ SHOW_PROGRESS_BAR = show
@@ -0,0 +1,372 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright (c) 2025, Katsuma Inoue. All rights reserved.
5
+ # This code is licensed under the MIT License.
6
+
7
+ import functools
8
+ import importlib
9
+ import inspect
10
+ import itertools
11
+ import math
12
+ import sys
13
+ from collections import Counter, defaultdict
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import polars as pl
18
+ from matplotlib.axes import Axes
19
+
20
+
21
+ def get_backend_name(mat):
22
+ return mat.__class__.__module__
23
+
24
+
25
+ def import_backend(mat):
26
+ try:
27
+ name = get_backend_name(mat)
28
+ backend = importlib.import_module(name)
29
+ except ImportError:
30
+ raise ModuleNotFoundError(f"{name} is not found") from None
31
+ return backend
32
+
33
+
34
+ def backend_max(mat, *args, **kwargs):
35
+ res = mat.max(*args, **kwargs)
36
+ if get_backend_name(mat) == "torch":
37
+ if type(res) is importlib.import_module("torch").return_types.max:
38
+ # Pytorch wrapper
39
+ # See https://pytorch.org/docs/stable/generated/torch.max.html
40
+ res = res.values
41
+ return res
42
+
43
+
44
+ def backend_std(mat, *args, **kwargs):
45
+ if get_backend_name(mat) == "torch":
46
+ return mat.std(*args, unbiased=False, **kwargs)
47
+ else:
48
+ return mat.std(*args, **kwargs)
49
+
50
+
51
+ def zeros_like(mat, shape=None):
52
+ module = importlib.import_module(mat.__class__.__module__)
53
+ if shape is None:
54
+ return module.zeros_like(module.broadcast_to(mat, mat.shape))
55
+ else:
56
+ return module.zeros_like(module.broadcast_to(mat.flatten()[0], shape))
57
+
58
+
59
+ @functools.cache
60
+ def make_degree_list(s, m=None):
61
+ if m is None:
62
+ m = s
63
+ if s == 0:
64
+ return [()]
65
+ ans = []
66
+ for d in range(1, min(s, m) + 1):
67
+ res = make_degree_list(s - d, d)
68
+ ans += [(d, *v) for v in res]
69
+ return ans
70
+
71
+
72
+ def multi_combination(arr, *num):
73
+ if len(num) == 0:
74
+ yield ()
75
+ return
76
+ length = len(arr)
77
+ if length < num[0] or num[0] < 0:
78
+ return
79
+ itl = itertools.combinations(arr, num[0])
80
+ itr = list(itertools.combinations(arr, length - num[0]))[::-1]
81
+ for left, right in zip(itl, itr, strict=False):
82
+ for out in multi_combination(right, *num[1:]):
83
+ yield left, *out
84
+
85
+
86
+ @functools.cache
87
+ def multi_combination_length(length, *num):
88
+ if len(num) == 0:
89
+ return 1
90
+ elif length < num[0]:
91
+ return 0
92
+ else:
93
+ return math.comb(length, num[0]) * multi_combination_length(length - num[0], *num[1:])
94
+
95
+
96
+ def make_delay_list(delay_range, degree_tuple):
97
+ @functools.cache
98
+ def _make_delay_list_wrapper(a1, a2):
99
+ return list(map(lambda v: sum(v, ()), multi_combination(a1, *a2)))
100
+
101
+ return _make_delay_list_wrapper(delay_range, Counter(degree_tuple).values())
102
+
103
+
104
+ @functools.cache
105
+ def make_permutation(length, seed=None, rnd=None, tolist=True):
106
+ perm = np.arange(length, dtype=np.int32)
107
+ if rnd is None:
108
+ rnd = np.random.default_rng(seed)
109
+ rnd.shuffle(perm)
110
+ if tolist:
111
+ return perm.tolist()
112
+ else:
113
+ return perm
114
+
115
+
116
+ def count_ipc_candidates(
117
+ degree_sum: int | list = 1,
118
+ delay_max: int | list = 100,
119
+ maximum_component: int = None,
120
+ zero_offset: bool = True,
121
+ surrogate_num: int = 1000,
122
+ ):
123
+ # Sum of degrees to be calculated.
124
+ if type(degree_sum) is list:
125
+ degree_sum = np.sort(np.unique(degree_sum)).tolist()
126
+ else:
127
+ degree_sum = [degree_sum]
128
+ # Range of delays to be calculated.
129
+ if type(delay_max) is int:
130
+ delay_max = [delay_max]
131
+ degree_tuples, delay_ranges = [], []
132
+ for degree, delay in zip(degree_sum, delay_max, strict=False):
133
+ degree_tuple = make_degree_list(degree)
134
+ if maximum_component is not None:
135
+ degree_tuple = list(filter(lambda t: len(t) <= maximum_component, degree_tuple))
136
+ degree_tuples += degree_tuple
137
+ delay_ranges += [range(0 if zero_offset else 1, delay + 1)] * len(degree_tuple)
138
+ # Calculate the number of iterations (# of regressor calls).
139
+ total_length = [
140
+ multi_combination_length(len(delay_range), *Counter(degree_tuple).values())
141
+ for degree_tuple, delay_range in zip(degree_tuples, delay_ranges, strict=False)
142
+ ]
143
+ total_length = [v for v in total_length if v > 0]
144
+ total_length = sum(total_length) + surrogate_num * len(total_length)
145
+ return total_length
146
+
147
+
148
+ def truncate_tuple(tup: tuple[int], max_length=5):
149
+ if len(tup) <= max_length:
150
+ return str(tup)
151
+ else:
152
+ return f"({', '.join(map(str, tup[:max_length]))}, ...)"
153
+
154
+
155
+ def truncate_dataframe(df: pl.DataFrame, key="ipc", rank=None):
156
+ columns = [column for column in df.columns if "ipc" not in column]
157
+ df_tr = df.filter(pl.col(key) > 0)
158
+ if rank is not None and df_tr[key].sum() > rank:
159
+ df_tr = df_tr.sort(key, descending=True)
160
+ df_tr = df_tr.with_columns(pl.col(key).cum_sum().alias("cum"))
161
+ df_tr = df_tr.filter(df_tr["cum"] < rank)
162
+ return df_tr[[*columns, key]]
163
+
164
+
165
+ def visualize_dataframe(
166
+ ax: Axes,
167
+ df: pl.DataFrame,
168
+ xticks: list | np.ndarray | None = None,
169
+ group_by: str = "degree",
170
+ threshold: float = 0.5,
171
+ sort_by: callable = np.nanmax,
172
+ cmap: str | plt.Colormap = "tab10",
173
+ x_offset: float = 0,
174
+ min_color_coef: float = 0.5,
175
+ fontsize: int = 12,
176
+ step_linewidth: float = 0.5,
177
+ bottom_min: np.ndarray | None = None,
178
+ zero_offset: bool = True,
179
+ ):
180
+ """
181
+
182
+ Visualizes IPC results stored in a DataFrame using a bar plot.
183
+
184
+ Parameters:
185
+ ax (Axes): Matplotlib Axes object to plot on.
186
+ df (pl.DataFrame): `polars.DataFrame` containing IPC results.
187
+ xticks (list | np.ndarray | None, optional): X-axis tick positions. If `None`, default positions are used.
188
+ group_by (str, optional): Grouping method for IPC components. Choose from 'degree', 'component', or 'detail'.
189
+ threshold (float, optional): Threshold for displaying IPC components. Components with values below this threshold are grouped into 'rest'.
190
+ sort_by (callable, optional): Function to sort IPC components.
191
+ cmap (str | plt.Colormap, optional): Colormap for coloring IPC components.
192
+ x_offset (float, optional): Horizontal offset for the x-axis.
193
+ min_color_coef (float, optional): Minimum color coefficient for coloring. Only used when `group_by` is 'component' or 'detail'.
194
+ fontsize (int, optional): Font size for labels.
195
+ step_linewidth (float, optional): Line width for step lines. If 0, no lines are drawn.
196
+ bottom_min (np.ndarray | None, optional): Minimum bottom values for bars.
197
+ zero_offset (bool, optional): Whether the delay offset starts from zero.
198
+
199
+ Notes:
200
+ `group_by` determines how IPC components are grouped and colored:
201
+
202
+ - `'degree'`: Groups by sum of degrees (e.g., `3` for `(3,)`, `(2, 1)`, `(1, 1, 1)`).
203
+ - `'component'`: Groups by tuple of degrees (e.g., `(3,)`, `(2, 1)`, `(1, 1, 1)` are distinct).
204
+ - `'detail'`: Groups by tuple of degrees and delays.
205
+
206
+ Since the number of unique components can grow rapidly, using `'component'` or `'detail'` may result in many distinct colors, making it time-consuming to render.
207
+ Especially for `'detail'`, consider setting a higher threshold to limit the number of displayed components (e.g., `threshold=1.0`).
208
+ Use a positive `threshold` value to group less significant components into a `rest` category.
209
+ """
210
+
211
+ ipc_columns = [column for column in df.columns if "ipc" in column]
212
+ assert group_by in ["degree", "component", "detail"], "invalid `group_by` argments"
213
+ col_cmp = sorted([column for column in df.columns if column.startswith("cmp")])
214
+ col_del = sorted([column for column in df.columns if column.startswith("del")])
215
+ group_by_columns = dict(degree=["degree"], component=col_cmp, detail=col_cmp + col_del)
216
+ if type(cmap) is str:
217
+ cmap = plt.get_cmap(cmap)
218
+
219
+ def shape_segment(segment, get_delay=False):
220
+ if group_by == "degree":
221
+ return tuple(segment)
222
+ elif group_by == "component":
223
+ return tuple(val for val in segment if val >= 0)
224
+ elif group_by == "detail":
225
+ degrees = tuple(val for val in segment[: len(segment) // 2] if val >= 0)
226
+ if get_delay:
227
+ delays = tuple(val for val in segment[len(segment) // 2 :] if val >= 0)
228
+ return degrees, delays
229
+ else:
230
+ return degrees
231
+
232
+ def get_color_index(segment):
233
+ if group_by == "degree":
234
+ return segment[0], 0, 1
235
+ elif group_by == "component":
236
+ degrees = shape_segment(segment)
237
+ degree = sum(degrees)
238
+ degree_list = make_degree_list(degree)
239
+ index = dict(zip(degree_list[::-1], range(len(degree_list)), strict=False))[degrees]
240
+ max_index = len(degree_list)
241
+ return degree, index, max_index
242
+ elif group_by == "detail":
243
+ degrees, delays = shape_segment(segment, get_delay=True)
244
+ degree = sum(degrees)
245
+ degree_list = make_degree_list(degree)
246
+ index = dict(zip(degree_list[::-1], range(len(degree_list)), strict=False))[degrees]
247
+ max_index = len(degree_list)
248
+ return degree, index + max(0, 1 - 0.9 ** (max(delays) - (not zero_offset))), max_index
249
+
250
+ def color_func(segment):
251
+ white = np.ones(4)
252
+ degree, index, max_index = get_color_index(segment)
253
+ coef = (index / max_index) * min_color_coef
254
+ out = np.array(cmap(degree - 1))
255
+ out = (1 - coef) * out + coef * white
256
+ return out
257
+
258
+ def label_func(segment):
259
+ if group_by == "degree":
260
+ return str(segment[0])
261
+ elif group_by == "component":
262
+ out_str = str(shape_segment(segment))
263
+ return out_str.replace("(", "{").replace(",)", "}").replace(")", "}")
264
+ elif group_by == "detail":
265
+ degrees, delays = shape_segment(segment, get_delay=True)
266
+ out_str = str(tuple(zip(degrees, delays, strict=False)))
267
+ return out_str.replace("(", "{").replace(",)", "}").replace(")", "}")
268
+
269
+ def hatch_func(segment):
270
+ hatches = ["//", "\\\\", "||", "--", "++", "xx", "oo", "OO", "..", "**"]
271
+ if group_by == "degree":
272
+ return None
273
+ elif group_by == "component":
274
+ return None
275
+ elif group_by == "detail":
276
+ _degrees, delays = shape_segment(segment, get_delay=True)
277
+ return hatches[(max(delays) - (not zero_offset)) % len(hatches)]
278
+
279
+ def sort_func(arg):
280
+ segment, val = arg
281
+ if sort_by(val) > threshold:
282
+ if group_by == "degree":
283
+ return segment
284
+ elif group_by == "component":
285
+ degrees = shape_segment(segment)
286
+ return (sum(degrees), *(-d for d in degrees))
287
+ elif group_by == "detail":
288
+ degrees = shape_segment(segment)
289
+ return (
290
+ sum(degrees),
291
+ *(-s for s in segment[: (len(segment) // 2)]),
292
+ *segment[(len(segment) // 2) :],
293
+ )
294
+ else:
295
+ return (np.inf,)
296
+
297
+ # Aggregation process.
298
+ out = defaultdict(list)
299
+ segments = df[group_by_columns[group_by]].unique()
300
+ for column in ipc_columns:
301
+ df_agg = df.group_by(group_by_columns[group_by]).agg(pl.col(column).sum())
302
+ for segment in segments.iter_rows():
303
+ out[segment].append(0)
304
+ for *segment, val in df_agg.iter_rows():
305
+ out[tuple(segment)][-1] = val
306
+
307
+ # Visualization process.
308
+ bottom, rest, legend_cnt = 0.0, 0.0, 1
309
+ if xticks is None:
310
+ pos = x_offset + np.arange(-1, len(ipc_columns) + 1)
311
+ width = 1.0
312
+ else:
313
+ pos = np.zeros(len(ipc_columns) + 2)
314
+ pos[1:-1] = xticks
315
+ width = pos[1] - pos[0]
316
+ pos[0] = pos[1] - width
317
+ pos[-1] = pos[-2] + width
318
+
319
+ legend_cnt = 1
320
+ bottom = np.zeros_like(pos, dtype=float)
321
+ rest = np.zeros_like(pos, dtype=float)
322
+ for segment, val in sorted(out.items(), key=sort_func):
323
+ ipc = np.zeros_like(bottom)
324
+ ipc[1:-1] = val
325
+ if sort_by(ipc) > threshold:
326
+ ax.bar(
327
+ pos[1:-1],
328
+ ipc[1:-1],
329
+ bottom=bottom[1:-1] if bottom_min is None else np.maximum(bottom[1:-1], bottom_min),
330
+ width=width,
331
+ linewidth=0.0,
332
+ label=label_func(segment),
333
+ color=color_func(segment),
334
+ hatch=hatch_func(segment),
335
+ )
336
+ if step_linewidth > 0:
337
+ ax.step(
338
+ pos,
339
+ ipc + bottom if bottom_min is None else np.maximum(ipc + bottom, bottom_min),
340
+ "#333333",
341
+ where="mid",
342
+ linewidth=step_linewidth,
343
+ )
344
+ legend_cnt += 1
345
+ bottom += ipc
346
+ else:
347
+ rest += ipc
348
+ if threshold > 0:
349
+ ax.bar(
350
+ pos[1:-1],
351
+ rest[1:-1],
352
+ bottom=bottom[1:-1],
353
+ width=width,
354
+ label="rest",
355
+ color="darkgray",
356
+ hatch="/",
357
+ linewidth=0.0,
358
+ )
359
+ if step_linewidth > 0:
360
+ ax.step(pos, rest + bottom, "#333333", where="mid", linewidth=step_linewidth)
361
+ ax.tick_params(axis="both", which="major", labelsize=fontsize)
362
+ ax.legend(
363
+ loc="upper left",
364
+ bbox_to_anchor=(1.05, 1.0),
365
+ borderaxespad=0,
366
+ ncol=math.ceil(legend_cnt / 18),
367
+ fontsize=fontsize,
368
+ )
369
+ return out
370
+
371
+
372
+ __all__ = [name for name, cls in inspect.getmembers(sys.modules[__name__], inspect.isroutine)]
@@ -0,0 +1,215 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright (c) 2025, Katsuma Inoue. All rights reserved.
5
+ # This code is licensed under the MIT License.
6
+
7
+ import inspect
8
+ import math
9
+ import sys
10
+ import warnings
11
+
12
+
13
+ class BasePolynomial(object):
14
+ """
15
+ Base class for polynomial classes using recurrence relations.
16
+
17
+ Notes:
18
+ When implementing a new polynomial type, follow these guidelines:
19
+
20
+ - Inherit `BasePolynomial`.
21
+ - `calc` method should be overridden in subclasses.
22
+ - `super().__init__` should be called in the subclass constructor.
23
+ """
24
+
25
+ def __init__(self, xs, **_kwargs):
26
+ self.xs = xs
27
+ self.caches = {}
28
+
29
+ def __getitem__(self, args):
30
+ if type(args) is tuple:
31
+ deg, sli = args[0], args[1:]
32
+ else:
33
+ deg, sli = args, ()
34
+ assert deg >= 0
35
+ if deg not in self.caches:
36
+ self.caches[deg] = self.calc(deg)
37
+ if len(sli) > 0:
38
+ return self.caches[deg][sli]
39
+ else:
40
+ return self.caches[deg]
41
+
42
+ def calc(self, *args, **kwargs):
43
+ raise NotImplementedError
44
+
45
+
46
+ class Jacobi(BasePolynomial):
47
+ """
48
+ Jacobi polynomial class using recurrence relation
49
+ (Cf. [Wikipedia](https://en.wikipedia.org/wiki/Jacobi_polynomials#Recurrence_relation)).
50
+ """
51
+
52
+ def __init__(self, xs, a: float, b: float, **_kwargs):
53
+ """
54
+ Parameters:
55
+ xs (Any): Input values.
56
+ a (float): Parameter a of the Jacobi polynomial.
57
+ b (float): Parameter b of the Jacobi polynomial.
58
+ """
59
+ super(Jacobi, self).__init__(xs)
60
+ self.a, self.b = a, b
61
+ self.caches[0] = 1
62
+ self.caches[1] = self.xs * (a + b + 2) / 2 + (a - b) / 2
63
+
64
+ def calc(self, n: int):
65
+ a = n + self.a
66
+ b = n + self.b
67
+ c = a + b
68
+ d = 2 * n * (c - n) * (c - 2)
69
+ e = (c - 1) * (c - 2) * c
70
+ f = (c - 1) * (c - 2 * n) * (a - b)
71
+ g = -2 * (a - 1) * (b - 1) * c
72
+
73
+ res = (e / d) * self.xs * self[n - 1]
74
+ res += (f / d) * self[n - 1]
75
+ res += (g / d) * self[n - 2]
76
+ return res
77
+
78
+
79
+ class Legendre(Jacobi):
80
+ """
81
+ Legendre polynomial is a special case of Jacobi polynomial with `a = b = 0`
82
+ (Cf. [Wikipedia](https://en.wikipedia.org/wiki/Legendre_polynomials)).
83
+ """
84
+
85
+ def __init__(self, xs, **_kwargs):
86
+ """
87
+ Parameters:
88
+ xs (Any): Input values.
89
+ """
90
+ super(Legendre, self).__init__(xs, 0, 0)
91
+
92
+
93
+ class Hermite(BasePolynomial):
94
+ """
95
+ Hermite polynomial class using recurrence relation
96
+ (Cf. [Wikipedia](https://en.wikipedia.org/wiki/Hermite_polynomials#Recurrence_relation)).
97
+ """
98
+
99
+ def __init__(self, xs, normalize: bool = False, **_kwargs):
100
+ """
101
+ Parameters:
102
+ xs (Any): Input values.
103
+ normalize (bool, optional): Whether to use normalized Hermite polynomials.
104
+ """
105
+ super(Hermite, self).__init__(xs)
106
+ if normalize:
107
+ exp = math.e ** (-0.25 * (self.xs * self.xs))
108
+ exp *= (2 * math.pi) ** -0.25
109
+ self.caches[0] = exp
110
+ self.caches[1] = exp * self.xs
111
+ else:
112
+ self.caches[0] = 1
113
+ self.caches[1] = self.xs
114
+
115
+ def calc(self, n: int):
116
+ # res = self.xs * self[n - 1]
117
+ # res -= (n - 1) * self[n - 2]
118
+ res = math.sqrt(1 / n) * self.xs * self[n - 1]
119
+ res -= math.sqrt((n - 1) / n) * self[n - 2]
120
+ return res
121
+
122
+
123
+ class Laguerre(BasePolynomial):
124
+ """
125
+ Laguerre polynomial class using recurrence relation
126
+ (Cf. [Wikipedia](https://en.wikipedia.org/wiki/Laguerre_polynomials#Generalized_Laguerre_polynomials)).
127
+ """
128
+
129
+ def __init__(self, xs, a=0, **_kwargs):
130
+ """
131
+ Parameters:
132
+ xs (Any): Input values.
133
+ a (float, optional): Parameter a of the Laguerre polynomial.
134
+ """
135
+ super(Laguerre, self).__init__(xs)
136
+ self.a = a
137
+ self.caches[0] = 1
138
+ self.caches[1] = (a + 1) - self.xs
139
+
140
+ def __getitem__(self, n: int):
141
+ a = self.a
142
+ res = ((2 * n - 1 + a) / n) * self[n - 1]
143
+ res -= (1 / n) * self.xs * self[n - 1]
144
+ res -= ((n - 1 + a) / n) * self[n - 2]
145
+ return res
146
+
147
+
148
+ class Krawtchouk(BasePolynomial):
149
+ """
150
+ Krawtchouk polynomial class using three-term recurrence relation
151
+ (Cf. [Wikipedia](https://en.wikipedia.org/wiki/Kravchuk_polynomials#Three_term_recurrence)).
152
+ """
153
+
154
+ def __init__(self, xs, N=2, p=0.5, **_kwargs):
155
+ """
156
+ Parameters:
157
+ xs (Any): Input values.
158
+ N (int, optional): Parameter N of the Krawtchouk polynomial.
159
+ p (float, optional): Parameter p of the Krawtchouk polynomial.
160
+ """
161
+ super(Krawtchouk, self).__init__(xs)
162
+ self.N, self.p = N, p
163
+ self.caches[0] = 1
164
+ self.caches[1] = 1 - self.xs * (1 / (self.N * self.p))
165
+
166
+ def calc(self, n: int):
167
+ assert n <= self.N, f"argument should be equal or less than N={self.N}, but {n} was given."
168
+ res = (self.p * (self.N - n + 1) + (n - 1) * (1 - self.p) - self.xs) * self[n - 1]
169
+ res -= (n - 1) * (1 - self.p) * self[n - 2]
170
+ res /= self.p * (self.N - n + 1)
171
+ return res
172
+
173
+
174
+ class GramSchmidt(BasePolynomial):
175
+ """
176
+ Gram-Schmidt polynomial class using the Gram-Schmidt process.
177
+ """
178
+
179
+ def __init__(self, xs, axis=None, depth=None, **_kwargs):
180
+ """
181
+ Parameters:
182
+ xs (Any): Input values.
183
+ axis (int | None, optional): Axis along which to perform the Gram-Schmidt process.
184
+ depth (int | None, optional): Depth of orthogonalization. If `None`, full depth is used.
185
+
186
+ Notes:
187
+ If `axis` is `None`, be cautious when `xs` is multidimensional, as it might cause unexpected behavior.
188
+ `axis=-2` is specified by the `BatchRegressor` class since the time dimension is the second-to-last dimension.
189
+ """
190
+ super(GramSchmidt, self).__init__(xs)
191
+ if axis is None:
192
+ warnings.warn(
193
+ "Note that axis is set to `None`. Be careful when xs is multidimensional.",
194
+ stacklevel=2,
195
+ )
196
+ self.axis = axis
197
+ self.depth = depth
198
+ self.caches[0] = 1
199
+
200
+ def offset(self, n: int):
201
+ return max(1, 1 if self.depth is None else n - self.depth)
202
+
203
+ def calc(self, n: int):
204
+ if n > 1:
205
+ base = self[1] * self[n - 1]
206
+ else:
207
+ base = self.xs
208
+ out = base - base.mean(axis=self.axis, keepdims=True)
209
+ for i in range(self.offset(n), n):
210
+ out -= (base * self[i]).sum(axis=self.axis, keepdims=True) * self[i]
211
+ out *= (out * out).sum(axis=self.axis, keepdims=True) ** (-0.5)
212
+ return out
213
+
214
+
215
+ __all__ = [name for name, cls in inspect.getmembers(sys.modules[__name__], inspect.isclass)]