aimz 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimz/__init__.py ADDED
@@ -0,0 +1,26 @@
1
+ # Copyright 2025 Eli Lilly and Company
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Initialize package with global setup."""
16
+
17
+ import logging
18
+ import sys
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ if not logger.hasHandlers():
23
+ logger.setLevel(logging.INFO)
24
+ handler = logging.StreamHandler(sys.stdout)
25
+ logger.addHandler(handler)
26
+ logging.getLogger("py.warnings").addHandler(handler)
aimz/_exceptions.py ADDED
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Eli Lilly and Company
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Custom warnings and errors."""
16
+
17
+ __all__ = ["KernelValidationError", "NotFittedError"]
18
+
19
+
20
+ class NotFittedError(ValueError, AttributeError):
21
+ """Exception class to raise if model is used before fitting."""
22
+
23
+
24
+ class KernelValidationError(Exception):
25
+ """Exception class to raise if kernel validation fails."""
aimz/data/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Eli Lilly and Company
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for data handling and processing."""
16
+
17
+ from aimz.data.array_loader import ArrayLoader
18
+
19
+ __all__ = ["ArrayLoader"]
aimz/data/_sharding.py ADDED
@@ -0,0 +1,230 @@
1
+ # Copyright 2025 Eli Lilly and Company
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for creating functions for sharding.
16
+
17
+ NOTE: This module is experimental and subject to change. It utilizes JAX's `shard_map()`
18
+ to distribute computations across devices. Tested on CPU and GPU.
19
+ """
20
+
21
+ from functools import partial
22
+ from typing import TYPE_CHECKING
23
+
24
+ from jax import jit
25
+ from jax.experimental.shard_map import shard_map
26
+ from jax.sharding import PartitionSpec
27
+ from numpyro.infer import log_likelihood as log_lik
28
+
29
+ from aimz.sampling._forward import _sample_forward
30
+
31
+ if TYPE_CHECKING:
32
+ from collections.abc import Callable
33
+
34
+ import jax
35
+ from jax.sharding import Mesh
36
+
37
+
38
+ def _create_sharded_sampler(
39
+ mesh: "Mesh | None",
40
+ n_kwargs_array: int,
41
+ n_kwargs_extra: int,
42
+ ) -> "Callable":
43
+ """Create a sharded posterior predictive sampling function.
44
+
45
+ Args:
46
+ mesh (Mesh): The JAX mesh object defining the device mesh for sharding.
47
+ n_kwargs_array (int): The number of arguments in the keyword arguments that
48
+ are array-like (sharded).
49
+ n_kwargs_extra (int): The number of extra keyword arguments that are not
50
+ array-like (not sharded).
51
+
52
+ Returns:
53
+ Callable: A sharded function that takes the following arguments:
54
+ - rng_key (jax.Array): A pseudo-random number generator key.
55
+ - kernel (Callable): A probabilistic model with Pyro primitives.
56
+ - posterior_samples (dict): A dictionary of posterior samples.
57
+ - batch_shape (tuple[int]): The shape of the batch dimension, specifically
58
+ `(num_samples,)`.
59
+ - param_input (str): The name of the parameter in the `kernel` for the
60
+ input data.
61
+ - kwargs_key (tuple[str]): A tuple of keyword argument names.
62
+ - X (jax.Array): Input data.
63
+ - *args (tuple): Additional arguments constructed from the original keyword
64
+ arguments (both sharded and non-sharded).
65
+
66
+ """
67
+
68
+ def f(
69
+ kernel: "Callable",
70
+ num_samples: int,
71
+ rng_key: "jax.Array",
72
+ return_sites: tuple[str],
73
+ posterior_samples: dict[str, "jax.Array"],
74
+ param_input: str,
75
+ kwargs_key: tuple[str],
76
+ X: "jax.Array",
77
+ *args: tuple,
78
+ ) -> dict[str, "jax.Array"]:
79
+ return _sample_forward(
80
+ model=kernel,
81
+ num_samples=num_samples,
82
+ rng_key=rng_key,
83
+ return_sites=return_sites,
84
+ posterior_samples=posterior_samples,
85
+ model_kwargs={
86
+ param_input: X,
87
+ **dict(zip(kwargs_key, args, strict=True)),
88
+ },
89
+ )
90
+
91
+ if mesh is None:
92
+ return partial(
93
+ jit,
94
+ static_argnames=[
95
+ "kernel",
96
+ "num_samples",
97
+ "return_sites",
98
+ "param_input",
99
+ "kwargs_key",
100
+ ],
101
+ )(f)
102
+
103
+ (axis,) = mesh.axis_names
104
+
105
+ return partial(
106
+ jit,
107
+ static_argnames=[
108
+ "kernel",
109
+ "num_samples",
110
+ "return_sites",
111
+ "param_input",
112
+ "kwargs_key",
113
+ ],
114
+ )(
115
+ partial(
116
+ shard_map,
117
+ mesh=mesh,
118
+ in_specs=(
119
+ None, # kernel
120
+ None, # posterior_samples
121
+ None, # rng_key
122
+ None, # num_samples
123
+ None, # return_sites
124
+ None, # param_input
125
+ None, # kwargs_key
126
+ PartitionSpec(axis), # X
127
+ *(
128
+ [PartitionSpec(axis)] * n_kwargs_array # kwargs_array
129
+ + [None] * n_kwargs_extra # kwargs_extra
130
+ ),
131
+ ),
132
+ out_specs=PartitionSpec(None, axis),
133
+ check_rep=False,
134
+ )(f),
135
+ )
136
+
137
+
138
+ def _create_sharded_log_likelihood(
139
+ mesh: "Mesh | None",
140
+ n_kwargs_array: int,
141
+ n_kwargs_extra: int,
142
+ ) -> "Callable":
143
+ """Create a sharded log-likelihood function.
144
+
145
+ Args:
146
+ mesh (Mesh): The JAX mesh object defining the device mesh for sharding.
147
+ n_kwargs_array (int): The number of arguments in the keyword arguments that are
148
+ array-like (sharded).
149
+ n_kwargs_extra (int): The number of extra keyword arguments that are not
150
+ array-like (not sharded).
151
+
152
+ Returns:
153
+ Callable: A sharded function that takes the following arguments:
154
+ - kernel (Callable): A probabilistic model with Pyro primitives optimized
155
+ with variational inference.
156
+ - posterior_samples (dict): A dictionary of posterior samples.
157
+ - param_input (str): The name of the parameter in the `kernel` for the input
158
+ data.
159
+ - param_output (str): The name of the parameter in the `kernel` for the
160
+ output data.
161
+ - kwargs_key (tuple[str]): A tuple of keyword argument names.
162
+ - X (jax.Array): Input data.
163
+ - y (jax.Array): Output data.
164
+ - *args (tuple): Additional arguments constructed from the original keyword
165
+ arguments (both sharded and non-sharded).
166
+
167
+ """
168
+
169
+ def f(
170
+ kernel: "Callable",
171
+ posterior_samples: dict,
172
+ param_input: str,
173
+ param_output: str,
174
+ kwargs_key: tuple[str],
175
+ X: "jax.Array",
176
+ y: "jax.Array",
177
+ *args: tuple,
178
+ ) -> "jax.Array":
179
+ return log_lik(
180
+ kernel,
181
+ posterior_samples=posterior_samples,
182
+ **{
183
+ param_input: X,
184
+ param_output: y,
185
+ **dict(zip(kwargs_key, args, strict=True)),
186
+ },
187
+ ).get(param_output)
188
+
189
+ if mesh is None:
190
+ return partial(
191
+ jit,
192
+ static_argnames=[
193
+ "kernel",
194
+ "param_input",
195
+ "param_output",
196
+ "kwargs_key",
197
+ ],
198
+ )(f)
199
+
200
+ (axis,) = mesh.axis_names
201
+
202
+ return partial(
203
+ jit,
204
+ static_argnames=[
205
+ "kernel",
206
+ "param_input",
207
+ "param_output",
208
+ "kwargs_key",
209
+ ],
210
+ )(
211
+ partial(
212
+ shard_map,
213
+ mesh=mesh,
214
+ in_specs=(
215
+ None, # kernel
216
+ None, # posterior_samples
217
+ None, # param_input
218
+ None, # param_output
219
+ None, # kwargs_key
220
+ PartitionSpec(axis), # X
221
+ PartitionSpec(axis), # y
222
+ *(
223
+ [PartitionSpec(axis)] * n_kwargs_array # kwargs_array
224
+ + [None] * n_kwargs_extra # kwargs_extra
225
+ ),
226
+ ),
227
+ out_specs=PartitionSpec(None, axis),
228
+ check_rep=False,
229
+ )(f),
230
+ )
@@ -0,0 +1,205 @@
1
+ # Copyright 2025 Eli Lilly and Company
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for custom data loader with padding logic for JAX arrays.
16
+
17
+ This module defines a custom `ArrayLoader` that processes batches of data and applies
18
+ padding to ensure the batch size is compatible with sharding across multiple XLA
19
+ devices.
20
+ """
21
+
22
+ from typing import TYPE_CHECKING
23
+
24
+ import jax.numpy as jnp
25
+ from jax import device_put
26
+ from jax_dataloader import ArrayDataset
27
+ from torch.utils.data import DataLoader
28
+
29
+ if TYPE_CHECKING:
30
+ from collections.abc import Callable
31
+
32
+ import jax
33
+ from jax.sharding import NamedSharding
34
+ from torch.utils.data import Sampler
35
+
36
+
37
+ class ArrayLoader(DataLoader):
38
+ """A custom DataLoader class that extends PyTorch's DataLoader."""
39
+
40
+ def __init__(
41
+ self,
42
+ dataset: ArrayDataset,
43
+ *,
44
+ batch_size: int = 1,
45
+ sampler: "Sampler | None" = None,
46
+ num_workers: int = 0,
47
+ collate_fn: "Callable | None" = None,
48
+ pin_memory: bool = False,
49
+ ) -> None:
50
+ """Initializes an ArrayLoader instance."""
51
+ super().__init__(
52
+ dataset,
53
+ batch_size=batch_size,
54
+ sampler=sampler,
55
+ num_workers=num_workers,
56
+ collate_fn=collate_fn,
57
+ pin_memory=pin_memory,
58
+ )
59
+
60
+ @staticmethod
61
+ def calculate_padding(batch_size: int, num_devices: int) -> int:
62
+ """Calculate the number of padding needed.
63
+
64
+ Args:
65
+ batch_size (int): The size of the batch.
66
+ num_devices (int): The number of devices.
67
+
68
+ Returns:
69
+ int: The number of padding rows (or elements) needed to make the batch size
70
+ divisible by the number of devices.
71
+ """
72
+ remainder = batch_size % num_devices
73
+ return 0 if remainder == 0 else num_devices - remainder
74
+
75
+ @staticmethod
76
+ def pad_array(x: "jax.Array", n_pad: int, axis: int) -> "jax.Array":
77
+ """Pad an array to ensure compatibility with sharding.
78
+
79
+ Args:
80
+ x (jax.Array): The input array to be padded.
81
+ n_pad (int): The number of padding elements to add.
82
+ axis (int): The axis along which to apply the padding.
83
+
84
+ Returns:
85
+ jax.Array: The padded array with padding applied along the specified axis.
86
+
87
+ Raises:
88
+ ValueError: If padding is requested along an unsupported axis for a 1D
89
+ array.
90
+ """
91
+ if x.ndim == 1:
92
+ if axis == 0:
93
+ return jnp.pad(x, pad_width=(0, n_pad), mode="edge")
94
+ msg = "Padding 1D arrays is only supported along axis 0."
95
+ raise ValueError(msg)
96
+
97
+ # Initialize all axes with no padding
98
+ pad_width: list[tuple[int, int]] = [(0, 0)] * x.ndim
99
+ # Apply padding to the specified axis
100
+ pad_width[axis] = (0, n_pad)
101
+
102
+ return jnp.pad(x, pad_width=pad_width, mode="edge")
103
+
104
+ @staticmethod
105
+ def collate_without_output(
106
+ batch: list[tuple],
107
+ device: "NamedSharding | None" = None,
108
+ ) -> tuple:
109
+ """Collate function to process batches with sharding and padding.
110
+
111
+ This function unpacks the batch of data, converts it into JAX arrays, and
112
+ applies padding to ensure the batch size is compatible with the number of
113
+ devices, if sharding is necessary. When a device is provided, the data is
114
+ automatically distributed across the available devices.
115
+
116
+ Args:
117
+ batch (list[tuple]): A list of tuples, where each tuple contains the input
118
+ data, optional target data, and array-like keyword arguments.
119
+ device (NamedSharding | None, optional): Sharding using named axes for
120
+ parallel data distribution across devices. Defaults to `None`, meaning
121
+ no sharding is applied.
122
+
123
+ Returns:
124
+ tuple: A tuple containing:
125
+ - n_pad (int): The number of padding rows/elements added (0 if no
126
+ padding was required).
127
+ - x_batch (jax.Array): The input batch with padding applied if
128
+ necessary.
129
+ - kwargs_batch (list[jax.Array]): A list of keyword arguments with
130
+ padding applied if necessary.
131
+ """
132
+ x_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))
133
+
134
+ n_pad = (
135
+ ArrayLoader.calculate_padding(
136
+ len(x_batch),
137
+ num_devices=device.num_devices,
138
+ )
139
+ if device
140
+ else 0
141
+ )
142
+ if n_pad:
143
+ x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
144
+ kwargs_batch = [
145
+ ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
146
+ ]
147
+
148
+ if device:
149
+ x_batch = device_put(x_batch, device=device)
150
+ kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]
151
+
152
+ return n_pad, x_batch, *kwargs_batch
153
+
154
+ @staticmethod
155
+ def collate_with_sharding(
156
+ batch: list[tuple],
157
+ device: "NamedSharding | None" = None,
158
+ ) -> tuple:
159
+ """Collate function to process batches with sharding and padding.
160
+
161
+ This function unpacks the batch of data, converts it into JAX arrays, and
162
+ applies padding to ensure the batch size is compatible with the number of
163
+ devices, if sharding is necessary. When a device is provided, the data is
164
+ automatically distributed across the available devices.
165
+
166
+ Args:
167
+ batch (list[tuple]): A list of tuples, where each tuple contains the input
168
+ data, optional target data, and array-like keyword arguments.
169
+ device (NamedSharding | None, optional): Sharding using named axes for
170
+ parallel data distribution across devices. Defaults to `None`, meaning
171
+ no sharding is applied.
172
+
173
+ Returns:
174
+ tuple: A tuple containing:
175
+ - n_pad (int): The number of padding rows/elements added (0 if no
176
+ padding was required).
177
+ - x_batch (jax.Array): The input batch with padding applied if
178
+ necessary.
179
+ - y_batch (jax.Array): The target batch with padding applied.
180
+ - kwargs_batch (list[jax.Array]): A list of keyword arguments with
181
+ padding applied if necessary.
182
+ """
183
+ x_batch, y_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))
184
+
185
+ n_pad = (
186
+ ArrayLoader.calculate_padding(
187
+ len(x_batch),
188
+ num_devices=device.num_devices,
189
+ )
190
+ if device
191
+ else 0
192
+ )
193
+ if n_pad:
194
+ x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
195
+ y_batch = ArrayLoader.pad_array(y_batch, n_pad=n_pad, axis=0)
196
+ kwargs_batch = [
197
+ ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
198
+ ]
199
+
200
+ if device:
201
+ x_batch = device_put(x_batch, device=device)
202
+ y_batch = device_put(y_batch, device=device)
203
+ kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]
204
+
205
+ return n_pad, x_batch, y_batch, *kwargs_batch
aimz/model/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Eli Lilly and Company
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Model object."""
16
+
17
+ from aimz.model.impact_model import ImpactModel
18
+
19
+ __all__ = ["ImpactModel"]
aimz/model/_core.py ADDED
@@ -0,0 +1,91 @@
1
+ # Copyright 2025 Eli Lilly and Company
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Base class for impact model."""
16
+
17
+ from abc import ABC, abstractmethod
18
+ from typing import TYPE_CHECKING, Self
19
+
20
+ from aimz.utils._validation import _validate_kernel_signature
21
+
22
+ if TYPE_CHECKING:
23
+ from collections.abc import Callable
24
+
25
+ import jax
26
+
27
+
28
+ class BaseModel(ABC):
29
+ """Abstract base class for the impact model.
30
+
31
+ Attributes:
32
+ kernel (Callable): A probabilistic model with Pyro primitives.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ kernel: "Callable",
38
+ param_input: str = "X",
39
+ param_output: str = "y",
40
+ ) -> None:
41
+ """Initialize the BaseModel with a callable model.
42
+
43
+ Args:
44
+ kernel (Callable): A probabilistic model with Pyro primitives.
45
+ param_input (str, optional): Name of the parameter in the `kernel` for the
46
+ main input data. Defaults to `"X"`.
47
+ param_output (str, optional): Name of the parameter in the `kernel` for the
48
+ output data. Defaults to `"y"`.
49
+ """
50
+ self.kernel = kernel
51
+ self.param_input = param_input
52
+ self.param_output = param_output
53
+ _validate_kernel_signature(self.kernel, self.param_input, self.param_output)
54
+
55
+ @abstractmethod
56
+ def fit(
57
+ self,
58
+ X: "jax.Array",
59
+ y: "jax.Array",
60
+ **kwargs: object,
61
+ ) -> Self:
62
+ """Fit the model to the input data `X` and output data `y`.
63
+
64
+ Args:
65
+ X (jax.Array): Input data with shape `(n_samples_X, n_features)`.
66
+ y (jax.Array): Output data with shape `(n_samples_Y,)`.
67
+ **kwargs (object): Additional arguments passed to the model, except for `X`
68
+ and `y`. All array-like objects in `**kwargs` are expected to be JAX
69
+ arrays.
70
+
71
+ Returns:
72
+ BaseModel: The fitted model instance, enabling method chaining.
73
+ """
74
+ return self
75
+
76
+ @abstractmethod
77
+ def predict(self, X: "jax.Array", **kwargs: object) -> None:
78
+ """Predict the output based on the fitted model.
79
+
80
+ Args:
81
+ X (jax.Array): Input data with shape `(n_samples_X, n_features)`.
82
+ **kwargs (object): Additional arguments passed to the model, except for `X`
83
+ and `y`. All array-like objects in `**kwargs` are expected to be JAX
84
+ arrays.
85
+ """
86
+
87
+ @abstractmethod
88
+ def estimate_effect(
89
+ self,
90
+ ) -> None:
91
+ """Estimate the effect of an intervention."""