pysiglib 3.0.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. pysiglib/README.md +6 -0
  2. pysiglib/__init__.py +48 -0
  3. pysiglib/_config.py +7 -0
  4. pysiglib/_version.py +22 -0
  5. pysiglib/branched_sig.py +304 -0
  6. pysiglib/branched_sig_backprop.py +178 -0
  7. pysiglib/cpsig.dll +0 -0
  8. pysiglib/data_handlers.py +301 -0
  9. pysiglib/dtypes.py +335 -0
  10. pysiglib/error_codes.py +35 -0
  11. pysiglib/jax_api/__init__.py +109 -0
  12. pysiglib/jax_api/_ffi.py +584 -0
  13. pysiglib/jax_api/jax_api.py +1242 -0
  14. pysiglib/jax_api/static_kernels_jax.py +145 -0
  15. pysiglib/linear_sig.py +104 -0
  16. pysiglib/load_siglib.py +671 -0
  17. pysiglib/log_sig.py +400 -0
  18. pysiglib/log_sig_backprop.py +171 -0
  19. pysiglib/log_sig_combine.py +167 -0
  20. pysiglib/log_sig_join.py +124 -0
  21. pysiglib/log_sig_join_backprop.py +134 -0
  22. pysiglib/logsig_to_sig.py +124 -0
  23. pysiglib/logsig_to_sig_backprop.py +112 -0
  24. pysiglib/param_checks.py +151 -0
  25. pysiglib/pysiglib_jax_ffi_cpu.dll +0 -0
  26. pysiglib/pysiglib_jax_ffi_cpu.lib +0 -0
  27. pysiglib/sig.py +251 -0
  28. pysiglib/sig_backprop.py +241 -0
  29. pysiglib/sig_coef.py +258 -0
  30. pysiglib/sig_coef_backprop.py +187 -0
  31. pysiglib/sig_join.py +119 -0
  32. pysiglib/sig_join_backprop.py +129 -0
  33. pysiglib/sig_kernel.py +468 -0
  34. pysiglib/sig_kernel_backprop.py +463 -0
  35. pysiglib/sig_length.py +182 -0
  36. pysiglib/sig_metrics.py +369 -0
  37. pysiglib/static_kernels.py +437 -0
  38. pysiglib/streams.py +803 -0
  39. pysiglib/torch_api/__init__.py +69 -0
  40. pysiglib/torch_api/torch_api.py +766 -0
  41. pysiglib/transform_path.py +151 -0
  42. pysiglib/transform_path_backprop.py +105 -0
  43. pysiglib/trees.py +356 -0
  44. pysiglib/words.py +337 -0
  45. pysiglib-3.0.0.dist-info/METADATA +234 -0
  46. pysiglib-3.0.0.dist-info/RECORD +48 -0
  47. pysiglib-3.0.0.dist-info/WHEEL +5 -0
  48. pysiglib-3.0.0.dist-info/licenses/LICENSE +201 -0
pysiglib/README.md ADDED
@@ -0,0 +1,6 @@
1
+ <h1 align='center'>pySigLib</h1>
2
+
3
+ <h2 align='center'>A Python Wrapper for cpSIG and cuSIG</h2>
4
+
5
+ This directory contains source code for `pysiglib`, which is a python wrapper
6
+ for the C++ libraries `cpSIG` and `cuSIG`.
pysiglib/__init__.py ADDED
@@ -0,0 +1,48 @@
1
+ # Copyright 2025 Daniil Shmelev
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =========================================================================
15
+
16
+ from .load_siglib import SYSTEM, BUILT_WITH_CUDA, BUILT_WITH_AVX, BUILT_WITH_JAX_FFI
17
+ from .words import words_of_length, words, lyndon_words_of_length, lyndon_words, is_lyndon, word_to_idx, idx_to_word
18
+ from .trees import trees, trees_of_order, tree_to_idx, idx_to_tree
19
+ from .sig_length import sig_length, log_sig_length
20
+ from .sig import sig_combine, sig
21
+ from .sig_backprop import sig_backprop, sig_combine_backprop
22
+ from .linear_sig import linear_sig
23
+ from .sig_join import sig_join
24
+ from .sig_join_backprop import sig_join_backprop
25
+ from .log_sig_join import log_sig_join
26
+ from .log_sig_join_backprop import log_sig_join_backprop
27
+ from .sig_coef import extract_sig_coef, sig_coef
28
+ from .sig_coef_backprop import sig_coef_backprop
29
+ from .log_sig import set_cache_dir, prepare_log_sig, clear_cache, sig_to_log_sig, log_sig
30
+ from .log_sig_backprop import sig_to_log_sig_backprop
31
+ from .logsig_to_sig import logsig_to_sig
32
+ from .logsig_to_sig_backprop import logsig_to_sig_backprop
33
+ from .log_sig_combine import log_sig_combine, log_sig_combine_backprop
34
+ from .sig_kernel import sig_kernel, sig_kernel_gram
35
+ from .sig_kernel_backprop import sig_kernel_backprop, sig_kernel_gram_backprop
36
+ from .sig_metrics import sig_score, expected_sig_score, sig_mmd
37
+ from .transform_path import transform_path
38
+ from .transform_path_backprop import transform_path_backprop
39
+ from .static_kernels import Context, StaticKernel, LinearKernel, ScaledLinearKernel, RBFKernel, PolynomialKernel, Matern12Kernel, Matern32Kernel, Matern52Kernel, RationalQuadraticKernel
40
+ from .branched_sig import prepare_branched_sig, branched_sig, branched_sig_combine, branched_sig_length
41
+ from .branched_sig_backprop import branched_sig_backprop, branched_sig_combine_backprop
42
+ from .streams import SigStream, LogSigStream, SigWindowStream, LogSigWindowStream
43
+
44
+ signature = sig
45
+
46
+ import pysiglib.torch_api
47
+
48
+ from ._version import __version__
pysiglib/_config.py ADDED
@@ -0,0 +1,7 @@
1
+ # This file is automatically generated by CMake and should not be edited.
2
+ # It contains information about how the package was built.
3
+ SYSTEM = 'Windows'
4
+ BUILT_WITH_JAX_FFI = True
5
+ BUILT_WITH_AVX = True
6
+ CXX_COMPILER = 'MSVC 19.44.35225.0'
7
+ CMAKE_GENERATOR = 'Visual Studio 17 2022'
pysiglib/_version.py ADDED
@@ -0,0 +1,22 @@
1
+ # Copyright 2026 Daniil Shmelev
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =========================================================================
15
+
16
+ # Single source of truth for pysiglib's version. Read by:
17
+ # - pysiglib/__init__.py (runtime __version__)
18
+ # - pyproject.toml [tool.scikit-build.metadata.version] (wheel metadata)
19
+ # - plugins/cuda/pyproject.toml (plugin wheel metadata; same version for ABI lockstep)
20
+ # - docs/conf.py (Sphinx release tag)
21
+ # - plugins/cuda/pysiglib_cuda/__init__.py (runtime ABI guard)
22
+ __version__ = "3.0.0"
@@ -0,0 +1,304 @@
1
+ # Copyright 2026 Daniil Shmelev
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =========================================================================
15
+
16
+ from typing import Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from .param_checks import check_type, check_non_neg, check_n_jobs
22
+ from .error_codes import err_msg
23
+ from .sig_length import aug_dim
24
+ from .dtypes import (CPSIG_BRANCHED_SIG, CPSIG_BRANCHED_SIG_COMBINE,
25
+ CUSIG_BRANCHED_SIG_CUDA, CUSIG_BRANCHED_SIG_COMBINE_CUDA)
26
+ from .data_handlers import PathInputHandler, SigOutputHandler, MultipleSigInputHandler
27
+ from .load_siglib import CPSIG
28
+ import kauri
29
+
30
+
31
+ def _permute_bsig(data, dimension, degree, planar=False, scalar_term=True):
32
+ """Permute branched sig from recursive order to canonical order (in-place).
33
+
34
+ The permutation acts on the non-empty trees only; the leading scalar (if
35
+ present) is left in place.
36
+ """
37
+ if planar:
38
+ perm = kauri.planar_canonical_to_recursive_permutation(dimension, degree)
39
+ else:
40
+ perm = kauri.canonical_to_recursive_permutation(dimension, degree)
41
+ start = 1 if scalar_term else 0
42
+ data[..., start:] = data[..., start:][..., perm]
43
+ return data
44
+
45
+
46
+ def _inv_permute_bsig(data, dimension, degree, planar=False, scalar_term=True):
47
+ """Permute branched sig from canonical order to recursive order. Returns a new array.
48
+
49
+ The permutation acts on the non-empty trees only; the leading scalar (if
50
+ present) is left in place.
51
+ """
52
+ if planar:
53
+ inv_perm = kauri.planar_recursive_to_canonical_permutation(dimension, degree)
54
+ else:
55
+ inv_perm = kauri.recursive_to_canonical_permutation(dimension, degree)
56
+ if isinstance(data, np.ndarray):
57
+ out = np.empty_like(data)
58
+ else:
59
+ out = torch.empty_like(data)
60
+ start = 1 if scalar_term else 0
61
+ if scalar_term:
62
+ out[..., :1] = data[..., :1]
63
+ out[..., start:] = data[..., start:][..., inv_perm]
64
+ return out
65
+
66
+
67
+ def prepare_branched_sig(
68
+ dimension: int,
69
+ degree: int,
70
+ *,
71
+ use_disk: bool = False,
72
+ time_aug: bool = False,
73
+ lead_lag: bool = False,
74
+ planar: bool = False
75
+ ):
76
+ """
77
+ Precomputes the tree enumeration and Connes-Kreimer coproduct tables
78
+ needed for branched signature computation. Must be called before
79
+ ``branched_sig()`` for a given ``(dimension, degree)`` pair.
80
+
81
+ If ``time_aug`` or ``lead_lag`` are set, the cache is prepared for
82
+ the augmented dimension automatically.
83
+
84
+ :param dimension: Dimension of the underlying path.
85
+ :param degree: Maximum tree order (number of nodes).
86
+ :param use_disk: If True, cache the precomputed tables to disk for
87
+ faster loading in future sessions. Uses the same cache directory
88
+ as ``set_cache_dir()`` / ``prepare_log_sig()``.
89
+ :param time_aug: If True, prepare for time-augmented paths (dim + 1).
90
+ :param lead_lag: If True, prepare for lead-lag transformed paths (2 * dim).
91
+ :param planar: If True, prepare for planar (ordered) branched signatures.
92
+ """
93
+ check_type(dimension, "dimension", int)
94
+ check_type(degree, "degree", int)
95
+ check_type(use_disk, "use_disk", bool)
96
+ check_type(time_aug, "time_aug", bool)
97
+ check_type(lead_lag, "lead_lag", bool)
98
+ check_type(planar, "planar", bool)
99
+ check_non_neg(dimension, "dimension")
100
+ check_non_neg(degree, "degree")
101
+ aug_dimension = aug_dim(dimension, time_aug, lead_lag)
102
+ err_code = CPSIG.prepare_branched_sig(aug_dimension, degree, use_disk, planar)
103
+ if err_code:
104
+ raise Exception("Error in pysiglib.prepare_branched_sig: " + err_msg(err_code))
105
+
106
+
107
+ def branched_sig_length(dimension: int, degree: int, *, planar: bool = False, scalar_term: bool = False) -> int:
108
+ """
109
+ Returns the length of a truncated branched signature.
110
+
111
+ :param dimension: Dimension of the underlying path.
112
+ :param degree: Maximum tree order (number of nodes).
113
+ :param planar: If True, return the length for planar (ordered) branched signatures.
114
+ :param scalar_term: If True, includes the empty-tree scalar term at index 0 in the length.
115
+ If False (default), the returned length is one less (matching ``branched_sig``
116
+ output with ``scalar_term=False``).
117
+ :return: Length of the branched signature array.
118
+ """
119
+ check_type(dimension, "dimension", int)
120
+ check_type(degree, "degree", int)
121
+ check_non_neg(dimension, "dimension")
122
+ check_non_neg(degree, "degree")
123
+ out = CPSIG.branched_sig_length(dimension, degree, planar)
124
+ if out == 0:
125
+ raise ValueError("Invalid parameters or integer overflow in branched_sig_length")
126
+ return out - (0 if scalar_term else 1)
127
+
128
+
129
+ _CUDA_MAX_NUM_TREES = 1024 # CUDA kernel hardcoded thread-block size limit
130
+
131
+
132
+ def _check_cuda_num_trees(dimension: int, degree: int, planar: bool, fn_name: str) -> None:
133
+ """Precheck the number of rooted trees against the CUDA kernel limit.
134
+
135
+ The branched_sig CUDA kernel launches one thread per tree within a single
136
+ block, capped at 1024. Above that, the kernel aborts with an opaque
137
+ ``Invalid argument (2)`` error. Surface a clear Python-level error instead.
138
+ """
139
+ num_trees = branched_sig_length(dimension, degree, planar=planar, scalar_term=False)
140
+ if num_trees > _CUDA_MAX_NUM_TREES:
141
+ raise RuntimeError(
142
+ f"{fn_name}: num_trees={num_trees} exceeds CUDA kernel limit of "
143
+ f"{_CUDA_MAX_NUM_TREES} for (dim={dimension}, degree={degree}"
144
+ + (f", planar={planar}" if planar else "")
145
+ + "). Use CPU or reduce degree."
146
+ )
147
+
148
+
149
+ def _infer_branched_scalar_term(bsig, dimension: int, degree: int, planar: bool = False) -> bool:
150
+ """Return True iff ``bsig``'s trailing dimension includes the leading scalar 1.
151
+
152
+ Raises ``ValueError`` if the shape matches neither the scalar_term=True nor
153
+ the scalar_term=False branched-signature length for the given
154
+ ``(dimension, degree, planar)``. Used by consumer-side branched-sig
155
+ functions that accept bsigs in either format and match their output format
156
+ to the input.
157
+ """
158
+ full_len = branched_sig_length(dimension, degree, planar=planar, scalar_term=True)
159
+ actual = bsig.shape[-1]
160
+ if actual == full_len:
161
+ return True
162
+ if actual == full_len - 1:
163
+ return False
164
+ raise ValueError(
165
+ "bsig has incompatible length " + str(actual) + " for dimension=" + str(dimension) +
166
+ ", degree=" + str(degree) + ", planar=" + str(planar) +
167
+ " (expected " + str(full_len) + " or " + str(full_len - 1) + ")."
168
+ )
169
+
170
+
171
+ def branched_sig(
172
+ path: Union[np.ndarray, torch.Tensor],
173
+ degree: int,
174
+ *,
175
+ time_aug: bool = False,
176
+ lead_lag: bool = False,
177
+ end_time: float = 1.0,
178
+ tree_order: str = "recursive",
179
+ planar: bool = False,
180
+ scalar_term : bool = False,
181
+ n_jobs: int = 1,
182
+ ) -> Union[np.ndarray, torch.Tensor]:
183
+ """
184
+ Computes the truncated branched signature of a path or batch of paths.
185
+
186
+ The branched signature extends the standard path signature to iterated
187
+ integrals indexed by decorated rooted trees, following Gubinelli (2010).
188
+
189
+ Must call ``prepare_branched_sig(dimension, degree, planar=planar)``
190
+ before first use, where ``dimension`` is the augmented dimension
191
+ (accounting for ``time_aug`` and ``lead_lag``).
192
+
193
+ :param path: Path of shape ``(length, dimension)`` or ``(..., length, dimension)``.
194
+ :param degree: Maximum tree order (number of nodes).
195
+ :param time_aug: If True, prepend a time channel to the path.
196
+ :param lead_lag: If True, apply the lead-lag transformation.
197
+ :param end_time: End time for the time augmentation channel.
198
+ :param tree_order: Tree ordering convention for the output coefficients.
199
+ ``"recursive"`` (default) uses the recursive construction order.
200
+ ``"canonical"`` uses the shape-first order matching :func:`tree_to_idx`.
201
+ :param planar: If True, compute the planar (ordered) branched signature.
202
+ :param scalar_term: If True, the output includes the leading constant 1 at index 0
203
+ (the empty-word term). If False (default), this leading element is stripped from the output.
204
+ :type scalar_term: bool
205
+ :param n_jobs: Number of parallel threads for batch processing.
206
+ :return: Branched signature array of shape ``(bsig_len,)`` or ``(..., bsig_len)``.
207
+ """
208
+ if tree_order not in ("recursive", "canonical"):
209
+ raise ValueError(f"tree_order must be 'recursive' or 'canonical', got {tree_order!r}")
210
+ check_type(degree, "degree", int)
211
+ check_type(time_aug, "time_aug", bool)
212
+ check_type(lead_lag, "lead_lag", bool)
213
+ check_type(end_time, "end_time", float)
214
+ check_type(planar, "planar", bool)
215
+ check_non_neg(degree, "degree")
216
+ check_n_jobs(n_jobs)
217
+
218
+ data = PathInputHandler(path, time_aug, lead_lag, end_time, "path")
219
+ dimension = data.data_dimension
220
+ aug_dimension = data.dimension
221
+ bsig_len = branched_sig_length(aug_dimension, degree, planar=planar, scalar_term=scalar_term)
222
+ result = SigOutputHandler(data, bsig_len)
223
+
224
+ if data.batch_size == 0:
225
+ return result.data
226
+
227
+ if data.device == "cpu":
228
+ err_code = CPSIG_BRANCHED_SIG[data.dtype](
229
+ data.data_ptr, result.data_ptr, data.batch_size,
230
+ dimension, data.data_length, degree, n_jobs,
231
+ data.time_aug, data.lead_lag, data.end_time, planar, scalar_term)
232
+ else:
233
+ _check_cuda_num_trees(aug_dimension, degree, planar, "branched_sig")
234
+ err_code = CUSIG_BRANCHED_SIG_CUDA[data.dtype](
235
+ data.data_ptr, result.data_ptr, data.batch_size,
236
+ dimension, data.data_length, degree,
237
+ data.time_aug, data.lead_lag, data.end_time, planar, scalar_term)
238
+ if err_code:
239
+ raise Exception("Error in pysiglib.branched_sig: " + err_msg(err_code))
240
+ if tree_order != "recursive":
241
+ _permute_bsig(result.data, aug_dimension, degree, planar=planar, scalar_term=scalar_term)
242
+ return result.data
243
+
244
+
245
+ def branched_sig_combine(
246
+ bsig1: Union[np.ndarray, torch.Tensor],
247
+ bsig2: Union[np.ndarray, torch.Tensor],
248
+ dimension: int,
249
+ degree: int,
250
+ *,
251
+ tree_order: str = "recursive",
252
+ planar: bool = False,
253
+ n_jobs: int = 1,
254
+ ) -> Union[np.ndarray, torch.Tensor]:
255
+ """
256
+ Combines two truncated branched signatures via the Butcher product
257
+ (the analogue of Chen's identity for branched rough paths).
258
+
259
+ :param bsig1: First branched signature, in the ordering specified by ``tree_order``.
260
+ :param bsig2: Second branched signature, in the ordering specified by ``tree_order``.
261
+ :param dimension: Dimension of the underlying path.
262
+ :param degree: Maximum tree order (number of nodes).
263
+ :param tree_order: Tree ordering convention for inputs and output.
264
+ ``"recursive"`` (default) uses the recursive construction order.
265
+ ``"canonical"`` uses the shape-first order matching :func:`tree_to_idx`.
266
+ :param planar: If True, combine planar (ordered) branched signatures.
267
+ :param n_jobs: Number of parallel threads for batch processing.
268
+ :return: Combined branched signature, in the same ordering and scalar-term format as the inputs.
269
+ """
270
+ if tree_order not in ("recursive", "canonical"):
271
+ raise ValueError(f"tree_order must be 'recursive' or 'canonical', got {tree_order!r}")
272
+ check_type(dimension, "dimension", int)
273
+ check_type(degree, "degree", int)
274
+ check_type(planar, "planar", bool)
275
+ check_non_neg(dimension, "dimension")
276
+ check_non_neg(degree, "degree")
277
+ check_n_jobs(n_jobs)
278
+
279
+ scalar_term = _infer_branched_scalar_term(bsig1, dimension, degree, planar=planar)
280
+ if tree_order != "recursive":
281
+ bsig1 = _inv_permute_bsig(bsig1, dimension, degree, planar=planar, scalar_term=scalar_term)
282
+ bsig2 = _inv_permute_bsig(bsig2, dimension, degree, planar=planar, scalar_term=scalar_term)
283
+
284
+ bsig_len = branched_sig_length(dimension, degree, planar=planar, scalar_term=scalar_term)
285
+ data = MultipleSigInputHandler([bsig1, bsig2], bsig_len, ["bsig1", "bsig2"])
286
+ result = SigOutputHandler(data, bsig_len)
287
+
288
+ if data.batch_size == 0:
289
+ return result.data
290
+
291
+ if data.device == "cpu":
292
+ err_code = CPSIG_BRANCHED_SIG_COMBINE[data.dtype](
293
+ data.sig_ptr[0], data.sig_ptr[1], result.data_ptr,
294
+ data.batch_size, dimension, degree, n_jobs, planar, scalar_term)
295
+ else:
296
+ _check_cuda_num_trees(dimension, degree, planar, "branched_sig_combine")
297
+ err_code = CUSIG_BRANCHED_SIG_COMBINE_CUDA[data.dtype](
298
+ data.sig_ptr[0], data.sig_ptr[1], result.data_ptr,
299
+ data.batch_size, dimension, degree, planar, scalar_term)
300
+ if err_code:
301
+ raise Exception("Error in pysiglib.branched_sig_combine: " + err_msg(err_code))
302
+ if tree_order != "recursive":
303
+ _permute_bsig(result.data, dimension, degree, planar=planar, scalar_term=scalar_term)
304
+ return result.data
@@ -0,0 +1,178 @@
1
+ # Copyright 2026 Daniil Shmelev
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # =========================================================================
15
+
16
+ from typing import Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from .param_checks import check_type, check_non_neg, check_n_jobs
22
+ from .error_codes import err_msg
23
+ from .dtypes import (CPSIG_BRANCHED_SIG_BACKPROP,
24
+ CPSIG_BRANCHED_SIG_COMBINE_BACKPROP,
25
+ CUSIG_BRANCHED_SIG_BACKPROP_CUDA,
26
+ CUSIG_BRANCHED_SIG_COMBINE_BACKPROP_CUDA)
27
+ from .data_handlers import PathInputHandler, PathOutputHandler, MultipleSigInputHandler, SigOutputHandler
28
+ from .branched_sig import _inv_permute_bsig, _permute_bsig, branched_sig_length, _infer_branched_scalar_term, _check_cuda_num_trees
29
+
30
+
31
+ def branched_sig_backprop(
32
+ path: Union[np.ndarray, torch.Tensor],
33
+ bsig: Union[np.ndarray, torch.Tensor],
34
+ bsig_derivs: Union[np.ndarray, torch.Tensor],
35
+ degree: int,
36
+ *,
37
+ time_aug: bool = False,
38
+ lead_lag: bool = False,
39
+ end_time: float = 1.0,
40
+ tree_order: str = "recursive",
41
+ planar: bool = False,
42
+ n_jobs: int = 1,
43
+ ) -> Union[np.ndarray, torch.Tensor]:
44
+ """
45
+ Backpropagates through the branched signature computation.
46
+
47
+ Given the forward branched signature ``bsig = branched_sig(path, degree)``
48
+ and upstream derivatives ``bsig_derivs = dF/d(bsig)``, computes
49
+ ``dF/d(path)``.
50
+
51
+ :param path: Input path, shape ``(length, dimension)`` or ``(batch, length, dimension)``.
52
+ :param bsig: Forward branched signature output.
53
+ :param bsig_derivs: Upstream derivatives w.r.t. the branched signature.
54
+ :param degree: Maximum tree order (must match forward call).
55
+ :param time_aug: Whether time augmentation was used in the forward pass.
56
+ :param lead_lag: Whether lead-lag was used in the forward pass.
57
+ :param end_time: End time for time augmentation.
58
+ :param tree_order: Tree ordering convention of ``bsig`` and ``bsig_derivs``.
59
+ ``"recursive"`` (default) uses the recursive construction order.
60
+ ``"canonical"`` uses the shape-first order matching :func:`tree_to_idx`.
61
+ :param planar: If True, backpropagate through planar branched signature.
62
+ :param n_jobs: Number of parallel threads for batch processing.
63
+ :return: Path derivatives, same shape as ``path``.
64
+ """
65
+ if tree_order not in ("recursive", "canonical"):
66
+ raise ValueError(f"tree_order must be 'recursive' or 'canonical', got {tree_order!r}")
67
+ check_type(degree, "degree", int)
68
+ check_type(time_aug, "time_aug", bool)
69
+ check_type(lead_lag, "lead_lag", bool)
70
+ check_type(end_time, "end_time", float)
71
+ check_type(planar, "planar", bool)
72
+ check_non_neg(degree, "degree")
73
+ check_n_jobs(n_jobs)
74
+
75
+ path_data = PathInputHandler(path, time_aug, lead_lag, end_time, "path")
76
+ dimension = path_data.data_dimension
77
+ aug_dimension = path_data.dimension
78
+
79
+ scalar_term = _infer_branched_scalar_term(bsig, aug_dimension, degree, planar=planar)
80
+ if tree_order != "recursive":
81
+ bsig = _inv_permute_bsig(bsig, aug_dimension, degree, planar=planar, scalar_term=scalar_term)
82
+ bsig_derivs = _inv_permute_bsig(bsig_derivs, aug_dimension, degree, planar=planar, scalar_term=scalar_term)
83
+
84
+ bsig_len = branched_sig_length(aug_dimension, degree, planar=planar, scalar_term=scalar_term)
85
+ sig_data = MultipleSigInputHandler([bsig, bsig_derivs], bsig_len, ["bsig", "bsig_derivs"])
86
+ result = PathOutputHandler(path_data.data_length, path_data.data_dimension, path_data)
87
+
88
+ if path_data.batch_size == 0:
89
+ return result.data
90
+
91
+ if path_data.device == "cpu":
92
+ err_code = CPSIG_BRANCHED_SIG_BACKPROP[path_data.dtype](
93
+ path_data.data_ptr, result.data_ptr,
94
+ sig_data.sig_ptr[1], sig_data.sig_ptr[0],
95
+ path_data.batch_size, dimension, path_data.data_length, degree, n_jobs,
96
+ path_data.time_aug, path_data.lead_lag, path_data.end_time, planar, scalar_term)
97
+ else:
98
+ _check_cuda_num_trees(aug_dimension, degree, planar, "branched_sig_backprop")
99
+ err_code = CUSIG_BRANCHED_SIG_BACKPROP_CUDA[path_data.dtype](
100
+ path_data.data_ptr, result.data_ptr,
101
+ sig_data.sig_ptr[1], sig_data.sig_ptr[0],
102
+ path_data.batch_size, dimension, path_data.data_length, degree,
103
+ path_data.time_aug, path_data.lead_lag, path_data.end_time, planar, scalar_term)
104
+ if err_code:
105
+ raise Exception("Error in pysiglib.branched_sig_backprop: " + err_msg(err_code))
106
+ return result.data
107
+
108
+
109
+ def branched_sig_combine_backprop(
110
+ derivs: Union[np.ndarray, torch.Tensor],
111
+ bsig1: Union[np.ndarray, torch.Tensor],
112
+ bsig2: Union[np.ndarray, torch.Tensor],
113
+ dimension: int,
114
+ degree: int,
115
+ *,
116
+ tree_order: str = "recursive",
117
+ planar: bool = False,
118
+ n_jobs: int = 1,
119
+ ) -> tuple:
120
+ """
121
+ Backpropagates through the branched signature combine (Butcher product).
122
+
123
+ Given ``out = branched_sig_combine(bsig1, bsig2, dimension, degree)``
124
+ and upstream derivatives ``derivs = dF/d(out)``, computes
125
+ ``(dF/d(bsig1), dF/d(bsig2))``.
126
+
127
+ :param derivs: Upstream derivatives, same shape as combine output.
128
+ :param bsig1: First branched signature input to the forward combine.
129
+ :param bsig2: Second branched signature input to the forward combine.
130
+ :param dimension: Dimension of the underlying path.
131
+ :param degree: Maximum tree order.
132
+ :param tree_order: Tree ordering convention of ``derivs``, ``bsig1``, ``bsig2`` and the returned gradients.
133
+ ``"recursive"`` (default) uses the recursive construction order.
134
+ ``"canonical"`` uses the shape-first order matching :func:`tree_to_idx`.
135
+ :param planar: If True, backpropagate through planar branched sig combine.
136
+ :param n_jobs: Number of parallel threads for batch processing.
137
+ :return: Tuple ``(dF/d(bsig1), dF/d(bsig2))``, in the same scalar-term format as the inputs.
138
+ """
139
+ if tree_order not in ("recursive", "canonical"):
140
+ raise ValueError(f"tree_order must be 'recursive' or 'canonical', got {tree_order!r}")
141
+ check_type(dimension, "dimension", int)
142
+ check_type(degree, "degree", int)
143
+ check_type(planar, "planar", bool)
144
+ check_non_neg(dimension, "dimension")
145
+ check_non_neg(degree, "degree")
146
+ check_n_jobs(n_jobs)
147
+
148
+ scalar_term = _infer_branched_scalar_term(bsig1, dimension, degree, planar=planar)
149
+ if tree_order != "recursive":
150
+ derivs = _inv_permute_bsig(derivs, dimension, degree, planar=planar, scalar_term=scalar_term)
151
+ bsig1 = _inv_permute_bsig(bsig1, dimension, degree, planar=planar, scalar_term=scalar_term)
152
+ bsig2 = _inv_permute_bsig(bsig2, dimension, degree, planar=planar, scalar_term=scalar_term)
153
+
154
+ bsig_len = branched_sig_length(dimension, degree, planar=planar, scalar_term=scalar_term)
155
+ data = MultipleSigInputHandler([derivs, bsig1, bsig2], bsig_len, ["derivs", "bsig1", "bsig2"])
156
+ result1 = SigOutputHandler(data, bsig_len)
157
+ result2 = SigOutputHandler(data, bsig_len)
158
+
159
+ if data.batch_size == 0:
160
+ return result1.data, result2.data
161
+
162
+ if data.device == "cpu":
163
+ err_code = CPSIG_BRANCHED_SIG_COMBINE_BACKPROP[data.dtype](
164
+ data.sig_ptr[1], data.sig_ptr[2], data.sig_ptr[0],
165
+ result1.data_ptr, result2.data_ptr,
166
+ data.batch_size, dimension, degree, n_jobs, planar, scalar_term)
167
+ else:
168
+ _check_cuda_num_trees(dimension, degree, planar, "branched_sig_combine_backprop")
169
+ err_code = CUSIG_BRANCHED_SIG_COMBINE_BACKPROP_CUDA[data.dtype](
170
+ data.sig_ptr[1], data.sig_ptr[2], data.sig_ptr[0],
171
+ result1.data_ptr, result2.data_ptr,
172
+ data.batch_size, dimension, degree, planar, scalar_term)
173
+ if err_code:
174
+ raise Exception("Error in pysiglib.branched_sig_combine_backprop: " + err_msg(err_code))
175
+ if tree_order != "recursive":
176
+ _permute_bsig(result1.data, dimension, degree, planar=planar, scalar_term=scalar_term)
177
+ _permute_bsig(result2.data, dimension, degree, planar=planar, scalar_term=scalar_term)
178
+ return result1.data, result2.data
pysiglib/cpsig.dll ADDED
Binary file