opttx 0.1.0a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. opttx-0.1.0a1/LICENSE +21 -0
  2. opttx-0.1.0a1/MANIFEST.in +10 -0
  3. opttx-0.1.0a1/PKG-INFO +220 -0
  4. opttx-0.1.0a1/README.md +172 -0
  5. opttx-0.1.0a1/opttx/__init__.py +65 -0
  6. opttx-0.1.0a1/opttx/apply.py +80 -0
  7. opttx-0.1.0a1/opttx/curvature.py +262 -0
  8. opttx-0.1.0a1/opttx/metrics.py +73 -0
  9. opttx-0.1.0a1/opttx/models/__init__.py +16 -0
  10. opttx-0.1.0a1/opttx/models/gat.py +177 -0
  11. opttx-0.1.0a1/opttx/models/gcn.py +151 -0
  12. opttx-0.1.0a1/opttx/objective.py +140 -0
  13. opttx-0.1.0a1/opttx/optimizers/__init__.py +35 -0
  14. opttx-0.1.0a1/opttx/optimizers/_utils.py +168 -0
  15. opttx-0.1.0a1/opttx/optimizers/aa_accelerator.py +346 -0
  16. opttx-0.1.0a1/opttx/optimizers/adam.py +167 -0
  17. opttx-0.1.0a1/opttx/optimizers/adamw.py +172 -0
  18. opttx-0.1.0a1/opttx/optimizers/cg.py +189 -0
  19. opttx-0.1.0a1/opttx/optimizers/cr.py +189 -0
  20. opttx-0.1.0a1/opttx/optimizers/lbfgs.py +165 -0
  21. opttx-0.1.0a1/opttx/optimizers/muon.py +342 -0
  22. opttx-0.1.0a1/opttx/optimizers/nltgcr.py +194 -0
  23. opttx-0.1.0a1/opttx/optimizers/nltgcr_crossbatch.py +386 -0
  24. opttx-0.1.0a1/opttx/optimizers/sgd.py +152 -0
  25. opttx-0.1.0a1/opttx/optimizers/shampoo.py +488 -0
  26. opttx-0.1.0a1/opttx/optimizers/soap.py +583 -0
  27. opttx-0.1.0a1/opttx/optimizers/tgs.py +171 -0
  28. opttx-0.1.0a1/opttx/optimizers/tgs_accelerator.py +427 -0
  29. opttx-0.1.0a1/opttx/optimizers/wrapper.py +104 -0
  30. opttx-0.1.0a1/opttx/py.typed +0 -0
  31. opttx-0.1.0a1/opttx/solvers/__init__.py +17 -0
  32. opttx-0.1.0a1/opttx/solvers/cg.py +223 -0
  33. opttx-0.1.0a1/opttx/solvers/cr.py +247 -0
  34. opttx-0.1.0a1/opttx/solvers/nltgcr.py +352 -0
  35. opttx-0.1.0a1/opttx/solvers/tgs.py +417 -0
  36. opttx-0.1.0a1/opttx/state.py +117 -0
  37. opttx-0.1.0a1/opttx/terms.py +50 -0
  38. opttx-0.1.0a1/opttx.egg-info/PKG-INFO +220 -0
  39. opttx-0.1.0a1/opttx.egg-info/SOURCES.txt +42 -0
  40. opttx-0.1.0a1/opttx.egg-info/dependency_links.txt +1 -0
  41. opttx-0.1.0a1/opttx.egg-info/requires.txt +7 -0
  42. opttx-0.1.0a1/opttx.egg-info/top_level.txt +1 -0
  43. opttx-0.1.0a1/pyproject.toml +41 -0
  44. opttx-0.1.0a1/setup.cfg +4 -0
opttx-0.1.0a1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Tianshi Xu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,10 @@
1
+ include LICENSE
2
+ include README.md
3
+ include pyproject.toml
4
+ recursive-include opttx *.py
5
+ recursive-include opttx py.typed
6
+ prune ref
7
+ prune tests
8
+ prune docs
9
+ prune examples
10
+ prune .pytest_cache
opttx-0.1.0a1/PKG-INFO ADDED
@@ -0,0 +1,220 @@
1
+ Metadata-Version: 2.4
2
+ Name: opttx
3
+ Version: 0.1.0a1
4
+ Summary: JAX/Flax/Optax optimizer manager
5
+ Author: Tianshi Xu
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Tianshi Xu
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Project-URL: Homepage, https://github.com/Hitenze/opttx
29
+ Project-URL: Repository, https://github.com/Hitenze/opttx
30
+ Project-URL: Issues, https://github.com/Hitenze/opttx/issues
31
+ Keywords: jax,optax,flax,optimizer
32
+ Classifier: Development Status :: 3 - Alpha
33
+ Classifier: Programming Language :: Python :: 3
34
+ Classifier: License :: OSI Approved :: MIT License
35
+ Classifier: Operating System :: OS Independent
36
+ Classifier: Intended Audience :: Science/Research
37
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
38
+ Requires-Python: >=3.10
39
+ Description-Content-Type: text/markdown
40
+ License-File: LICENSE
41
+ Requires-Dist: jax>=0.4.20
42
+ Requires-Dist: jaxlib>=0.4.20
43
+ Requires-Dist: optax>=0.2.0
44
+ Requires-Dist: flax>=0.8.0
45
+ Provides-Extra: dev
46
+ Requires-Dist: pytest>=7.4; extra == "dev"
47
+ Dynamic: license-file
48
+
49
+ # OptTx
50
+
51
+ > **Research Code**: Co-developed with Claude Code, Gemini CLI, Codex CLI, and Cursor. No guarantees provided. Use at your own risk.
52
+
53
+ JAX/Flax/Optax optimizer library for PINNs and second-order methods.
54
+
55
+ ## Features
56
+
57
+ - **Multi-term objectives**: `Objective` with `TermSpec` for PINNs (PDE, BC, IC terms)
58
+ - **First-order optimizers**: Adam, SGD, AdamW, SOAP, MUON, Shampoo, L-BFGS
59
+ - **Second-order optimizers**: CGOptimizer (Fisher/GGN), CROptimizer (Hessian)
60
+ - **Acceleration methods**: TGS, NLTGCR, Anderson Acceleration (AA)
61
+ - **Graph neural networks**: GCN, GAT layers for node classification
62
+ - **Matrix-free curvature**: `build_hessian_matvec`, `build_fisher_matvec`
63
+ - **JIT-stable**: Works with `jax.jit` and `jax.lax.scan`
64
+
65
+ ## Install
66
+
67
+ ```bash
68
+ pip install opttx
69
+ ```
70
+
71
+ For development:
72
+ ```bash
73
+ pip install -e .[dev]
74
+ ```
75
+
76
+ ## Quickstart
77
+
78
+ ### First-order optimizer
79
+
80
+ ```python
81
+ import jax
82
+ import jax.numpy as jnp
83
+ from flax import linen as nn
84
+
85
+ from opttx import Adam, Objective, TermSpec, TrainState
86
+
87
+ # Define model
88
+ class MLP(nn.Module):
89
+ @nn.compact
90
+ def __call__(self, x):
91
+ x = nn.Dense(32)(x)
92
+ x = nn.relu(x)
93
+ x = nn.Dense(1)(x)
94
+ return x
95
+
96
+ # Define loss
97
+ def mse_loss(pred, batch):
98
+ x, y = batch
99
+ return jnp.mean((pred - y) ** 2)
100
+
101
+ # Create objective
102
+ term = TermSpec(name="mse", batch_key="data", loss_fn=mse_loss)
103
+ objective = Objective(terms=[term])
104
+
105
+ # Initialize
106
+ model = MLP()
107
+ params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 3)))["params"]
108
+
109
+ state = TrainState(
110
+ step=jnp.array(0),
111
+ params=params,
112
+ opt_state=None,
113
+ apply_fn=lambda v, b: model.apply({"params": v["params"]}, b[0]),
114
+ )
115
+
116
+ # Create optimizer and train
117
+ optimizer = Adam(objective, learning_rate=1e-3)
118
+ state = optimizer.init(state)
119
+
120
+ batch = {"data": (jnp.ones((8, 3)), jnp.zeros((8, 1)))}
121
+ state, metrics = optimizer.step(state, batch)
122
+ print(f"Loss: {metrics['loss']}")
123
+ ```
124
+
125
+ ### Second-order optimizer (CR + Hessian)
126
+
127
+ ```python
128
+ from opttx import CROptimizer
129
+
130
+ optimizer = CROptimizer(
131
+ objective,
132
+ learning_rate=1.0,
133
+ damping=1e-3,
134
+ cr_iters=10,
135
+ curvature_type="hessian", # or "fisher"
136
+ )
137
+ state = optimizer.init(state)
138
+ state, metrics = optimizer.step(state, batch)
139
+ ```
140
+
141
+ ### Multi-term objective (PINNs)
142
+
143
+ ```python
144
+ def pde_loss(pred, batch):
145
+ return jnp.mean(pred ** 2)
146
+
147
+ def bc_loss(pred, batch):
148
+ return jnp.mean(pred ** 2)
149
+
150
+ pde_term = TermSpec(name="pde", batch_key="x_pde", loss_fn=pde_loss)
151
+ bc_term = TermSpec(name="bc", batch_key="x_bc", loss_fn=bc_loss)
152
+
153
+ objective = Objective(
154
+ terms=[pde_term, bc_term],
155
+ loss_weights={"pde": 1.0, "bc": 0.1},
156
+ )
157
+
158
+ batch = {
159
+ "x_pde": jnp.ones((100, 2)),
160
+ "x_bc": jnp.ones((20, 2)),
161
+ }
162
+ ```
163
+
164
+ ## API Reference
165
+
166
+ ### Optimizers
167
+
168
+ | Optimizer | Description |
169
+ |-----------|-------------|
170
+ | `Adam` | Adam optimizer |
171
+ | `SGD` | SGD with momentum |
172
+ | `AdamW` | Adam with weight decay |
173
+ | `SOAP` | Second-order approximation |
174
+ | `MUON` | Momentum with orthogonalization |
175
+ | `Shampoo` | Shampoo preconditioner |
176
+ | `LBFGSOptimizer` | L-BFGS quasi-Newton |
177
+ | `CGOptimizer` | Conjugate Gradient (Fisher/GGN) |
178
+ | `CROptimizer` | Conjugate Residual (Hessian) |
179
+ | `TGSOptimizer` | TGS acceleration |
180
+ | `TGSAccelerator` | TGS wrapper for any optimizer |
181
+ | `AAAccelerator` | Anderson Acceleration wrapper |
182
+ | `NLTGCROptimizer` | Nonlinear truncated GCR |
183
+
184
+ ### Curvature
185
+
186
+ | Function | Description |
187
+ |----------|-------------|
188
+ | `build_hessian_matvec` | Matrix-free Hessian-vector product |
189
+ | `build_fisher_matvec` | Matrix-free Fisher/GGN-vector product |
190
+ | `build_damped_matvec` | Add damping: (H + λI)v |
191
+
192
+ ### Solvers
193
+
194
+ | Function | Description |
195
+ |----------|-------------|
196
+ | `cg_solve` | Conjugate Gradient solver |
197
+ | `cr_solve` | Conjugate Residual solver |
198
+ | `tgs_solve_fori` | TGS solver (JIT-compatible) |
199
+ | `nltgcr_solve_fori` | NLTGCR solver (JIT-compatible) |
200
+
201
+ ### Models
202
+
203
+ | Model | Description |
204
+ |-------|-------------|
205
+ | `GCN` | Graph Convolutional Network |
206
+ | `GCNLayer` | Single GCN layer |
207
+ | `GAT` | Graph Attention Network |
208
+ | `GATLayer` | Single GAT layer |
209
+ | `normalize_adjacency` | Symmetric adjacency normalization |
210
+
211
+ ## Design Constraints
212
+
213
+ - `state.step` must be a scalar `jax.Array` (never Python int)
214
+ - Metrics have static string keys and scalar values
215
+ - Must include `"loss"` key in metrics
216
+ - Multi-term + `batch_stats` is not supported
217
+
218
+ ## License
219
+
220
+ MIT
@@ -0,0 +1,172 @@
1
+ # OptTx
2
+
3
+ > **Research Code**: Co-developed with Claude Code, Gemini CLI, Codex CLI, and Cursor. No guarantees provided. Use at your own risk.
4
+
5
+ JAX/Flax/Optax optimizer library for PINNs and second-order methods.
6
+
7
+ ## Features
8
+
9
+ - **Multi-term objectives**: `Objective` with `TermSpec` for PINNs (PDE, BC, IC terms)
10
+ - **First-order optimizers**: Adam, SGD, AdamW, SOAP, MUON, Shampoo, L-BFGS
11
+ - **Second-order optimizers**: CGOptimizer (Fisher/GGN), CROptimizer (Hessian)
12
+ - **Acceleration methods**: TGS, NLTGCR, Anderson Acceleration (AA)
13
+ - **Graph neural networks**: GCN, GAT layers for node classification
14
+ - **Matrix-free curvature**: `build_hessian_matvec`, `build_fisher_matvec`
15
+ - **JIT-stable**: Works with `jax.jit` and `jax.lax.scan`
16
+
17
+ ## Install
18
+
19
+ ```bash
20
+ pip install opttx
21
+ ```
22
+
23
+ For development:
24
+ ```bash
25
+ pip install -e .[dev]
26
+ ```
27
+
28
+ ## Quickstart
29
+
30
+ ### First-order optimizer
31
+
32
+ ```python
33
+ import jax
34
+ import jax.numpy as jnp
35
+ from flax import linen as nn
36
+
37
+ from opttx import Adam, Objective, TermSpec, TrainState
38
+
39
+ # Define model
40
+ class MLP(nn.Module):
41
+ @nn.compact
42
+ def __call__(self, x):
43
+ x = nn.Dense(32)(x)
44
+ x = nn.relu(x)
45
+ x = nn.Dense(1)(x)
46
+ return x
47
+
48
+ # Define loss
49
+ def mse_loss(pred, batch):
50
+ x, y = batch
51
+ return jnp.mean((pred - y) ** 2)
52
+
53
+ # Create objective
54
+ term = TermSpec(name="mse", batch_key="data", loss_fn=mse_loss)
55
+ objective = Objective(terms=[term])
56
+
57
+ # Initialize
58
+ model = MLP()
59
+ params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 3)))["params"]
60
+
61
+ state = TrainState(
62
+ step=jnp.array(0),
63
+ params=params,
64
+ opt_state=None,
65
+ apply_fn=lambda v, b: model.apply({"params": v["params"]}, b[0]),
66
+ )
67
+
68
+ # Create optimizer and train
69
+ optimizer = Adam(objective, learning_rate=1e-3)
70
+ state = optimizer.init(state)
71
+
72
+ batch = {"data": (jnp.ones((8, 3)), jnp.zeros((8, 1)))}
73
+ state, metrics = optimizer.step(state, batch)
74
+ print(f"Loss: {metrics['loss']}")
75
+ ```
76
+
77
+ ### Second-order optimizer (CR + Hessian)
78
+
79
+ ```python
80
+ from opttx import CROptimizer
81
+
82
+ optimizer = CROptimizer(
83
+ objective,
84
+ learning_rate=1.0,
85
+ damping=1e-3,
86
+ cr_iters=10,
87
+ curvature_type="hessian", # or "fisher"
88
+ )
89
+ state = optimizer.init(state)
90
+ state, metrics = optimizer.step(state, batch)
91
+ ```
92
+
93
+ ### Multi-term objective (PINNs)
94
+
95
+ ```python
96
+ def pde_loss(pred, batch):
97
+ return jnp.mean(pred ** 2)
98
+
99
+ def bc_loss(pred, batch):
100
+ return jnp.mean(pred ** 2)
101
+
102
+ pde_term = TermSpec(name="pde", batch_key="x_pde", loss_fn=pde_loss)
103
+ bc_term = TermSpec(name="bc", batch_key="x_bc", loss_fn=bc_loss)
104
+
105
+ objective = Objective(
106
+ terms=[pde_term, bc_term],
107
+ loss_weights={"pde": 1.0, "bc": 0.1},
108
+ )
109
+
110
+ batch = {
111
+ "x_pde": jnp.ones((100, 2)),
112
+ "x_bc": jnp.ones((20, 2)),
113
+ }
114
+ ```
115
+
116
+ ## API Reference
117
+
118
+ ### Optimizers
119
+
120
+ | Optimizer | Description |
121
+ |-----------|-------------|
122
+ | `Adam` | Adam optimizer |
123
+ | `SGD` | SGD with momentum |
124
+ | `AdamW` | Adam with weight decay |
125
+ | `SOAP` | Second-order approximation |
126
+ | `MUON` | Momentum with orthogonalization |
127
+ | `Shampoo` | Shampoo preconditioner |
128
+ | `LBFGSOptimizer` | L-BFGS quasi-Newton |
129
+ | `CGOptimizer` | Conjugate Gradient (Fisher/GGN) |
130
+ | `CROptimizer` | Conjugate Residual (Hessian) |
131
+ | `TGSOptimizer` | TGS acceleration |
132
+ | `TGSAccelerator` | TGS wrapper for any optimizer |
133
+ | `AAAccelerator` | Anderson Acceleration wrapper |
134
+ | `NLTGCROptimizer` | Nonlinear truncated GCR |
135
+
136
+ ### Curvature
137
+
138
+ | Function | Description |
139
+ |----------|-------------|
140
+ | `build_hessian_matvec` | Matrix-free Hessian-vector product |
141
+ | `build_fisher_matvec` | Matrix-free Fisher/GGN-vector product |
142
+ | `build_damped_matvec` | Add damping: (H + λI)v |
143
+
144
+ ### Solvers
145
+
146
+ | Function | Description |
147
+ |----------|-------------|
148
+ | `cg_solve` | Conjugate Gradient solver |
149
+ | `cr_solve` | Conjugate Residual solver |
150
+ | `tgs_solve_fori` | TGS solver (JIT-compatible) |
151
+ | `nltgcr_solve_fori` | NLTGCR solver (JIT-compatible) |
152
+
153
+ ### Models
154
+
155
+ | Model | Description |
156
+ |-------|-------------|
157
+ | `GCN` | Graph Convolutional Network |
158
+ | `GCNLayer` | Single GCN layer |
159
+ | `GAT` | Graph Attention Network |
160
+ | `GATLayer` | Single GAT layer |
161
+ | `normalize_adjacency` | Symmetric adjacency normalization |
162
+
163
+ ## Design Constraints
164
+
165
+ - `state.step` must be a scalar `jax.Array` (never Python int)
166
+ - Metrics have static string keys and scalar values
167
+ - Must include `"loss"` key in metrics
168
+ - Multi-term + `batch_stats` is not supported
169
+
170
+ ## License
171
+
172
+ MIT
@@ -0,0 +1,65 @@
1
+ """OptTx: JAX/Flax/Optax optimizer library for PINNs and second-order methods."""
2
+
3
+ from .objective import Objective
4
+ from .optimizers import (
5
+ OptaxOptimizer,
6
+ Adam,
7
+ SGD,
8
+ AdamW,
9
+ SOAP,
10
+ Shampoo,
11
+ MUON,
12
+ LBFGSOptimizer,
13
+ CGOptimizer,
14
+ CROptimizer,
15
+ NLTGCROptimizer,
16
+ CrossBatchNLTGCROptimizer,
17
+ TGSOptimizer,
18
+ TGSAccelerator,
19
+ AAAccelerator,
20
+ )
21
+ from .state import TrainState
22
+ from .terms import TermSpec
23
+ from .curvature import build_hessian_matvec, build_fisher_matvec, build_damped_matvec
24
+ from .solvers.cg import cg_solve
25
+ from .solvers.cr import cr_solve
26
+ from .solvers.tgs import tgs_solve_fori
27
+ from .solvers.nltgcr import nltgcr_solve_fori
28
+
29
+ __version__ = "0.1.0a1"
30
+
31
+ __all__ = [
32
+ # First-order optimizers
33
+ "OptaxOptimizer",
34
+ "Adam",
35
+ "SGD",
36
+ "AdamW",
37
+ "SOAP",
38
+ "Shampoo",
39
+ "MUON",
40
+ # Quasi-Newton optimizers
41
+ "LBFGSOptimizer",
42
+ # Second-order optimizers
43
+ "CGOptimizer",
44
+ "CROptimizer",
45
+ "NLTGCROptimizer",
46
+ "CrossBatchNLTGCROptimizer",
47
+ # First-order accelerated
48
+ "TGSOptimizer",
49
+ "TGSAccelerator",
50
+ "AAAccelerator",
51
+ # Curvature matvecs
52
+ "build_hessian_matvec",
53
+ "build_fisher_matvec",
54
+ "build_damped_matvec",
55
+ # Solvers
56
+ "cg_solve",
57
+ "cr_solve",
58
+ "tgs_solve_fori",
59
+ "nltgcr_solve_fori",
60
+ # Core
61
+ "Objective",
62
+ "TermSpec",
63
+ "TrainState",
64
+ "__version__",
65
+ ]
@@ -0,0 +1,80 @@
1
+ """Apply wrapper for Flax modules with method dispatch."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable, Dict, Optional, Sequence, Union
6
+
7
+
8
+ def apply_with_method(
9
+ apply_fn: Callable[..., Any],
10
+ variables: Dict[str, Any],
11
+ *args: Any,
12
+ method: Optional[Union[Callable, str]] = None,
13
+ training: Optional[bool] = None,
14
+ rngs: Optional[Dict[str, Any]] = None,
15
+ mutable: Union[bool, Sequence[str]] = False,
16
+ **kwargs: Any,
17
+ ) -> Any:
18
+ """Apply a Flax module with optional method dispatch.
19
+
20
+ This wrapper enables using different forward methods for multi-term objectives,
21
+ which is essential for PINNs where different terms may compute different outputs
22
+ (e.g., residuals, boundary values, initial conditions).
23
+
24
+ Args:
25
+ apply_fn: The Flax module's apply function
26
+ variables: Model variables dict (must contain "params")
27
+ *args: Positional arguments to pass to the method
28
+ method: Optional method to call. Can be:
29
+ - None: Use default __call__ method
30
+ - str: Name of module method (e.g., "pde_residual")
31
+ - Callable: Custom callable(module, *args) -> output
32
+ training: Whether in training mode. Only passed if explicitly set.
33
+ rngs: Optional RNG keys dict for stochastic layers
34
+ mutable: Collections to mark as mutable (e.g., ["batch_stats"])
35
+ **kwargs: Additional keyword arguments
36
+
37
+ Returns:
38
+ If mutable=False: model output
39
+ If mutable=True/list: (model_output, updates_dict)
40
+
41
+ Example:
42
+ >>> # Default forward (PINNs - no training flag needed)
43
+ >>> out = apply_with_method(model.apply, variables, x)
44
+ >>>
45
+ >>> # Custom method for PDE residual
46
+ >>> residual = apply_with_method(
47
+ ... model.apply, variables, x, method="compute_residual"
48
+ ... )
49
+ >>>
50
+ >>> # CNN with BatchNorm (training flag needed)
51
+ >>> out, updates = apply_with_method(
52
+ ... model.apply, variables, x,
53
+ ... training=True, mutable=["batch_stats"]
54
+ ... )
55
+ """
56
+ # Build apply kwargs - only include training if explicitly set
57
+ apply_kwargs = {}
58
+ if training is not None:
59
+ apply_kwargs["training"] = training
60
+ apply_kwargs.update(kwargs)
61
+
62
+ if rngs is not None:
63
+ apply_kwargs["rngs"] = rngs
64
+
65
+ if mutable:
66
+ apply_kwargs["mutable"] = mutable
67
+
68
+ # Dispatch based on method type
69
+ if method is None:
70
+ # Default __call__
71
+ return apply_fn(variables, *args, **apply_kwargs)
72
+ elif isinstance(method, str):
73
+ # String method name - use Flax's method= argument
74
+ return apply_fn(variables, *args, method=method, **apply_kwargs)
75
+ elif callable(method):
76
+ # Callable method - pass module as first arg
77
+ # Note: Flax apply with method=callable expects the callable to receive module
78
+ return apply_fn(variables, *args, method=method, **apply_kwargs)
79
+ else:
80
+ raise ValueError(f"method must be None, str, or callable, got {type(method)}")