pymc-extras 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/distributions/__init__.py +5 -5
- pymc_extras/distributions/histogram_utils.py +1 -1
- pymc_extras/inference/__init__.py +1 -1
- pymc_extras/inference/laplace_approx/find_map.py +12 -5
- pymc_extras/inference/laplace_approx/idata.py +4 -3
- pymc_extras/inference/laplace_approx/laplace.py +6 -4
- pymc_extras/inference/pathfinder/pathfinder.py +1 -2
- pymc_extras/printing.py +1 -1
- pymc_extras/statespace/__init__.py +4 -4
- pymc_extras/statespace/core/__init__.py +1 -1
- pymc_extras/statespace/core/representation.py +8 -8
- pymc_extras/statespace/core/statespace.py +94 -23
- pymc_extras/statespace/filters/__init__.py +3 -3
- pymc_extras/statespace/filters/kalman_filter.py +16 -11
- pymc_extras/statespace/models/SARIMAX.py +138 -74
- pymc_extras/statespace/models/VARMAX.py +248 -57
- pymc_extras/statespace/models/__init__.py +2 -2
- pymc_extras/statespace/models/structural/__init__.py +21 -0
- pymc_extras/statespace/models/structural/components/__init__.py +0 -0
- pymc_extras/statespace/models/structural/components/autoregressive.py +213 -0
- pymc_extras/statespace/models/structural/components/cycle.py +325 -0
- pymc_extras/statespace/models/structural/components/level_trend.py +289 -0
- pymc_extras/statespace/models/structural/components/measurement_error.py +154 -0
- pymc_extras/statespace/models/structural/components/regression.py +257 -0
- pymc_extras/statespace/models/structural/components/seasonality.py +628 -0
- pymc_extras/statespace/models/structural/core.py +919 -0
- pymc_extras/statespace/models/structural/utils.py +16 -0
- pymc_extras/statespace/models/utilities.py +285 -0
- pymc_extras/statespace/utils/constants.py +21 -18
- pymc_extras/statespace/utils/data_tools.py +4 -3
- {pymc_extras-0.3.1.dist-info → pymc_extras-0.4.1.dist-info}/METADATA +5 -4
- {pymc_extras-0.3.1.dist-info → pymc_extras-0.4.1.dist-info}/RECORD +34 -25
- pymc_extras/statespace/models/structural.py +0 -1679
- {pymc_extras-0.3.1.dist-info → pymc_extras-0.4.1.dist-info}/WHEEL +0 -0
- {pymc_extras-0.3.1.dist-info → pymc_extras-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from pytensor import tensor as pt
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def order_to_mask(order):
|
|
7
|
+
if isinstance(order, int):
|
|
8
|
+
return np.ones(order).astype(bool)
|
|
9
|
+
else:
|
|
10
|
+
return np.array(order).astype(bool)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _frequency_transition_block(s, j):
|
|
14
|
+
lam = 2 * np.pi * j / s
|
|
15
|
+
|
|
16
|
+
return pt.stack([[pt.cos(lam), pt.sin(lam)], [-pt.sin(lam), pt.cos(lam)]])
|
|
@@ -1,6 +1,10 @@
|
|
|
1
|
+
from typing import cast as type_cast
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pytensor.tensor as pt
|
|
3
5
|
|
|
6
|
+
from pytensor.tensor import TensorVariable
|
|
7
|
+
|
|
4
8
|
from pymc_extras.statespace.utils.constants import (
|
|
5
9
|
ALL_STATE_AUX_DIM,
|
|
6
10
|
ALL_STATE_DIM,
|
|
@@ -374,6 +378,287 @@ def conform_time_varying_and_time_invariant_matrices(A, B):
|
|
|
374
378
|
return A, B
|
|
375
379
|
|
|
376
380
|
|
|
381
|
+
def normalize_axis(x, axis):
|
|
382
|
+
"""
|
|
383
|
+
Convert negative axis values to positive axis values
|
|
384
|
+
"""
|
|
385
|
+
if isinstance(axis, tuple):
|
|
386
|
+
return tuple([normalize_axis(x, i) for i in axis])
|
|
387
|
+
if axis < 0:
|
|
388
|
+
axis = x.ndim + axis
|
|
389
|
+
return axis
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def reorder_from_labels(
|
|
393
|
+
x: TensorVariable,
|
|
394
|
+
labels: list[str],
|
|
395
|
+
ordered_labels: list[str],
|
|
396
|
+
labeled_axis: int | tuple[int, int],
|
|
397
|
+
) -> TensorVariable:
|
|
398
|
+
"""
|
|
399
|
+
Reorder an input tensor along request axis/axes based on lists of string labels
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
x: TensorVariable
|
|
404
|
+
Input tensor
|
|
405
|
+
labels: list of str
|
|
406
|
+
Labels associated with values of the input tensor ``x``, along the ``labeled_axis``. At runtime, should have
|
|
407
|
+
``x.shape[labeled_axis] == len(labels)``
|
|
408
|
+
ordered_labels: list of str
|
|
409
|
+
Target ordering according to which ``x`` will be reordered.
|
|
410
|
+
labeled_axis: int or tuple of int
|
|
411
|
+
Axis along which ``x`` will be labeled. If a tuple, each axis will be assumed to have identical labels, and
|
|
412
|
+
and reorganization will be done on all requested axes together (NOT fancy indexing!)
|
|
413
|
+
|
|
414
|
+
Returns
|
|
415
|
+
-------
|
|
416
|
+
x_sorted: TensorVariable
|
|
417
|
+
Output tensor sorted along ``labeled_axis`` according to ``ordered_labels``
|
|
418
|
+
"""
|
|
419
|
+
n_out = len(ordered_labels)
|
|
420
|
+
label_to_index = {label: index for index, label in enumerate(ordered_labels)}
|
|
421
|
+
|
|
422
|
+
missing_labels = [label for label in ordered_labels if label not in labels]
|
|
423
|
+
indices = np.argsort([label_to_index[label] for label in [*labels, *missing_labels]])
|
|
424
|
+
|
|
425
|
+
if isinstance(labeled_axis, int):
|
|
426
|
+
labeled_axis = (labeled_axis,)
|
|
427
|
+
|
|
428
|
+
if indices.tolist() != list(range(n_out)):
|
|
429
|
+
for axis in labeled_axis:
|
|
430
|
+
idx = np.s_[tuple([slice(None, None) if i != axis else indices for i in range(x.ndim)])]
|
|
431
|
+
shape = x.type.shape
|
|
432
|
+
x = pt.specify_shape(x[idx], shape)
|
|
433
|
+
|
|
434
|
+
return x
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def pad_and_reorder(
|
|
438
|
+
x: TensorVariable, labels: list[str], ordered_labels: list[str], labeled_axis: int
|
|
439
|
+
) -> TensorVariable:
|
|
440
|
+
"""
|
|
441
|
+
Pad input tensor ``x`` along the `labeled_axis` to match the length of ``ordered_labels``, then reorder the
|
|
442
|
+
padded dimension to match the ordering in ``ordered_labels``.
|
|
443
|
+
|
|
444
|
+
Parameters
|
|
445
|
+
----------
|
|
446
|
+
x: TensorVariable
|
|
447
|
+
Input tensor
|
|
448
|
+
labels: list of str
|
|
449
|
+
String labels associated with the `x` tensor at the ``labeled_axis`` dimension. At runtime, should have
|
|
450
|
+
``x.shape[labeled_axis] == len(labels)``. ``labels`` should be a subset of ``ordered_labels``.
|
|
451
|
+
ordered_labels: list of str
|
|
452
|
+
Target ordering according to which ``x`` will be reordered.
|
|
453
|
+
labeled_axis: int
|
|
454
|
+
Axis along which ``x`` will be labeled.
|
|
455
|
+
|
|
456
|
+
Returns
|
|
457
|
+
-------
|
|
458
|
+
x_padded: TensorVariable
|
|
459
|
+
Output tensor padded along ``labeled_axis`` according to ``ordered_labels``, then reordered.
|
|
460
|
+
|
|
461
|
+
"""
|
|
462
|
+
n_out = len(ordered_labels)
|
|
463
|
+
n_missing = n_out - len(labels)
|
|
464
|
+
|
|
465
|
+
if n_missing > 0:
|
|
466
|
+
zeros = pt.zeros(
|
|
467
|
+
tuple([x.shape[i] if i != labeled_axis else n_missing for i in range(x.ndim)])
|
|
468
|
+
)
|
|
469
|
+
x_padded = pt.concatenate([x, zeros], axis=labeled_axis)
|
|
470
|
+
else:
|
|
471
|
+
x_padded = x
|
|
472
|
+
|
|
473
|
+
return reorder_from_labels(x_padded, labels, ordered_labels, labeled_axis)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def ndim_pad_and_reorder(
|
|
477
|
+
x: TensorVariable,
|
|
478
|
+
labels: list[str],
|
|
479
|
+
ordered_labels: list[str],
|
|
480
|
+
labeled_axis: int | tuple[int, int],
|
|
481
|
+
) -> TensorVariable:
|
|
482
|
+
"""
|
|
483
|
+
Pad input tensor ``x`` along the `labeled_axis` to match the length of ``ordered_labels``, then reorder the
|
|
484
|
+
padded dimension to match the ordering in ``ordered_labels``.
|
|
485
|
+
|
|
486
|
+
Unlike ``pad_and_reorder``, this function allows padding and reordering to be done simultaneously on multiple
|
|
487
|
+
axes. In this case, reordering is done jointly on all axes -- it does *not* use fancy indexing.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
x: TensorVariable
|
|
492
|
+
Input tensor
|
|
493
|
+
labels: list of str
|
|
494
|
+
Labels associated with values of the input tensor ``x``, along the ``labeled_axis``. At runtime, should have
|
|
495
|
+
``x.shape[labeled_axis] == len(labels)``. If ``labeled_axis`` is a tuple, all axes are assumed to have the
|
|
496
|
+
same labels.
|
|
497
|
+
ordered_labels: list of str
|
|
498
|
+
Target ordering according to which ``x`` will be reordered. ``labels`` should be a subset of ``ordered_labels``.
|
|
499
|
+
labeled_axis: int or tuple of int
|
|
500
|
+
Axis along which ``x`` will be labeled. If a tuple, each axis will be assumed to have identical labels, and
|
|
501
|
+
and reorganization will be done on all requested axes together (NOT fancy indexing!)
|
|
502
|
+
|
|
503
|
+
Returns
|
|
504
|
+
-------
|
|
505
|
+
x_sorted: TensorVariable
|
|
506
|
+
Output tensor. Each ``labeled_axis`` is padded to the length of ``ordered_labels``, then reordered.
|
|
507
|
+
"""
|
|
508
|
+
n_missing = len(ordered_labels) - len(labels)
|
|
509
|
+
|
|
510
|
+
if isinstance(labeled_axis, int):
|
|
511
|
+
labeled_axis = (labeled_axis,)
|
|
512
|
+
|
|
513
|
+
if n_missing > 0:
|
|
514
|
+
pad_size = [(0, 0) if i not in labeled_axis else (0, n_missing) for i in range(x.ndim)]
|
|
515
|
+
for axis, (_, after) in enumerate(pad_size):
|
|
516
|
+
if after > 0:
|
|
517
|
+
shape = list(x.type.shape)
|
|
518
|
+
shape[axis] = after
|
|
519
|
+
zero_shape = [
|
|
520
|
+
static_shape if static_shape is not None else x.shape[i]
|
|
521
|
+
for i, static_shape in enumerate(shape)
|
|
522
|
+
]
|
|
523
|
+
x = pt.join(axis, x, pt.zeros(zero_shape))
|
|
524
|
+
|
|
525
|
+
return reorder_from_labels(x, labels, ordered_labels, labeled_axis)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def add_tensors_by_dim_labels(
|
|
529
|
+
tensor: TensorVariable,
|
|
530
|
+
other_tensor: TensorVariable,
|
|
531
|
+
labels: list[str],
|
|
532
|
+
other_labels: list[str],
|
|
533
|
+
labeled_axis: int | tuple[int, int] = -1,
|
|
534
|
+
) -> TensorVariable:
|
|
535
|
+
"""
|
|
536
|
+
Add two tensors based on labels associated with one dimension.
|
|
537
|
+
|
|
538
|
+
When combining statespace matrices associated with structural components with potentially different states, it is
|
|
539
|
+
important to make sure that duplicated states are handled correctly. For bias vectors and covariance matrices,
|
|
540
|
+
duplicated states should be summed.
|
|
541
|
+
|
|
542
|
+
When a state appears in one component but not another, that state should be treated as an implicit zero in the
|
|
543
|
+
components where the state does not appear. This amounts to padding the relevant matrices with zeros before
|
|
544
|
+
performing the addition.
|
|
545
|
+
|
|
546
|
+
When labeled_axis is a tuple, each provided label is assumed to be identically labeled in each input tensor. This
|
|
547
|
+
is the case, for example, when working with a covariance matrix. In this case, padding and alignment will be
|
|
548
|
+
done on each indicated index.
|
|
549
|
+
|
|
550
|
+
Parameters
|
|
551
|
+
----------
|
|
552
|
+
tensor: TensorVariable
|
|
553
|
+
A statespace matrix to be summed with ``other_tensor``.
|
|
554
|
+
other_tensor: TensorVariable
|
|
555
|
+
A statespace matrix to be summed with ``tensor``.
|
|
556
|
+
labels: list of str
|
|
557
|
+
Dimension labels associated with ``tensor``, on the ``labeled_axis`` dimension.
|
|
558
|
+
other_labels: list of str
|
|
559
|
+
Dimension labels associated with ``other_tensor``, on the ``labeled_axis`` dimension.
|
|
560
|
+
labeled_axis: int or tuple of int
|
|
561
|
+
Dimension that is labeled by ``labels`` and ``other_labels``. ``tensor.shape[labeled_axis]`` must have the
|
|
562
|
+
shape of ``len(labels)`` at runtime.
|
|
563
|
+
|
|
564
|
+
Returns
|
|
565
|
+
-------
|
|
566
|
+
result: TensorVariable
|
|
567
|
+
Result of addition of ``tensor`` and ``other_tensor``, along the ``labeled_axis`` dimension. The ordering of
|
|
568
|
+
the output will be ``labels + [label for label in other_labels if label not in labels]``. That is, ``labels``
|
|
569
|
+
come first, followed by any new labels introduced by ``other_labels``.
|
|
570
|
+
|
|
571
|
+
"""
|
|
572
|
+
labeled_axis = normalize_axis(tensor, labeled_axis)
|
|
573
|
+
new_labels = [label for label in other_labels if label not in labels]
|
|
574
|
+
combined_labels = type_cast(list[str], [*labels, *new_labels])
|
|
575
|
+
|
|
576
|
+
# If there is no overlap at all, directly concatenate the two matrices -- there's no need to worry about the order
|
|
577
|
+
# of things, or padding. This is equivalent to padding both out with zeros then adding them.
|
|
578
|
+
if combined_labels == [*labels, *other_labels]:
|
|
579
|
+
if isinstance(labeled_axis, int):
|
|
580
|
+
return pt.concatenate([tensor, other_tensor], axis=labeled_axis)
|
|
581
|
+
else:
|
|
582
|
+
# In the case where we want to align multiple dimensions, use block_diag to accomplish padding on the last
|
|
583
|
+
# two dimensions
|
|
584
|
+
dims = [*[i for i in range(tensor.ndim) if i not in labeled_axis], *labeled_axis]
|
|
585
|
+
return pt.linalg.block_diag(
|
|
586
|
+
type_cast(TensorVariable, tensor.transpose(*dims)),
|
|
587
|
+
type_cast(TensorVariable, other_tensor.transpose(*dims)),
|
|
588
|
+
)
|
|
589
|
+
# Otherwise, there are two possibilities. If all labels are the same, we might need to re-order one or both to get
|
|
590
|
+
# them to agree. If *some* labels are the same, we will need to pad first, then potentially re-order. In any case,
|
|
591
|
+
# the final step is just to add the padded and re-ordered tensors.
|
|
592
|
+
fn = pad_and_reorder if isinstance(labeled_axis, int) else ndim_pad_and_reorder
|
|
593
|
+
|
|
594
|
+
padded_tensor = fn(
|
|
595
|
+
tensor,
|
|
596
|
+
labels=type_cast(list[str], labels),
|
|
597
|
+
ordered_labels=combined_labels,
|
|
598
|
+
labeled_axis=labeled_axis,
|
|
599
|
+
)
|
|
600
|
+
padded_tensor.name = tensor.name
|
|
601
|
+
|
|
602
|
+
padded_other_tensor = fn(
|
|
603
|
+
other_tensor,
|
|
604
|
+
labels=type_cast(list[str], other_labels),
|
|
605
|
+
ordered_labels=combined_labels,
|
|
606
|
+
labeled_axis=labeled_axis,
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
padded_other_tensor.name = other_tensor.name
|
|
610
|
+
|
|
611
|
+
return padded_tensor + padded_other_tensor
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def join_tensors_by_dim_labels(
|
|
615
|
+
tensor: TensorVariable,
|
|
616
|
+
other_tensor: TensorVariable,
|
|
617
|
+
labels: list[str],
|
|
618
|
+
other_labels: list[str],
|
|
619
|
+
labeled_axis: int = -1,
|
|
620
|
+
join_axis: int = -1,
|
|
621
|
+
block_diag_join: bool = False,
|
|
622
|
+
) -> TensorVariable:
|
|
623
|
+
labeled_axis = normalize_axis(tensor, labeled_axis)
|
|
624
|
+
new_labels = [label for label in other_labels if label not in labels]
|
|
625
|
+
combined_labels = [*labels, *new_labels]
|
|
626
|
+
|
|
627
|
+
# Check for no overlap first. In this case, do a block_diagonal join, which implicitly results in padding zeros
|
|
628
|
+
# everywhere they are needed -- no other sorting or padding necessary
|
|
629
|
+
if combined_labels == [*labels, *other_labels]:
|
|
630
|
+
res = pt.linalg.block_diag(tensor, other_tensor)
|
|
631
|
+
new_shape = [
|
|
632
|
+
shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None) else None
|
|
633
|
+
for shape_1, shape_2 in zip(tensor.type.shape, other_tensor.type.shape)
|
|
634
|
+
]
|
|
635
|
+
return pt.specify_shape(res, new_shape)
|
|
636
|
+
|
|
637
|
+
# Otherwise there is either total overlap or partial overlap. Let the padding and reordering function figure it out.
|
|
638
|
+
tensor = ndim_pad_and_reorder(tensor, labels, combined_labels, labeled_axis)
|
|
639
|
+
other_tensor = ndim_pad_and_reorder(other_tensor, other_labels, combined_labels, labeled_axis)
|
|
640
|
+
|
|
641
|
+
if block_diag_join:
|
|
642
|
+
new_shape = [
|
|
643
|
+
shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None) else None
|
|
644
|
+
for shape_1, shape_2 in zip(tensor.type.shape, other_tensor.type.shape)
|
|
645
|
+
]
|
|
646
|
+
res = pt.linalg.block_diag(tensor, other_tensor)
|
|
647
|
+
else:
|
|
648
|
+
new_shape = []
|
|
649
|
+
join_axis_norm = normalize_axis(tensor, join_axis)
|
|
650
|
+
for i, (shape_1, shape_2) in enumerate(zip(tensor.type.shape, other_tensor.type.shape)):
|
|
651
|
+
if i == join_axis_norm:
|
|
652
|
+
new_shape.append(
|
|
653
|
+
shape_1 + shape_2 if (shape_1 is not None and shape_2 is not None) else None
|
|
654
|
+
)
|
|
655
|
+
else:
|
|
656
|
+
new_shape.append(shape_1 if shape_1 is not None else shape_2)
|
|
657
|
+
res = pt.concatenate([tensor, other_tensor], axis=join_axis)
|
|
658
|
+
|
|
659
|
+
return pt.specify_shape(res, new_shape)
|
|
660
|
+
|
|
661
|
+
|
|
377
662
|
def get_exog_dims_from_idata(exog_name, idata):
|
|
378
663
|
if exog_name in idata.posterior.data_vars:
|
|
379
664
|
exog_dims = idata.posterior[exog_name].dims[2:]
|
|
@@ -7,11 +7,12 @@ OBS_STATE_AUX_DIM = "observed_state_aux"
|
|
|
7
7
|
SHOCK_DIM = "shock"
|
|
8
8
|
SHOCK_AUX_DIM = "shock_aux"
|
|
9
9
|
TIME_DIM = "time"
|
|
10
|
-
AR_PARAM_DIM = "
|
|
11
|
-
MA_PARAM_DIM = "
|
|
12
|
-
SEASONAL_AR_PARAM_DIM = "
|
|
13
|
-
SEASONAL_MA_PARAM_DIM = "
|
|
10
|
+
AR_PARAM_DIM = "lag_ar"
|
|
11
|
+
MA_PARAM_DIM = "lag_ma"
|
|
12
|
+
SEASONAL_AR_PARAM_DIM = "seasonal_lag_ar"
|
|
13
|
+
SEASONAL_MA_PARAM_DIM = "seasonal_lag_ma"
|
|
14
14
|
ETS_SEASONAL_DIM = "seasonal_lag"
|
|
15
|
+
EXOGENOUS_DIM = "exogenous"
|
|
15
16
|
|
|
16
17
|
NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]
|
|
17
18
|
VECTOR_VALUED = ["initial_state", "state_intercept", "obs_intercept", "a0", "c", "d"]
|
|
@@ -38,14 +39,16 @@ SHORT_NAME_TO_LONG = dict(zip(MATRIX_NAMES, LONG_MATRIX_NAMES))
|
|
|
38
39
|
LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES))
|
|
39
40
|
|
|
40
41
|
FILTER_OUTPUT_NAMES = [
|
|
41
|
-
"
|
|
42
|
-
"
|
|
43
|
-
"
|
|
44
|
-
"
|
|
42
|
+
"filtered_states",
|
|
43
|
+
"predicted_states",
|
|
44
|
+
"filtered_covariances",
|
|
45
|
+
"predicted_covariances",
|
|
46
|
+
"predicted_observed_states",
|
|
47
|
+
"predicted_observed_covariances",
|
|
45
48
|
]
|
|
46
49
|
|
|
47
|
-
SMOOTHER_OUTPUT_NAMES = ["
|
|
48
|
-
OBSERVED_OUTPUT_NAMES = ["
|
|
50
|
+
SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"]
|
|
51
|
+
OBSERVED_OUTPUT_NAMES = ["predicted_observed_states", "predicted_observed_covariances"]
|
|
49
52
|
|
|
50
53
|
MATRIX_DIMS = {
|
|
51
54
|
"x0": (ALL_STATE_DIM,),
|
|
@@ -60,14 +63,14 @@ MATRIX_DIMS = {
|
|
|
60
63
|
}
|
|
61
64
|
|
|
62
65
|
FILTER_OUTPUT_DIMS = {
|
|
63
|
-
"
|
|
64
|
-
"
|
|
65
|
-
"
|
|
66
|
-
"
|
|
67
|
-
"
|
|
68
|
-
"
|
|
69
|
-
"
|
|
70
|
-
"
|
|
66
|
+
"filtered_states": (TIME_DIM, ALL_STATE_DIM),
|
|
67
|
+
"smoothed_states": (TIME_DIM, ALL_STATE_DIM),
|
|
68
|
+
"predicted_states": (TIME_DIM, ALL_STATE_DIM),
|
|
69
|
+
"filtered_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
70
|
+
"smoothed_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
71
|
+
"predicted_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
|
|
72
|
+
"predicted_observed_states": (TIME_DIM, OBS_STATE_DIM),
|
|
73
|
+
"predicted_observed_covariances": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
|
|
71
74
|
}
|
|
72
75
|
|
|
73
76
|
POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]
|
|
@@ -53,7 +53,7 @@ def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=Fal
|
|
|
53
53
|
if len(missing_cols) > 0:
|
|
54
54
|
raise ValueError(
|
|
55
55
|
"Columns of DataFrame provided as data do not match state names. The following states were"
|
|
56
|
-
f
|
|
56
|
+
f"not found: {', '.join(missing_cols)}. This may result in unexpected results in complex"
|
|
57
57
|
f"statespace models"
|
|
58
58
|
)
|
|
59
59
|
|
|
@@ -141,9 +141,10 @@ def add_data_to_active_model(values, index, data_dims=None):
|
|
|
141
141
|
|
|
142
142
|
# If the data has just one column, we need to specify the shape as (None, 1), or else the JAX backend will
|
|
143
143
|
# raise a broadcasting error.
|
|
144
|
-
|
|
145
|
-
if values.shape[-1] == 1:
|
|
144
|
+
if values.shape[-1] == 1 or values.ndim == 1:
|
|
146
145
|
data_shape = (None, 1)
|
|
146
|
+
else:
|
|
147
|
+
data_shape = (None, values.shape[-1])
|
|
147
148
|
|
|
148
149
|
data = pm.Data("data", values, dims=data_dims, shape=data_shape)
|
|
149
150
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Project-URL: Documentation, https://pymc-extras.readthedocs.io/
|
|
6
6
|
Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
|
|
@@ -232,9 +232,11 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
232
232
|
Classifier: Topic :: Scientific/Engineering
|
|
233
233
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
234
234
|
Requires-Python: >=3.11
|
|
235
|
-
Requires-Dist: better-optimize>=0.1.
|
|
235
|
+
Requires-Dist: better-optimize>=0.1.5
|
|
236
|
+
Requires-Dist: preliz>=0.20.0
|
|
236
237
|
Requires-Dist: pydantic>=2.0.0
|
|
237
|
-
Requires-Dist: pymc>=5.
|
|
238
|
+
Requires-Dist: pymc>=5.24.1
|
|
239
|
+
Requires-Dist: pytensor>=2.31.4
|
|
238
240
|
Requires-Dist: scikit-learn
|
|
239
241
|
Provides-Extra: complete
|
|
240
242
|
Requires-Dist: dask[complete]<2025.1.1; extra == 'complete'
|
|
@@ -245,7 +247,6 @@ Requires-Dist: xhistogram; extra == 'dask-histogram'
|
|
|
245
247
|
Provides-Extra: dev
|
|
246
248
|
Requires-Dist: blackjax; extra == 'dev'
|
|
247
249
|
Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
|
|
248
|
-
Requires-Dist: preliz>=0.5.0; extra == 'dev'
|
|
249
250
|
Requires-Dist: pytest-mock; extra == 'dev'
|
|
250
251
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
251
252
|
Requires-Dist: statsmodels; extra == 'dev'
|
|
@@ -2,12 +2,12 @@ pymc_extras/__init__.py,sha256=YsR6OG72aW73y6dGS7w3nGGMV-V-ImHkmUOXKMPfMRA,1230
|
|
|
2
2
|
pymc_extras/deserialize.py,sha256=dktK5gsR96X3zAUoRF5udrTiconknH3uupiAWqkZi0M,5937
|
|
3
3
|
pymc_extras/linearmodel.py,sha256=KkvZ_DBXOD6myPgVNzu742YV0OzDK449_pDqNC5yae4,3975
|
|
4
4
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
5
|
-
pymc_extras/printing.py,sha256=
|
|
5
|
+
pymc_extras/printing.py,sha256=bFOANgsOWDk0vbRMvm2h_D5TsT7OiSojdG7tvyfCw28,6506
|
|
6
6
|
pymc_extras/prior.py,sha256=0XbyRRVuS7aKY5gmvJr_iq4fGyHrRDeI_OjWu_O7CTA,39449
|
|
7
|
-
pymc_extras/distributions/__init__.py,sha256=
|
|
7
|
+
pymc_extras/distributions/__init__.py,sha256=Cge3AP7gzD6qTJY7v2tYRtSgn-rlnIo7wQBgf3IfKQ8,1377
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=530wvcO-QcYVdiVN-iQRveImWfyJzzmxiZLMVShP7w4,11251
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=HNi-K0_hnNWTcfyBkWGh26sc71FwBgukQ_EjGAaAOjY,13036
|
|
10
|
-
pymc_extras/distributions/histogram_utils.py,sha256=
|
|
10
|
+
pymc_extras/distributions/histogram_utils.py,sha256=xvCc19nlOmeb9PLZDcsR5PRdmcr5sRefZlPlCvxmGfM,5814
|
|
11
11
|
pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
|
|
12
12
|
pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
|
|
13
13
|
pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
|
|
@@ -15,17 +15,17 @@ pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-
|
|
|
15
15
|
pymc_extras/distributions/transforms/partial_order.py,sha256=oEZlc9WgnGR46uFEjLzKEUxlhzIo2vrUUbBE3vYrsfQ,8404
|
|
16
16
|
pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
|
|
17
17
|
pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
|
|
18
|
-
pymc_extras/inference/__init__.py,sha256=
|
|
18
|
+
pymc_extras/inference/__init__.py,sha256=sy1JYQGNZNvPs-3jVFfbFQTW0iCIrbjH3aHBpx1HQi0,917
|
|
19
19
|
pymc_extras/inference/fit.py,sha256=U_jfzuyjk5bV6AvOxtOKzBg-q4z-_BOR06Hn38T0W6E,1328
|
|
20
20
|
pymc_extras/inference/laplace_approx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
|
-
pymc_extras/inference/laplace_approx/find_map.py,sha256=
|
|
22
|
-
pymc_extras/inference/laplace_approx/idata.py,sha256=
|
|
23
|
-
pymc_extras/inference/laplace_approx/laplace.py,sha256=
|
|
21
|
+
pymc_extras/inference/laplace_approx/find_map.py,sha256=fP8DQ21OZbkUiBaq-TXGe7CtH0umupFacRC3qReoiKU,14022
|
|
22
|
+
pymc_extras/inference/laplace_approx/idata.py,sha256=P_GyodNJy2yr6FBYBqSoMShW2CKKuljBTFY1jOAHEKE,13332
|
|
23
|
+
pymc_extras/inference/laplace_approx/laplace.py,sha256=V49TdsCYGxt7Evg7Ml2qtHW0xeZYP5YjCOBaewTvJog,18778
|
|
24
24
|
pymc_extras/inference/laplace_approx/scipy_interface.py,sha256=qMxYodmmxaUGsOp1jc7HxBJc6L8NnmFT2Fd4UNNXu2c,8835
|
|
25
25
|
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
26
26
|
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
|
|
27
27
|
pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
|
|
28
|
-
pymc_extras/inference/pathfinder/pathfinder.py,sha256=
|
|
28
|
+
pymc_extras/inference/pathfinder/pathfinder.py,sha256=wVDbyvE97iqiYLDHLfnl1MFtDdmEmaI5XZS3Lr6f9sE,64475
|
|
29
29
|
pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
|
|
30
30
|
pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
|
|
31
31
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -38,32 +38,41 @@ pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
|
|
|
38
38
|
pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
|
|
39
39
|
pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
40
|
pymc_extras/preprocessing/standard_scaler.py,sha256=Vajp33ma6OkwlU54JYtSS8urHbMJ3CRiRFxZpvFNuus,600
|
|
41
|
-
pymc_extras/statespace/__init__.py,sha256=
|
|
42
|
-
pymc_extras/statespace/core/__init__.py,sha256=
|
|
41
|
+
pymc_extras/statespace/__init__.py,sha256=PxV8i4aa2XJarRM6aKU14_bEY1AoLu4bNXIBy_E1rRw,431
|
|
42
|
+
pymc_extras/statespace/core/__init__.py,sha256=LEhkqdMZzzcTyzYml45IM4ykWoCdbWWj2c29IpM_ey8,309
|
|
43
43
|
pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
|
|
44
|
-
pymc_extras/statespace/core/representation.py,sha256=
|
|
45
|
-
pymc_extras/statespace/core/statespace.py,sha256=
|
|
46
|
-
pymc_extras/statespace/filters/__init__.py,sha256=
|
|
44
|
+
pymc_extras/statespace/core/representation.py,sha256=boY-jjlkd3KuuO2XiSuV-GwEAyEqRJ9267H72AmE3BU,18956
|
|
45
|
+
pymc_extras/statespace/core/statespace.py,sha256=yu7smA5w7l1LFNjTwuKLnGarGLx4HEPJKQ9ZMDbWhDY,108161
|
|
46
|
+
pymc_extras/statespace/filters/__init__.py,sha256=F0EtZUhArp23lj3upy6zB0mDTjLIjwGh0pKmMny0QfY,420
|
|
47
47
|
pymc_extras/statespace/filters/distributions.py,sha256=-s1c5s2zm6FMc0UqKSrWnJzIF4U5bvJT_3mMNTyV_ak,11927
|
|
48
|
-
pymc_extras/statespace/filters/kalman_filter.py,sha256=
|
|
48
|
+
pymc_extras/statespace/filters/kalman_filter.py,sha256=rgpgF4KZXX5M8yRwblrt2SEINKgoXgiKNfKkbl7ZU9Y,31464
|
|
49
49
|
pymc_extras/statespace/filters/kalman_smoother.py,sha256=5jlSZAPveJzD5Q8omnpn7Gb1jgElBMgixGR7H9zoH8U,4183
|
|
50
50
|
pymc_extras/statespace/filters/utilities.py,sha256=iwdaYnO1cO06t_XUjLLRmqb8vwzzVH6Nx1iyZcbJL2k,1584
|
|
51
51
|
pymc_extras/statespace/models/ETS.py,sha256=08sbiuNvKdxcgKzS7jWj-z4jf-su73WFkYc8sKkGdEs,28538
|
|
52
|
-
pymc_extras/statespace/models/SARIMAX.py,sha256=
|
|
53
|
-
pymc_extras/statespace/models/VARMAX.py,sha256=
|
|
54
|
-
pymc_extras/statespace/models/__init__.py,sha256=
|
|
55
|
-
pymc_extras/statespace/models/
|
|
56
|
-
pymc_extras/statespace/models/
|
|
52
|
+
pymc_extras/statespace/models/SARIMAX.py,sha256=Yppz_k1ZyZuKPC62WIye6K7luw44cP-dog73VVkw0L4,25096
|
|
53
|
+
pymc_extras/statespace/models/VARMAX.py,sha256=7obJFXES9t9NONlcUQoeJ9TCqyoDlVat9FkPviQhAq0,25947
|
|
54
|
+
pymc_extras/statespace/models/__init__.py,sha256=DUwPrwfnz9AUbgZOFvZeUpWEw5FiPAK5X9x7vZrRWqY,319
|
|
55
|
+
pymc_extras/statespace/models/utilities.py,sha256=jpUYByAy6rMFP7l56uST1SEYchRa-clsFQ-At_1NLSw,27123
|
|
56
|
+
pymc_extras/statespace/models/structural/__init__.py,sha256=jvbczE1IeNkhW7gMQ2vF2BhhKHeYyfD90mV-Awko-Vs,811
|
|
57
|
+
pymc_extras/statespace/models/structural/core.py,sha256=n0cbP8_-NFLmflFF4x37AyOOIHcY5iylRrgTzjyOAhM,35374
|
|
58
|
+
pymc_extras/statespace/models/structural/utils.py,sha256=Eze34Z0iXJzDC_gZEY2mHrp2VIYu8rHV915vM4U5Sn4,359
|
|
59
|
+
pymc_extras/statespace/models/structural/components/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
60
|
+
pymc_extras/statespace/models/structural/components/autoregressive.py,sha256=HkS5an5fuNOBGcjHFNMUVNJrF1BNnlpQxvmPq_5dD0s,8021
|
|
61
|
+
pymc_extras/statespace/models/structural/components/cycle.py,sha256=qEiGFGMEXKS2Tl_zgzKIp77ijGXCVq6UIHEZp_ErHSQ,13931
|
|
62
|
+
pymc_extras/statespace/models/structural/components/level_trend.py,sha256=7glYX_tKOJPq6uB1NBuPQFFZGkhcwK4GMZUBTcU0xIY,11357
|
|
63
|
+
pymc_extras/statespace/models/structural/components/measurement_error.py,sha256=5LHDx3IplNrWSGcsY3xJLywKPosTqr42jlrvm80ZApM,5316
|
|
64
|
+
pymc_extras/statespace/models/structural/components/regression.py,sha256=27PRV9I64_VXIyjUi7pRr_gbk7sSI5DfJ4FBAbq5WCM,9856
|
|
65
|
+
pymc_extras/statespace/models/structural/components/seasonality.py,sha256=soXJIZ2xewUhSUb5s2MGnxvnQCcir7ZgbgkSr94xEvc,26987
|
|
57
66
|
pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
58
|
-
pymc_extras/statespace/utils/constants.py,sha256
|
|
67
|
+
pymc_extras/statespace/utils/constants.py,sha256=-4vCXo7-X3IuzdcplWBrAV9m9tm8JngcgoE-8imGmj0,2518
|
|
59
68
|
pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
60
|
-
pymc_extras/statespace/utils/data_tools.py,sha256=
|
|
69
|
+
pymc_extras/statespace/utils/data_tools.py,sha256=Tomur7d8WCKlMXUCrPqufqVTKUe_nLLCHdipsM9pmaI,6620
|
|
61
70
|
pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
|
|
62
71
|
pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
|
|
63
72
|
pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
|
|
64
73
|
pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
|
|
65
74
|
pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
|
|
66
|
-
pymc_extras-0.
|
|
67
|
-
pymc_extras-0.
|
|
68
|
-
pymc_extras-0.
|
|
69
|
-
pymc_extras-0.
|
|
75
|
+
pymc_extras-0.4.1.dist-info/METADATA,sha256=TpuX_8nEFjQfPlC51u_2EvQV3XwHAvgYCQMKYzeVU_E,18898
|
|
76
|
+
pymc_extras-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
77
|
+
pymc_extras-0.4.1.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
78
|
+
pymc_extras-0.4.1.dist-info/RECORD,,
|