aimz 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimz/__init__.py +26 -0
- aimz/_exceptions.py +25 -0
- aimz/data/__init__.py +19 -0
- aimz/data/_sharding.py +230 -0
- aimz/data/array_loader.py +205 -0
- aimz/model/__init__.py +19 -0
- aimz/model/_core.py +91 -0
- aimz/model/impact_model.py +1127 -0
- aimz/sampling/__init__.py +15 -0
- aimz/sampling/_forward.py +92 -0
- aimz/utils/__init__.py +15 -0
- aimz/utils/_kwargs.py +48 -0
- aimz/utils/_output.py +144 -0
- aimz/utils/_validation.py +175 -0
- aimz-0.1.0.dist-info/METADATA +258 -0
- aimz-0.1.0.dist-info/RECORD +19 -0
- aimz-0.1.0.dist-info/WHEEL +5 -0
- aimz-0.1.0.dist-info/licenses/LICENSE +201 -0
- aimz-0.1.0.dist-info/top_level.txt +1 -0
aimz/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 Eli Lilly and Company
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Initialize package with global setup."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import sys
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
if not logger.hasHandlers():
|
|
23
|
+
logger.setLevel(logging.INFO)
|
|
24
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
25
|
+
logger.addHandler(handler)
|
|
26
|
+
logging.getLogger("py.warnings").addHandler(handler)
|
aimz/_exceptions.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright 2025 Eli Lilly and Company
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Custom warnings and errors."""
|
|
16
|
+
|
|
17
|
+
__all__ = ["KernelValidationError", "NotFittedError"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NotFittedError(ValueError, AttributeError):
|
|
21
|
+
"""Exception class to raise if model is used before fitting."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class KernelValidationError(Exception):
|
|
25
|
+
"""Exception class to raise if kernel validation fails."""
|
aimz/data/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Eli Lilly and Company
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utilities for data handling and processing."""
|
|
16
|
+
|
|
17
|
+
from aimz.data.array_loader import ArrayLoader
|
|
18
|
+
|
|
19
|
+
__all__ = ["ArrayLoader"]
|
aimz/data/_sharding.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright 2025 Eli Lilly and Company
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module for creating functions for sharding.
|
|
16
|
+
|
|
17
|
+
NOTE: This module is experimental and subject to change. It utilizes JAX's `shard_map()`
|
|
18
|
+
to distribute computations across devices. Tested on CPU and GPU.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from functools import partial
|
|
22
|
+
from typing import TYPE_CHECKING
|
|
23
|
+
|
|
24
|
+
from jax import jit
|
|
25
|
+
from jax.experimental.shard_map import shard_map
|
|
26
|
+
from jax.sharding import PartitionSpec
|
|
27
|
+
from numpyro.infer import log_likelihood as log_lik
|
|
28
|
+
|
|
29
|
+
from aimz.sampling._forward import _sample_forward
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from collections.abc import Callable
|
|
33
|
+
|
|
34
|
+
import jax
|
|
35
|
+
from jax.sharding import Mesh
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _create_sharded_sampler(
|
|
39
|
+
mesh: "Mesh | None",
|
|
40
|
+
n_kwargs_array: int,
|
|
41
|
+
n_kwargs_extra: int,
|
|
42
|
+
) -> "Callable":
|
|
43
|
+
"""Create a sharded posterior predictive sampling function.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
mesh (Mesh): The JAX mesh object defining the device mesh for sharding.
|
|
47
|
+
n_kwargs_array (int): The number of arguments in the keyword arguments that
|
|
48
|
+
are array-like (sharded).
|
|
49
|
+
n_kwargs_extra (int): The number of extra keyword arguments that are not
|
|
50
|
+
array-like (not sharded).
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Callable: A sharded function that takes the following arguments:
|
|
54
|
+
- rng_key (jax.Array): A pseudo-random number generator key.
|
|
55
|
+
- kernel (Callable): A probabilistic model with Pyro primitives.
|
|
56
|
+
- posterior_samples (dict): A dictionary of posterior samples.
|
|
57
|
+
- batch_shape (tuple[int]): The shape of the batch dimension, specifically
|
|
58
|
+
`(num_samples,)`.
|
|
59
|
+
- param_input (str): The name of the parameter in the `kernel` for the
|
|
60
|
+
input data.
|
|
61
|
+
- kwargs_key (tuple[str]): A tuple of keyword argument names.
|
|
62
|
+
- X (jax.Array): Input data.
|
|
63
|
+
- *args (tuple): Additional arguments constructed from the original keyword
|
|
64
|
+
arguments (both sharded and non-sharded).
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def f(
|
|
69
|
+
kernel: "Callable",
|
|
70
|
+
num_samples: int,
|
|
71
|
+
rng_key: "jax.Array",
|
|
72
|
+
return_sites: tuple[str],
|
|
73
|
+
posterior_samples: dict[str, "jax.Array"],
|
|
74
|
+
param_input: str,
|
|
75
|
+
kwargs_key: tuple[str],
|
|
76
|
+
X: "jax.Array",
|
|
77
|
+
*args: tuple,
|
|
78
|
+
) -> dict[str, "jax.Array"]:
|
|
79
|
+
return _sample_forward(
|
|
80
|
+
model=kernel,
|
|
81
|
+
num_samples=num_samples,
|
|
82
|
+
rng_key=rng_key,
|
|
83
|
+
return_sites=return_sites,
|
|
84
|
+
posterior_samples=posterior_samples,
|
|
85
|
+
model_kwargs={
|
|
86
|
+
param_input: X,
|
|
87
|
+
**dict(zip(kwargs_key, args, strict=True)),
|
|
88
|
+
},
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if mesh is None:
|
|
92
|
+
return partial(
|
|
93
|
+
jit,
|
|
94
|
+
static_argnames=[
|
|
95
|
+
"kernel",
|
|
96
|
+
"num_samples",
|
|
97
|
+
"return_sites",
|
|
98
|
+
"param_input",
|
|
99
|
+
"kwargs_key",
|
|
100
|
+
],
|
|
101
|
+
)(f)
|
|
102
|
+
|
|
103
|
+
(axis,) = mesh.axis_names
|
|
104
|
+
|
|
105
|
+
return partial(
|
|
106
|
+
jit,
|
|
107
|
+
static_argnames=[
|
|
108
|
+
"kernel",
|
|
109
|
+
"num_samples",
|
|
110
|
+
"return_sites",
|
|
111
|
+
"param_input",
|
|
112
|
+
"kwargs_key",
|
|
113
|
+
],
|
|
114
|
+
)(
|
|
115
|
+
partial(
|
|
116
|
+
shard_map,
|
|
117
|
+
mesh=mesh,
|
|
118
|
+
in_specs=(
|
|
119
|
+
None, # kernel
|
|
120
|
+
None, # posterior_samples
|
|
121
|
+
None, # rng_key
|
|
122
|
+
None, # num_samples
|
|
123
|
+
None, # return_sites
|
|
124
|
+
None, # param_input
|
|
125
|
+
None, # kwargs_key
|
|
126
|
+
PartitionSpec(axis), # X
|
|
127
|
+
*(
|
|
128
|
+
[PartitionSpec(axis)] * n_kwargs_array # kwargs_array
|
|
129
|
+
+ [None] * n_kwargs_extra # kwargs_extra
|
|
130
|
+
),
|
|
131
|
+
),
|
|
132
|
+
out_specs=PartitionSpec(None, axis),
|
|
133
|
+
check_rep=False,
|
|
134
|
+
)(f),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _create_sharded_log_likelihood(
|
|
139
|
+
mesh: "Mesh | None",
|
|
140
|
+
n_kwargs_array: int,
|
|
141
|
+
n_kwargs_extra: int,
|
|
142
|
+
) -> "Callable":
|
|
143
|
+
"""Create a sharded log-likelihood function.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
mesh (Mesh): The JAX mesh object defining the device mesh for sharding.
|
|
147
|
+
n_kwargs_array (int): The number of arguments in the keyword arguments that are
|
|
148
|
+
array-like (sharded).
|
|
149
|
+
n_kwargs_extra (int): The number of extra keyword arguments that are not
|
|
150
|
+
array-like (not sharded).
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Callable: A sharded function that takes the following arguments:
|
|
154
|
+
- kernel (Callable): A probabilistic model with Pyro primitives optimized
|
|
155
|
+
with variational inference.
|
|
156
|
+
- posterior_samples (dict): A dictionary of posterior samples.
|
|
157
|
+
- param_input (str): The name of the parameter in the `kernel` for the input
|
|
158
|
+
data.
|
|
159
|
+
- param_output (str): The name of the parameter in the `kernel` for the
|
|
160
|
+
output data.
|
|
161
|
+
- kwargs_key (tuple[str]): A tuple of keyword argument names.
|
|
162
|
+
- X (jax.Array): Input data.
|
|
163
|
+
- y (jax.Array): Output data.
|
|
164
|
+
- *args (tuple): Additional arguments constructed from the original keyword
|
|
165
|
+
arguments (both sharded and non-sharded).
|
|
166
|
+
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def f(
|
|
170
|
+
kernel: "Callable",
|
|
171
|
+
posterior_samples: dict,
|
|
172
|
+
param_input: str,
|
|
173
|
+
param_output: str,
|
|
174
|
+
kwargs_key: tuple[str],
|
|
175
|
+
X: "jax.Array",
|
|
176
|
+
y: "jax.Array",
|
|
177
|
+
*args: tuple,
|
|
178
|
+
) -> "jax.Array":
|
|
179
|
+
return log_lik(
|
|
180
|
+
kernel,
|
|
181
|
+
posterior_samples=posterior_samples,
|
|
182
|
+
**{
|
|
183
|
+
param_input: X,
|
|
184
|
+
param_output: y,
|
|
185
|
+
**dict(zip(kwargs_key, args, strict=True)),
|
|
186
|
+
},
|
|
187
|
+
).get(param_output)
|
|
188
|
+
|
|
189
|
+
if mesh is None:
|
|
190
|
+
return partial(
|
|
191
|
+
jit,
|
|
192
|
+
static_argnames=[
|
|
193
|
+
"kernel",
|
|
194
|
+
"param_input",
|
|
195
|
+
"param_output",
|
|
196
|
+
"kwargs_key",
|
|
197
|
+
],
|
|
198
|
+
)(f)
|
|
199
|
+
|
|
200
|
+
(axis,) = mesh.axis_names
|
|
201
|
+
|
|
202
|
+
return partial(
|
|
203
|
+
jit,
|
|
204
|
+
static_argnames=[
|
|
205
|
+
"kernel",
|
|
206
|
+
"param_input",
|
|
207
|
+
"param_output",
|
|
208
|
+
"kwargs_key",
|
|
209
|
+
],
|
|
210
|
+
)(
|
|
211
|
+
partial(
|
|
212
|
+
shard_map,
|
|
213
|
+
mesh=mesh,
|
|
214
|
+
in_specs=(
|
|
215
|
+
None, # kernel
|
|
216
|
+
None, # posterior_samples
|
|
217
|
+
None, # param_input
|
|
218
|
+
None, # param_output
|
|
219
|
+
None, # kwargs_key
|
|
220
|
+
PartitionSpec(axis), # X
|
|
221
|
+
PartitionSpec(axis), # y
|
|
222
|
+
*(
|
|
223
|
+
[PartitionSpec(axis)] * n_kwargs_array # kwargs_array
|
|
224
|
+
+ [None] * n_kwargs_extra # kwargs_extra
|
|
225
|
+
),
|
|
226
|
+
),
|
|
227
|
+
out_specs=PartitionSpec(None, axis),
|
|
228
|
+
check_rep=False,
|
|
229
|
+
)(f),
|
|
230
|
+
)
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright 2025 Eli Lilly and Company
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module for custom data loader with padding logic for JAX arrays.
|
|
16
|
+
|
|
17
|
+
This module defines a custom `ArrayLoader` that processes batches of data and applies
|
|
18
|
+
padding to ensure the batch size is compatible with sharding across multiple XLA
|
|
19
|
+
devices.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from typing import TYPE_CHECKING
|
|
23
|
+
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
from jax import device_put
|
|
26
|
+
from jax_dataloader import ArrayDataset
|
|
27
|
+
from torch.utils.data import DataLoader
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from collections.abc import Callable
|
|
31
|
+
|
|
32
|
+
import jax
|
|
33
|
+
from jax.sharding import NamedSharding
|
|
34
|
+
from torch.utils.data import Sampler
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ArrayLoader(DataLoader):
|
|
38
|
+
"""A custom DataLoader class that extends PyTorch's DataLoader."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
dataset: ArrayDataset,
|
|
43
|
+
*,
|
|
44
|
+
batch_size: int = 1,
|
|
45
|
+
sampler: "Sampler | None" = None,
|
|
46
|
+
num_workers: int = 0,
|
|
47
|
+
collate_fn: "Callable | None" = None,
|
|
48
|
+
pin_memory: bool = False,
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Initializes an ArrayLoader instance."""
|
|
51
|
+
super().__init__(
|
|
52
|
+
dataset,
|
|
53
|
+
batch_size=batch_size,
|
|
54
|
+
sampler=sampler,
|
|
55
|
+
num_workers=num_workers,
|
|
56
|
+
collate_fn=collate_fn,
|
|
57
|
+
pin_memory=pin_memory,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def calculate_padding(batch_size: int, num_devices: int) -> int:
|
|
62
|
+
"""Calculate the number of padding needed.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
batch_size (int): The size of the batch.
|
|
66
|
+
num_devices (int): The number of devices.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
int: The number of padding rows (or elements) needed to make the batch size
|
|
70
|
+
divisible by the number of devices.
|
|
71
|
+
"""
|
|
72
|
+
remainder = batch_size % num_devices
|
|
73
|
+
return 0 if remainder == 0 else num_devices - remainder
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def pad_array(x: "jax.Array", n_pad: int, axis: int) -> "jax.Array":
|
|
77
|
+
"""Pad an array to ensure compatibility with sharding.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
x (jax.Array): The input array to be padded.
|
|
81
|
+
n_pad (int): The number of padding elements to add.
|
|
82
|
+
axis (int): The axis along which to apply the padding.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
jax.Array: The padded array with padding applied along the specified axis.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
ValueError: If padding is requested along an unsupported axis for a 1D
|
|
89
|
+
array.
|
|
90
|
+
"""
|
|
91
|
+
if x.ndim == 1:
|
|
92
|
+
if axis == 0:
|
|
93
|
+
return jnp.pad(x, pad_width=(0, n_pad), mode="edge")
|
|
94
|
+
msg = "Padding 1D arrays is only supported along axis 0."
|
|
95
|
+
raise ValueError(msg)
|
|
96
|
+
|
|
97
|
+
# Initialize all axes with no padding
|
|
98
|
+
pad_width: list[tuple[int, int]] = [(0, 0)] * x.ndim
|
|
99
|
+
# Apply padding to the specified axis
|
|
100
|
+
pad_width[axis] = (0, n_pad)
|
|
101
|
+
|
|
102
|
+
return jnp.pad(x, pad_width=pad_width, mode="edge")
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def collate_without_output(
|
|
106
|
+
batch: list[tuple],
|
|
107
|
+
device: "NamedSharding | None" = None,
|
|
108
|
+
) -> tuple:
|
|
109
|
+
"""Collate function to process batches with sharding and padding.
|
|
110
|
+
|
|
111
|
+
This function unpacks the batch of data, converts it into JAX arrays, and
|
|
112
|
+
applies padding to ensure the batch size is compatible with the number of
|
|
113
|
+
devices, if sharding is necessary. When a device is provided, the data is
|
|
114
|
+
automatically distributed across the available devices.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
batch (list[tuple]): A list of tuples, where each tuple contains the input
|
|
118
|
+
data, optional target data, and array-like keyword arguments.
|
|
119
|
+
device (NamedSharding | None, optional): Sharding using named axes for
|
|
120
|
+
parallel data distribution across devices. Defaults to `None`, meaning
|
|
121
|
+
no sharding is applied.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
tuple: A tuple containing:
|
|
125
|
+
- n_pad (int): The number of padding rows/elements added (0 if no
|
|
126
|
+
padding was required).
|
|
127
|
+
- x_batch (jax.Array): The input batch with padding applied if
|
|
128
|
+
necessary.
|
|
129
|
+
- kwargs_batch (list[jax.Array]): A list of keyword arguments with
|
|
130
|
+
padding applied if necessary.
|
|
131
|
+
"""
|
|
132
|
+
x_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))
|
|
133
|
+
|
|
134
|
+
n_pad = (
|
|
135
|
+
ArrayLoader.calculate_padding(
|
|
136
|
+
len(x_batch),
|
|
137
|
+
num_devices=device.num_devices,
|
|
138
|
+
)
|
|
139
|
+
if device
|
|
140
|
+
else 0
|
|
141
|
+
)
|
|
142
|
+
if n_pad:
|
|
143
|
+
x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
|
|
144
|
+
kwargs_batch = [
|
|
145
|
+
ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
if device:
|
|
149
|
+
x_batch = device_put(x_batch, device=device)
|
|
150
|
+
kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]
|
|
151
|
+
|
|
152
|
+
return n_pad, x_batch, *kwargs_batch
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def collate_with_sharding(
|
|
156
|
+
batch: list[tuple],
|
|
157
|
+
device: "NamedSharding | None" = None,
|
|
158
|
+
) -> tuple:
|
|
159
|
+
"""Collate function to process batches with sharding and padding.
|
|
160
|
+
|
|
161
|
+
This function unpacks the batch of data, converts it into JAX arrays, and
|
|
162
|
+
applies padding to ensure the batch size is compatible with the number of
|
|
163
|
+
devices, if sharding is necessary. When a device is provided, the data is
|
|
164
|
+
automatically distributed across the available devices.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
batch (list[tuple]): A list of tuples, where each tuple contains the input
|
|
168
|
+
data, optional target data, and array-like keyword arguments.
|
|
169
|
+
device (NamedSharding | None, optional): Sharding using named axes for
|
|
170
|
+
parallel data distribution across devices. Defaults to `None`, meaning
|
|
171
|
+
no sharding is applied.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
tuple: A tuple containing:
|
|
175
|
+
- n_pad (int): The number of padding rows/elements added (0 if no
|
|
176
|
+
padding was required).
|
|
177
|
+
- x_batch (jax.Array): The input batch with padding applied if
|
|
178
|
+
necessary.
|
|
179
|
+
- y_batch (jax.Array): The target batch with padding applied.
|
|
180
|
+
- kwargs_batch (list[jax.Array]): A list of keyword arguments with
|
|
181
|
+
padding applied if necessary.
|
|
182
|
+
"""
|
|
183
|
+
x_batch, y_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))
|
|
184
|
+
|
|
185
|
+
n_pad = (
|
|
186
|
+
ArrayLoader.calculate_padding(
|
|
187
|
+
len(x_batch),
|
|
188
|
+
num_devices=device.num_devices,
|
|
189
|
+
)
|
|
190
|
+
if device
|
|
191
|
+
else 0
|
|
192
|
+
)
|
|
193
|
+
if n_pad:
|
|
194
|
+
x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
|
|
195
|
+
y_batch = ArrayLoader.pad_array(y_batch, n_pad=n_pad, axis=0)
|
|
196
|
+
kwargs_batch = [
|
|
197
|
+
ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
|
|
198
|
+
]
|
|
199
|
+
|
|
200
|
+
if device:
|
|
201
|
+
x_batch = device_put(x_batch, device=device)
|
|
202
|
+
y_batch = device_put(y_batch, device=device)
|
|
203
|
+
kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]
|
|
204
|
+
|
|
205
|
+
return n_pad, x_batch, y_batch, *kwargs_batch
|
aimz/model/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Eli Lilly and Company
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Model object."""
|
|
16
|
+
|
|
17
|
+
from aimz.model.impact_model import ImpactModel
|
|
18
|
+
|
|
19
|
+
__all__ = ["ImpactModel"]
|
aimz/model/_core.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# Copyright 2025 Eli Lilly and Company
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Base class for impact model."""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import TYPE_CHECKING, Self
|
|
19
|
+
|
|
20
|
+
from aimz.utils._validation import _validate_kernel_signature
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections.abc import Callable
|
|
24
|
+
|
|
25
|
+
import jax
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseModel(ABC):
|
|
29
|
+
"""Abstract base class for the impact model.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
kernel (Callable): A probabilistic model with Pyro primitives.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
kernel: "Callable",
|
|
38
|
+
param_input: str = "X",
|
|
39
|
+
param_output: str = "y",
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Initialize the BaseModel with a callable model.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
kernel (Callable): A probabilistic model with Pyro primitives.
|
|
45
|
+
param_input (str, optional): Name of the parameter in the `kernel` for the
|
|
46
|
+
main input data. Defaults to `"X"`.
|
|
47
|
+
param_output (str, optional): Name of the parameter in the `kernel` for the
|
|
48
|
+
output data. Defaults to `"y"`.
|
|
49
|
+
"""
|
|
50
|
+
self.kernel = kernel
|
|
51
|
+
self.param_input = param_input
|
|
52
|
+
self.param_output = param_output
|
|
53
|
+
_validate_kernel_signature(self.kernel, self.param_input, self.param_output)
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def fit(
|
|
57
|
+
self,
|
|
58
|
+
X: "jax.Array",
|
|
59
|
+
y: "jax.Array",
|
|
60
|
+
**kwargs: object,
|
|
61
|
+
) -> Self:
|
|
62
|
+
"""Fit the model to the input data `X` and output data `y`.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
X (jax.Array): Input data with shape `(n_samples_X, n_features)`.
|
|
66
|
+
y (jax.Array): Output data with shape `(n_samples_Y,)`.
|
|
67
|
+
**kwargs (object): Additional arguments passed to the model, except for `X`
|
|
68
|
+
and `y`. All array-like objects in `**kwargs` are expected to be JAX
|
|
69
|
+
arrays.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
BaseModel: The fitted model instance, enabling method chaining.
|
|
73
|
+
"""
|
|
74
|
+
return self
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def predict(self, X: "jax.Array", **kwargs: object) -> None:
|
|
78
|
+
"""Predict the output based on the fitted model.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
X (jax.Array): Input data with shape `(n_samples_X, n_features)`.
|
|
82
|
+
**kwargs (object): Additional arguments passed to the model, except for `X`
|
|
83
|
+
and `y`. All array-like objects in `**kwargs` are expected to be JAX
|
|
84
|
+
arrays.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def estimate_effect(
|
|
89
|
+
self,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""Estimate the effect of an intervention."""
|