genforge 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. forge/__init__.py +3 -0
  2. forge/cli.py +143 -0
  3. forge/configs/config.yaml +22 -0
  4. forge/configs/control/.gitkeep +1 -0
  5. forge/configs/control/cbf.yaml +3 -0
  6. forge/configs/control/fbsde_control.yaml +4 -0
  7. forge/configs/control/guidance.yaml +4 -0
  8. forge/configs/control/projection.yaml +2 -0
  9. forge/configs/control/value_guidance.yaml +5 -0
  10. forge/configs/cost/.gitkeep +1 -0
  11. forge/configs/cost/barrier.yaml +4 -0
  12. forge/configs/cost/halfspace.yaml +4 -0
  13. forge/configs/cost/reward.yaml +4 -0
  14. forge/configs/dataset/.gitkeep +1 -0
  15. forge/configs/environment/.gitkeep +1 -0
  16. forge/configs/method/.gitkeep +1 -0
  17. forge/configs/method/conditional.yaml +3 -0
  18. forge/configs/method/d3pm.yaml +2 -0
  19. forge/configs/method/ddpm.yaml +2 -0
  20. forge/configs/method/ddpm_huber.yaml +3 -0
  21. forge/configs/method/fbsde.yaml +3 -0
  22. forge/configs/method/flow_matching.yaml +2 -0
  23. forge/configs/method/mdlm.yaml +2 -0
  24. forge/configs/method/ot_cfm.yaml +2 -0
  25. forge/configs/method/sedd.yaml +2 -0
  26. forge/configs/method/value_training.yaml +4 -0
  27. forge/configs/model/.gitkeep +1 -0
  28. forge/configs/model/categorical_mlp.yaml +5 -0
  29. forge/configs/model/mlp.yaml +6 -0
  30. forge/configs/model/temporal_unet.yaml +6 -0
  31. forge/configs/model/temporal_unet_janner.yaml +10 -0
  32. forge/configs/model/transformer.yaml +8 -0
  33. forge/configs/model/value_mlp.yaml +5 -0
  34. forge/configs/preprocessor/.gitkeep +1 -0
  35. forge/configs/preprocessor/minmax.yaml +2 -0
  36. forge/configs/preprocessor/standardize.yaml +2 -0
  37. forge/configs/runner/.gitkeep +1 -0
  38. forge/configs/runner/planning.yaml +7 -0
  39. forge/configs/runner/policy_training.yaml +9 -0
  40. forge/configs/runner/training.yaml +20 -0
  41. forge/configs/runner/value_training.yaml +7 -0
  42. forge/configs/sampler/.gitkeep +1 -0
  43. forge/configs/sampler/ddim.yaml +3 -0
  44. forge/configs/sampler/ddpm.yaml +2 -0
  45. forge/configs/sampler/flow.yaml +3 -0
  46. forge/configs/sampler/interpolant.yaml +3 -0
  47. forge/configs/sampler/sedd.yaml +2 -0
  48. forge/configs/sampler/tau_leaping.yaml +2 -0
  49. forge/configs/schedule/.gitkeep +1 -0
  50. forge/configs/schedule/absorbing.yaml +3 -0
  51. forge/configs/schedule/cfm_linear.yaml +3 -0
  52. forge/configs/schedule/linear_flow.yaml +2 -0
  53. forge/configs/schedule/si_trig.yaml +2 -0
  54. forge/configs/schedule/uniform_discrete.yaml +3 -0
  55. forge/configs/schedule/vp_cosine.yaml +4 -0
  56. forge/configs/schedule/vp_linear.yaml +4 -0
  57. forge/configs/space/.gitkeep +1 -0
  58. forge/configs/space/discrete.yaml +4 -0
  59. forge/configs/space/euclidean.yaml +3 -0
  60. forge/configs/visualizer/.gitkeep +1 -0
  61. forge/configs/visualizer/trajectory.yaml +2 -0
  62. forge/control/__init__.py +1 -0
  63. forge/control/cbf.py +35 -0
  64. forge/control/fbsde_control.py +23 -0
  65. forge/control/guidance.py +35 -0
  66. forge/control/projection.py +65 -0
  67. forge/control/value_guidance.py +79 -0
  68. forge/core/__init__.py +1 -0
  69. forge/core/builder.py +236 -0
  70. forge/core/checkpoint.py +88 -0
  71. forge/core/compose.py +44 -0
  72. forge/core/interfaces.py +429 -0
  73. forge/core/plugins.py +69 -0
  74. forge/core/protocols.py +143 -0
  75. forge/core/registry.py +105 -0
  76. forge/core/resolvers.py +23 -0
  77. forge/core/types.py +33 -0
  78. forge/costs/__init__.py +1 -0
  79. forge/costs/ball.py +44 -0
  80. forge/costs/barrier.py +40 -0
  81. forge/costs/box.py +34 -0
  82. forge/costs/halfspace.py +45 -0
  83. forge/costs/likelihood.py +32 -0
  84. forge/costs/reward.py +35 -0
  85. forge/datasets/__init__.py +7 -0
  86. forge/environments/__init__.py +7 -0
  87. forge/methods/__init__.py +1 -0
  88. forge/methods/conditional.py +54 -0
  89. forge/methods/d3pm.py +36 -0
  90. forge/methods/ddpm.py +43 -0
  91. forge/methods/ddpm_huber.py +48 -0
  92. forge/methods/fbsde.py +50 -0
  93. forge/methods/flow_matching.py +40 -0
  94. forge/methods/mdlm.py +50 -0
  95. forge/methods/ot_cfm.py +59 -0
  96. forge/methods/sedd.py +82 -0
  97. forge/methods/value_training.py +38 -0
  98. forge/models/__init__.py +1 -0
  99. forge/models/categorical.py +52 -0
  100. forge/models/mlp.py +70 -0
  101. forge/models/temporal_unet.py +104 -0
  102. forge/models/temporal_unet_janner.py +332 -0
  103. forge/models/transformer.py +147 -0
  104. forge/models/value.py +32 -0
  105. forge/preprocessing/__init__.py +1 -0
  106. forge/preprocessing/minmax.py +54 -0
  107. forge/preprocessing/standardize.py +57 -0
  108. forge/runners/__init__.py +1 -0
  109. forge/runners/multistep.py +226 -0
  110. forge/runners/planning.py +66 -0
  111. forge/runners/policy_training.py +138 -0
  112. forge/runners/training.py +385 -0
  113. forge/runners/value_training.py +18 -0
  114. forge/samplers/__init__.py +1 -0
  115. forge/samplers/ddim.py +45 -0
  116. forge/samplers/ddpm.py +56 -0
  117. forge/samplers/flow.py +55 -0
  118. forge/samplers/interpolant.py +61 -0
  119. forge/samplers/sedd.py +45 -0
  120. forge/samplers/tau_leaping.py +22 -0
  121. forge/schedules/__init__.py +1 -0
  122. forge/schedules/discrete.py +146 -0
  123. forge/schedules/flow.py +125 -0
  124. forge/schedules/vp.py +110 -0
  125. forge/spaces/__init__.py +1 -0
  126. forge/spaces/discrete.py +55 -0
  127. forge/spaces/euclidean.py +44 -0
  128. forge/utils/__init__.py +1 -0
  129. forge/utils/ema.py +82 -0
  130. forge/utils/logging.py +118 -0
  131. forge/utils/lora.py +126 -0
  132. forge/utils/seeding.py +34 -0
  133. forge/utils/torch_utils.py +23 -0
  134. forge/visualizations/__init__.py +1 -0
  135. forge/visualizations/trajectory.py +42 -0
  136. genforge-0.1.0.dist-info/METADATA +124 -0
  137. genforge-0.1.0.dist-info/RECORD +140 -0
  138. genforge-0.1.0.dist-info/WHEEL +4 -0
  139. genforge-0.1.0.dist-info/entry_points.txt +3 -0
  140. genforge-0.1.0.dist-info/licenses/LICENSE +21 -0
forge/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """genforge — a unified, PyTorch-based framework for generative-modeling techniques."""
2
+
3
+ __version__ = "0.1.0"
forge/cli.py ADDED
@@ -0,0 +1,143 @@
1
+ """The ``forge`` command-line entrypoint: ``list`` / ``train`` / ``sample``.
2
+
3
+ ``list`` imports the built-ins so registrations fire, then prints the registered components by
4
+ category. ``train`` / ``sample`` compose a Hydra config from an ``experiment=`` selection, build the
5
+ runner, and run. ``sample checkpoint=<path.pt>`` rebuilds everything from the self-contained
6
+ checkpoint alone (Invariant 5). All three fail loudly on misconfiguration.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import sys
13
+ from typing import Optional, Sequence
14
+
15
+ from .core import registry
16
+ from .core.builder import build, import_builtin_components
17
+
18
+ _EXPERIMENT_HINT = (
19
+ "Select one with `experiment=<env>/<params>/<method>` "
20
+ "(e.g. `forge train experiment=distributions/ddpm/base`)."
21
+ )
22
+
23
+
24
+ def _split_overrides(overrides: Sequence[str]) -> dict:
25
+ """Parse ``key=value`` overrides into a flat dict (first '=')."""
26
+ flat: dict[str, str] = {}
27
+ for o in overrides:
28
+ if "=" in o:
29
+ k, v = o.split("=", 1)
30
+ flat[k] = v
31
+ return flat
32
+
33
+
34
+ def _cmd_list(_args: argparse.Namespace) -> int:
35
+ import_builtin_components()
36
+ # Concrete envs are plugins (no experiment selected here), so import the bundled env packages
37
+ # too — otherwise `list` would omit environments/datasets/env-preprocessors.
38
+ from .core.plugins import load_bundled_envs
39
+
40
+ load_bundled_envs()
41
+ reg = registry.registered()
42
+ print("forge components")
43
+ print("===================")
44
+ for category in registry.CATEGORIES:
45
+ comps = reg.get(category, {})
46
+ names = ", ".join(comps) if comps else "(none yet)"
47
+ print(f" {category:<13} {names}")
48
+ for category in [c for c in reg if c not in registry.CATEGORIES]:
49
+ print(f" {category:<13} {', '.join(reg[category])}")
50
+ return 0
51
+
52
+
53
+ def _run_from_config(overrides: Sequence[str], action: str) -> int:
54
+ from omegaconf import OmegaConf
55
+
56
+ from .core.compose import compose_config
57
+
58
+ cfg = compose_config(overrides)
59
+ runner = build(cfg)
60
+ runner.resolved_config = OmegaConf.to_container(cfg, resolve=True)
61
+
62
+ if action == "train":
63
+ runner.train()
64
+ metrics = runner.evaluate()
65
+ print(f"[train] done. eval: {metrics}")
66
+ return 0
67
+
68
+ # sample from an experiment: load its configured checkpoint if present.
69
+ ckpt_path = getattr(runner, "ckpt_path", None)
70
+ if ckpt_path:
71
+ from pathlib import Path
72
+
73
+ from .core.checkpoint import load_checkpoint
74
+
75
+ if not Path(ckpt_path).exists():
76
+ print(
77
+ f"`forge sample` found no checkpoint at {ckpt_path!r}. Train first "
78
+ f"(`forge train experiment=...`) or pass `checkpoint=<path.pt>`.",
79
+ file=sys.stderr,
80
+ )
81
+ return 1
82
+ runner.load_state(load_checkpoint(ckpt_path))
83
+ metrics = runner.evaluate()
84
+ print(f"[sample] {metrics}")
85
+ return 0
86
+
87
+
88
+ def _cmd_train(args: argparse.Namespace) -> int:
89
+ flat = _split_overrides(args.overrides)
90
+ if "experiment" not in flat:
91
+ print(f"`forge train` requires an experiment selection. {_EXPERIMENT_HINT}", file=sys.stderr)
92
+ return 2
93
+ return _run_from_config(args.overrides, "train")
94
+
95
+
96
+ def _cmd_sample(args: argparse.Namespace) -> int:
97
+ flat = _split_overrides(args.overrides)
98
+ if "checkpoint" in flat:
99
+ # Self-contained path: rebuild from the .pt alone (Invariant 5).
100
+ from .runners.training import TrainingRunner
101
+
102
+ runner = TrainingRunner.from_checkpoint(flat["checkpoint"], build_fn=build)
103
+ metrics = runner.evaluate()
104
+ print(f"[sample] from checkpoint {flat['checkpoint']}: {metrics}")
105
+ return 0
106
+ if "experiment" not in flat:
107
+ print(
108
+ f"`forge sample` requires `experiment=...` or `checkpoint=<path.pt>`. {_EXPERIMENT_HINT}",
109
+ file=sys.stderr,
110
+ )
111
+ return 2
112
+ return _run_from_config(args.overrides, "sample")
113
+
114
+
115
+ def _build_parser() -> argparse.ArgumentParser:
116
+ parser = argparse.ArgumentParser(
117
+ prog="forge",
118
+ description="A unified framework for generative modeling with a clean control layer.",
119
+ )
120
+ sub = parser.add_subparsers(dest="command", required=True)
121
+
122
+ p_list = sub.add_parser("list", help="List registered components by category.")
123
+ p_list.set_defaults(func=_cmd_list)
124
+
125
+ p_train = sub.add_parser("train", help="Train a model from an experiment config.")
126
+ p_train.add_argument("overrides", nargs="*", help="Hydra-style overrides, e.g. experiment=...")
127
+ p_train.set_defaults(func=_cmd_train)
128
+
129
+ p_sample = sub.add_parser("sample", help="Sample from a trained model or checkpoint.")
130
+ p_sample.add_argument("overrides", nargs="*", help="experiment=... or checkpoint=<path.pt>")
131
+ p_sample.set_defaults(func=_cmd_sample)
132
+
133
+ return parser
134
+
135
+
136
+ def main(argv: Optional[Sequence[str]] = None) -> int:
137
+ parser = _build_parser()
138
+ args = parser.parse_args(argv)
139
+ return args.func(args)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ raise SystemExit(main())
@@ -0,0 +1,22 @@
1
+ # genforge root config (Hydra base+delta).
2
+ #
3
+ # Per-category config groups live in the sibling directories (space/, schedule/, model/, ...).
4
+ # Experiments are base+delta bundles in the repo-root `experiment/` tree, which is added to the
5
+ # Hydra searchpath at runtime by the CLI (env var GENFORGE_EXP_ROOT). Select an experiment with
6
+ # `experiment=<env>/<params>/<method>` — it composes the component groups it needs.
7
+
8
+ defaults:
9
+ - _self_
10
+ - experiment: ??? # mandatory: train/sample require an experiment selection
11
+
12
+ # Global run settings (resolved into the checkpoint, Invariant 5).
13
+ seed: 0
14
+
15
+ hydra:
16
+ searchpath:
17
+ - file://${oc.env:GENFORGE_EXP_ROOT}
18
+ job:
19
+ chdir: false
20
+ output_subdir: null
21
+ run:
22
+ dir: .
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `control`. Concrete options arrive in later phases.
@@ -0,0 +1,3 @@
1
+ name: cbf
2
+ params:
3
+ alpha: 1.0
@@ -0,0 +1,4 @@
1
+ name: fbsde_control
2
+ params:
3
+ value_checkpoint: checkpoints/lq/fbsde/values.pt
4
+ scale: 1.0
@@ -0,0 +1,4 @@
1
+ name: guidance
2
+ params:
3
+ scale: 2.0
4
+ sigma_weight: true
@@ -0,0 +1,2 @@
1
+ name: projection
2
+ params: {}
@@ -0,0 +1,5 @@
1
+ name: value_guidance
2
+ params:
3
+ value_checkpoint: checkpoints/distributions/value/values.pt
4
+ scale: 3.0
5
+ sigma_weight: true
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `cost`. Concrete options arrive in later phases.
@@ -0,0 +1,4 @@
1
+ name: barrier
2
+ params:
3
+ normal: [1.0, 0.0]
4
+ offset: 0.0
@@ -0,0 +1,4 @@
1
+ name: halfspace
2
+ params:
3
+ normal: [1.0, 0.0]
4
+ offset: 0.0
@@ -0,0 +1,4 @@
1
+ name: reward
2
+ params:
3
+ target: [2.0, 0.0]
4
+ weight: 1.0
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `dataset`. Concrete options arrive in later phases.
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `environment`. Concrete options arrive in later phases.
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `method`. Concrete options arrive in later phases.
@@ -0,0 +1,3 @@
1
+ name: conditional
2
+ params:
3
+ pin_positions: [0, -1]
@@ -0,0 +1,2 @@
1
+ name: d3pm
2
+ params: {}
@@ -0,0 +1,2 @@
1
+ name: ddpm
2
+ params: {}
@@ -0,0 +1,3 @@
1
+ name: ddpm_huber
2
+ params:
3
+ delta: 1.0 # residual threshold; |r| > delta is penalized linearly instead of quadratically
@@ -0,0 +1,3 @@
1
+ name: fbsde
2
+ params:
3
+ q: 4.0
@@ -0,0 +1,2 @@
1
+ name: flow_matching
2
+ params: {}
@@ -0,0 +1,2 @@
1
+ name: mdlm
2
+ params: {}
@@ -0,0 +1,2 @@
1
+ name: ot_cfm
2
+ params: {}
@@ -0,0 +1,2 @@
1
+ name: sedd
2
+ params: {}
@@ -0,0 +1,4 @@
1
+ name: value_training
2
+ params:
3
+ target: [2.0, 0.0]
4
+ weight: 1.0
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `model`. Concrete options arrive in later phases.
@@ -0,0 +1,5 @@
1
+ name: categorical_mlp
2
+ params:
3
+ num_classes: 5
4
+ hidden: 128
5
+ depth: 3
@@ -0,0 +1,6 @@
1
+ name: mlp
2
+ params:
3
+ dim: 2
4
+ hidden: 128
5
+ depth: 3
6
+ output_type: eps
@@ -0,0 +1,6 @@
1
+ name: temporal_unet
2
+ params:
3
+ dim: 2
4
+ horizon: 32
5
+ base: 32
6
+ output_type: eps
@@ -0,0 +1,10 @@
1
+ name: temporal_unet_janner
2
+ params:
3
+ dim: 32
4
+ dim_mults: [1, 2, 4, 8]
5
+ horizon: 32
6
+ transition_dim: 14
7
+ cond_dim: 0
8
+ cond_predict_scale: false
9
+ output_type: x0
10
+ obs_dim: 0
@@ -0,0 +1,8 @@
1
+ name: transformer
2
+ params:
3
+ vocab_size: 32 # = char_text unique chars (31) + 1 mask token
4
+ length: 32
5
+ d_model: 128
6
+ depth: 4
7
+ n_heads: 4
8
+ output_type: logits
@@ -0,0 +1,5 @@
1
+ name: value_mlp
2
+ params:
3
+ dim: 2
4
+ hidden: 128
5
+ depth: 3
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `preprocessor`. Concrete options arrive in later phases.
@@ -0,0 +1,2 @@
1
+ name: minmax
2
+ params: {}
@@ -0,0 +1,2 @@
1
+ name: standardize
2
+ params: {}
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `runner`. Concrete options arrive in later phases.
@@ -0,0 +1,7 @@
1
+ name: planning
2
+ params:
3
+ steps: 2000
4
+ batch_size: 128
5
+ lr: 1.0e-3
6
+ device: cpu
7
+ seed: 0
@@ -0,0 +1,9 @@
1
+ name: policy_training
2
+ params:
3
+ steps: 100000
4
+ batch_size: 256
5
+ lr: 1.0e-4
6
+ ema_decay: 0.9999
7
+ n_sample_steps: 100
8
+ device: cuda
9
+ seed: 0
@@ -0,0 +1,20 @@
1
+ name: training
2
+ params:
3
+ steps: 2000
4
+ batch_size: 256
5
+ lr: 1.0e-3
6
+ ema_decay: 0.999
7
+ n_sample_steps: 100
8
+ n_eval_samples: 2000
9
+ eval_radius: 0.6
10
+ device: cpu
11
+ seed: 0
12
+ # Optional experiment logging — needs the `logging` extra (pip install genforge[logging]).
13
+ # Absent/false => no-op logger + plain loop; a default run is unchanged. Any runner accepts
14
+ # these via CLI too, e.g. runner.params.log.wandb=true (or env FORGE_WANDB=1).
15
+ log:
16
+ wandb: false # off by default; no wandb import, no login prompt
17
+ project: forge
18
+ mode: null # online | offline | disabled
19
+ name: null # run name; defaults to <method>-<steps>steps
20
+ progress: false # tqdm bar; plain loop if tqdm absent or stderr is non-TTY
@@ -0,0 +1,7 @@
1
+ name: value_training
2
+ params:
3
+ steps: 2000
4
+ batch_size: 512
5
+ lr: 1.0e-3
6
+ device: cpu
7
+ seed: 0
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `sampler`. Concrete options arrive in later phases.
@@ -0,0 +1,3 @@
1
+ name: ddim
2
+ params:
3
+ eta: 0.0
@@ -0,0 +1,2 @@
1
+ name: ddpm
2
+ params: {}
@@ -0,0 +1,3 @@
1
+ name: flow
2
+ params:
3
+ integrator: heun
@@ -0,0 +1,3 @@
1
+ name: interpolant
2
+ params:
3
+ epsilon: 1.0 # free diffusion-coefficient scale; ε(t)=epsilon·σ(t)². 0 ⇒ probability-flow ODE.
@@ -0,0 +1,2 @@
1
+ name: sedd
2
+ params: {}
@@ -0,0 +1,2 @@
1
+ name: tau_leaping
2
+ params: {}
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `schedule`. Concrete options arrive in later phases.
@@ -0,0 +1,3 @@
1
+ name: absorbing
2
+ params:
3
+ num_classes: 5
@@ -0,0 +1,3 @@
1
+ name: cfm_linear
2
+ params:
3
+ sigma_min: 0.01
@@ -0,0 +1,2 @@
1
+ name: linear_flow
2
+ params: {}
@@ -0,0 +1,2 @@
1
+ name: si_trig
2
+ params: {}
@@ -0,0 +1,3 @@
1
+ name: uniform_discrete
2
+ params:
3
+ num_classes: 4
@@ -0,0 +1,4 @@
1
+ name: vp_cosine
2
+ params:
3
+ s: 0.008
4
+ parameterization: sqrt_alpha_bar
@@ -0,0 +1,4 @@
1
+ name: vp_linear
2
+ params:
3
+ beta_min: 0.1
4
+ beta_max: 20.0
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `space`. Concrete options arrive in later phases.
@@ -0,0 +1,4 @@
1
+ name: discrete
2
+ params:
3
+ num_classes: 5
4
+ length: 1
@@ -0,0 +1,3 @@
1
+ name: euclidean
2
+ params:
3
+ dim: 2
@@ -0,0 +1 @@
1
+ # Phase 0: empty config group for `visualizer`. Concrete options arrive in later phases.
@@ -0,0 +1,2 @@
1
+ name: trajectory
2
+ params: {}
@@ -0,0 +1 @@
1
+ """The control layer: HOW you approximate the tilt Q ∝ exp(log_h)·P (the thesis surface)."""
forge/control/cbf.py ADDED
@@ -0,0 +1,35 @@
1
+ """CBF — a control-barrier-function safety filter on the reverse DRIFT (the rate ẋ).
2
+
3
+ The first real `drift`-surface controller (Invariant 6), and the proof the surface works: a CBF
4
+ *reads the rate*, so it has no faithful x̂₀ reduction. Given a barrier ``h(x)≥0``, it minimally
5
+ edits the drift so the barrier's forward-invariance condition ``ḣ = ∇h·ẋ ≥ −α·h(x)`` holds. For one
6
+ linear barrier this is the closed-form CBF-QP solution (project the drift onto the safe halfspace):
7
+
8
+ slack = ∇h·drift + α·h ; if slack < 0: drift += (−slack/‖∇h‖²)·∇h (else unchanged).
9
+
10
+ With the data-ward heading drift and α=1 this keeps the clean estimate feasible every step
11
+ (``h(x̂₀') ≥ (1−α)·h(xₜ) = 0``), so the formed sample lands in the feasible set.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+
18
+ from ..core.interfaces import Controller
19
+ from ..core.registry import register
20
+
21
+
22
+ @register("control", "cbf")
23
+ class CBF(Controller):
24
+ surface = "drift"
25
+
26
+ def __init__(self, cost, alpha: float = 1.0):
27
+ super().__init__(cost)
28
+ self.alpha = float(alpha)
29
+
30
+ def modify_drift(self, drift: torch.Tensor, x: torch.Tensor, t, schedule) -> torch.Tensor:
31
+ grad = self.cost.grad_h(x) # ∇h(x), (..., dim)
32
+ h = self.cost.value(x) # h(x), (...,)
33
+ slack = (grad * drift).sum(dim=-1) + self.alpha * h # ḣ + α·h ≥ 0 required
34
+ coef = torch.clamp(-slack, min=0.0) / (grad * grad).sum(dim=-1).clamp_min(1e-12)
35
+ return drift + coef.unsqueeze(-1) * grad # minimal safe correction
@@ -0,0 +1,23 @@
1
+ """FBSDEControl — consume an FBSDE-learned value (cost-to-go) to steer sampling.
2
+
3
+ Same amortized machinery as `ValueGuidance` (loads the value from a checkpoint, never imports the
4
+ method), but descends the cost-to-go: the optimal control follows ``−∇V`` (``sign = −1``).
5
+
6
+ FBSDE control belongs on the ``drift`` surface (``u* = Gᵀ∇V`` added to b_θ). This implementation
7
+ keeps it on the inherited ``x0`` surface (a −∇V shift of x̂₀) with identical behavior; drift-level
8
+ FBSDE is future work (it depends on the dual-surface refactor landing first).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from ..core.registry import register
14
+ from .value_guidance import AmortizedValueController
15
+
16
+
17
+ @register("control", "fbsde_control")
18
+ class FBSDEControl(AmortizedValueController):
19
+ surface = "x0" # TODO: reclassify to the "drift" surface.
20
+
21
+ def __init__(self, value_checkpoint: str, cost=None, scale: float = 1.0, sigma_weight: bool = True):
22
+ # Descend the cost-to-go: optimal control is −∇V.
23
+ super().__init__(value_checkpoint, cost=cost, scale=scale, sigma_weight=sigma_weight, sign=-1.0)
@@ -0,0 +1,35 @@
1
+ """Guidance controller — first-order ∇log h on the clean-sample estimate (DPS/RePaint style).
2
+
3
+ Each reverse step, nudge x̂₀ along ∇_{x̂₀} log h (computed by autograd), never the noisy iterate.
4
+ The step is scaled by ``scale`` and, optionally, by σ(t)² so the correction is gentler near the data
5
+ (t→0) — a soft tilt rather than a hard projection. The base model is never touched (Invariant 6).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+
12
+ from ..core.interfaces import Controller
13
+ from ..core.registry import register
14
+
15
+
16
+ @register("control", "guidance")
17
+ class Guidance(Controller):
18
+ surface = "x0"
19
+
20
+ def __init__(self, cost, scale: float = 1.0, sigma_weight: bool = True):
21
+ super().__init__(cost)
22
+ self.scale = float(scale)
23
+ self.sigma_weight = bool(sigma_weight)
24
+
25
+ def modify_x0(self, x0_hat: torch.Tensor, x: torch.Tensor, t, schedule) -> torch.Tensor:
26
+ with torch.enable_grad():
27
+ z = x0_hat.detach().requires_grad_(True)
28
+ lh = self.cost.log_h(z, t).sum()
29
+ (grad,) = torch.autograd.grad(lh, z)
30
+ step = self.scale
31
+ if self.sigma_weight:
32
+ # σ(t)² fades the correction as the estimate sharpens toward the data manifold.
33
+ sigma = torch.as_tensor(schedule.sigma(t), device=x0_hat.device)
34
+ step = step * (sigma**2)
35
+ return (x0_hat + step * grad).detach()