pymc-extras 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -63,7 +63,7 @@ from rich.text import Text
63
63
  # TODO: change to typing.Self after Python versions greater than 3.10
64
64
  from typing_extensions import Self
65
65
 
66
- from pymc_extras.inference.laplace import add_data_to_inferencedata
66
+ from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
67
67
  from pymc_extras.inference.pathfinder.importance_sampling import (
68
68
  importance_sampling as _importance_sampling,
69
69
  )
@@ -1759,6 +1759,6 @@ def fit_pathfinder(
1759
1759
  importance_sampling=importance_sampling,
1760
1760
  )
1761
1761
 
1762
- idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1762
+ idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs)
1763
1763
 
1764
1764
  return idata
@@ -2,10 +2,12 @@ import numpy as np
2
2
  import pandas as pd
3
3
  import pymc as pm
4
4
 
5
+ from sklearn.base import BaseEstimator
6
+
5
7
  from pymc_extras.model_builder import ModelBuilder
6
8
 
7
9
 
8
- class LinearModel(ModelBuilder):
10
+ class LinearModel(ModelBuilder, BaseEstimator):
9
11
  def __init__(
10
12
  self, model_config: dict | None = None, sampler_config: dict | None = None, nsamples=100
11
13
  ):
@@ -5,6 +5,7 @@ from itertools import zip_longest
5
5
 
6
6
  from pymc import SymbolicRandomVariable
7
7
  from pymc.model.fgraph import ModelVar
8
+ from pymc.variational.minibatch_rv import MinibatchRandomVariable
8
9
  from pytensor.graph import Variable, ancestors
9
10
  from pytensor.graph.basic import io_toposort
10
11
  from pytensor.tensor import TensorType, TensorVariable
@@ -313,6 +314,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
313
314
 
314
315
  var_dims[node.outputs[0]] = output_dims
315
316
 
317
+ elif isinstance(node.op, MinibatchRandomVariable):
318
+ var_dims[node.outputs[0]] = inputs_dims[0]
319
+
316
320
  else:
317
321
  raise NotImplementedError(f"Marginalization through operation {node} not supported.")
318
322