genforge 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- forge/__init__.py +3 -0
- forge/cli.py +143 -0
- forge/configs/config.yaml +22 -0
- forge/configs/control/.gitkeep +1 -0
- forge/configs/control/cbf.yaml +3 -0
- forge/configs/control/fbsde_control.yaml +4 -0
- forge/configs/control/guidance.yaml +4 -0
- forge/configs/control/projection.yaml +2 -0
- forge/configs/control/value_guidance.yaml +5 -0
- forge/configs/cost/.gitkeep +1 -0
- forge/configs/cost/barrier.yaml +4 -0
- forge/configs/cost/halfspace.yaml +4 -0
- forge/configs/cost/reward.yaml +4 -0
- forge/configs/dataset/.gitkeep +1 -0
- forge/configs/environment/.gitkeep +1 -0
- forge/configs/method/.gitkeep +1 -0
- forge/configs/method/conditional.yaml +3 -0
- forge/configs/method/d3pm.yaml +2 -0
- forge/configs/method/ddpm.yaml +2 -0
- forge/configs/method/ddpm_huber.yaml +3 -0
- forge/configs/method/fbsde.yaml +3 -0
- forge/configs/method/flow_matching.yaml +2 -0
- forge/configs/method/mdlm.yaml +2 -0
- forge/configs/method/ot_cfm.yaml +2 -0
- forge/configs/method/sedd.yaml +2 -0
- forge/configs/method/value_training.yaml +4 -0
- forge/configs/model/.gitkeep +1 -0
- forge/configs/model/categorical_mlp.yaml +5 -0
- forge/configs/model/mlp.yaml +6 -0
- forge/configs/model/temporal_unet.yaml +6 -0
- forge/configs/model/temporal_unet_janner.yaml +10 -0
- forge/configs/model/transformer.yaml +8 -0
- forge/configs/model/value_mlp.yaml +5 -0
- forge/configs/preprocessor/.gitkeep +1 -0
- forge/configs/preprocessor/minmax.yaml +2 -0
- forge/configs/preprocessor/standardize.yaml +2 -0
- forge/configs/runner/.gitkeep +1 -0
- forge/configs/runner/planning.yaml +7 -0
- forge/configs/runner/policy_training.yaml +9 -0
- forge/configs/runner/training.yaml +20 -0
- forge/configs/runner/value_training.yaml +7 -0
- forge/configs/sampler/.gitkeep +1 -0
- forge/configs/sampler/ddim.yaml +3 -0
- forge/configs/sampler/ddpm.yaml +2 -0
- forge/configs/sampler/flow.yaml +3 -0
- forge/configs/sampler/interpolant.yaml +3 -0
- forge/configs/sampler/sedd.yaml +2 -0
- forge/configs/sampler/tau_leaping.yaml +2 -0
- forge/configs/schedule/.gitkeep +1 -0
- forge/configs/schedule/absorbing.yaml +3 -0
- forge/configs/schedule/cfm_linear.yaml +3 -0
- forge/configs/schedule/linear_flow.yaml +2 -0
- forge/configs/schedule/si_trig.yaml +2 -0
- forge/configs/schedule/uniform_discrete.yaml +3 -0
- forge/configs/schedule/vp_cosine.yaml +4 -0
- forge/configs/schedule/vp_linear.yaml +4 -0
- forge/configs/space/.gitkeep +1 -0
- forge/configs/space/discrete.yaml +4 -0
- forge/configs/space/euclidean.yaml +3 -0
- forge/configs/visualizer/.gitkeep +1 -0
- forge/configs/visualizer/trajectory.yaml +2 -0
- forge/control/__init__.py +1 -0
- forge/control/cbf.py +35 -0
- forge/control/fbsde_control.py +23 -0
- forge/control/guidance.py +35 -0
- forge/control/projection.py +65 -0
- forge/control/value_guidance.py +79 -0
- forge/core/__init__.py +1 -0
- forge/core/builder.py +236 -0
- forge/core/checkpoint.py +88 -0
- forge/core/compose.py +44 -0
- forge/core/interfaces.py +429 -0
- forge/core/plugins.py +69 -0
- forge/core/protocols.py +143 -0
- forge/core/registry.py +105 -0
- forge/core/resolvers.py +23 -0
- forge/core/types.py +33 -0
- forge/costs/__init__.py +1 -0
- forge/costs/ball.py +44 -0
- forge/costs/barrier.py +40 -0
- forge/costs/box.py +34 -0
- forge/costs/halfspace.py +45 -0
- forge/costs/likelihood.py +32 -0
- forge/costs/reward.py +35 -0
- forge/datasets/__init__.py +7 -0
- forge/environments/__init__.py +7 -0
- forge/methods/__init__.py +1 -0
- forge/methods/conditional.py +54 -0
- forge/methods/d3pm.py +36 -0
- forge/methods/ddpm.py +43 -0
- forge/methods/ddpm_huber.py +48 -0
- forge/methods/fbsde.py +50 -0
- forge/methods/flow_matching.py +40 -0
- forge/methods/mdlm.py +50 -0
- forge/methods/ot_cfm.py +59 -0
- forge/methods/sedd.py +82 -0
- forge/methods/value_training.py +38 -0
- forge/models/__init__.py +1 -0
- forge/models/categorical.py +52 -0
- forge/models/mlp.py +70 -0
- forge/models/temporal_unet.py +104 -0
- forge/models/temporal_unet_janner.py +332 -0
- forge/models/transformer.py +147 -0
- forge/models/value.py +32 -0
- forge/preprocessing/__init__.py +1 -0
- forge/preprocessing/minmax.py +54 -0
- forge/preprocessing/standardize.py +57 -0
- forge/runners/__init__.py +1 -0
- forge/runners/multistep.py +226 -0
- forge/runners/planning.py +66 -0
- forge/runners/policy_training.py +138 -0
- forge/runners/training.py +385 -0
- forge/runners/value_training.py +18 -0
- forge/samplers/__init__.py +1 -0
- forge/samplers/ddim.py +45 -0
- forge/samplers/ddpm.py +56 -0
- forge/samplers/flow.py +55 -0
- forge/samplers/interpolant.py +61 -0
- forge/samplers/sedd.py +45 -0
- forge/samplers/tau_leaping.py +22 -0
- forge/schedules/__init__.py +1 -0
- forge/schedules/discrete.py +146 -0
- forge/schedules/flow.py +125 -0
- forge/schedules/vp.py +110 -0
- forge/spaces/__init__.py +1 -0
- forge/spaces/discrete.py +55 -0
- forge/spaces/euclidean.py +44 -0
- forge/utils/__init__.py +1 -0
- forge/utils/ema.py +82 -0
- forge/utils/logging.py +118 -0
- forge/utils/lora.py +126 -0
- forge/utils/seeding.py +34 -0
- forge/utils/torch_utils.py +23 -0
- forge/visualizations/__init__.py +1 -0
- forge/visualizations/trajectory.py +42 -0
- genforge-0.1.0.dist-info/METADATA +124 -0
- genforge-0.1.0.dist-info/RECORD +140 -0
- genforge-0.1.0.dist-info/WHEEL +4 -0
- genforge-0.1.0.dist-info/entry_points.txt +3 -0
- genforge-0.1.0.dist-info/licenses/LICENSE +21 -0
forge/__init__.py
ADDED
forge/cli.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""The ``forge`` command-line entrypoint: ``list`` / ``train`` / ``sample``.
|
|
2
|
+
|
|
3
|
+
``list`` imports the built-ins so registrations fire, then prints the registered components by
|
|
4
|
+
category. ``train`` / ``sample`` compose a Hydra config from an ``experiment=`` selection, build the
|
|
5
|
+
runner, and run. ``sample checkpoint=<path.pt>`` rebuilds everything from the self-contained
|
|
6
|
+
checkpoint alone (Invariant 5). All three fail loudly on misconfiguration.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import sys
|
|
13
|
+
from typing import Optional, Sequence
|
|
14
|
+
|
|
15
|
+
from .core import registry
|
|
16
|
+
from .core.builder import build, import_builtin_components
|
|
17
|
+
|
|
18
|
+
_EXPERIMENT_HINT = (
|
|
19
|
+
"Select one with `experiment=<env>/<params>/<method>` "
|
|
20
|
+
"(e.g. `forge train experiment=distributions/ddpm/base`)."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _split_overrides(overrides: Sequence[str]) -> dict:
|
|
25
|
+
"""Parse ``key=value`` overrides into a flat dict (first '=')."""
|
|
26
|
+
flat: dict[str, str] = {}
|
|
27
|
+
for o in overrides:
|
|
28
|
+
if "=" in o:
|
|
29
|
+
k, v = o.split("=", 1)
|
|
30
|
+
flat[k] = v
|
|
31
|
+
return flat
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _cmd_list(_args: argparse.Namespace) -> int:
|
|
35
|
+
import_builtin_components()
|
|
36
|
+
# Concrete envs are plugins (no experiment selected here), so import the bundled env packages
|
|
37
|
+
# too — otherwise `list` would omit environments/datasets/env-preprocessors.
|
|
38
|
+
from .core.plugins import load_bundled_envs
|
|
39
|
+
|
|
40
|
+
load_bundled_envs()
|
|
41
|
+
reg = registry.registered()
|
|
42
|
+
print("forge components")
|
|
43
|
+
print("===================")
|
|
44
|
+
for category in registry.CATEGORIES:
|
|
45
|
+
comps = reg.get(category, {})
|
|
46
|
+
names = ", ".join(comps) if comps else "(none yet)"
|
|
47
|
+
print(f" {category:<13} {names}")
|
|
48
|
+
for category in [c for c in reg if c not in registry.CATEGORIES]:
|
|
49
|
+
print(f" {category:<13} {', '.join(reg[category])}")
|
|
50
|
+
return 0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _run_from_config(overrides: Sequence[str], action: str) -> int:
|
|
54
|
+
from omegaconf import OmegaConf
|
|
55
|
+
|
|
56
|
+
from .core.compose import compose_config
|
|
57
|
+
|
|
58
|
+
cfg = compose_config(overrides)
|
|
59
|
+
runner = build(cfg)
|
|
60
|
+
runner.resolved_config = OmegaConf.to_container(cfg, resolve=True)
|
|
61
|
+
|
|
62
|
+
if action == "train":
|
|
63
|
+
runner.train()
|
|
64
|
+
metrics = runner.evaluate()
|
|
65
|
+
print(f"[train] done. eval: {metrics}")
|
|
66
|
+
return 0
|
|
67
|
+
|
|
68
|
+
# sample from an experiment: load its configured checkpoint if present.
|
|
69
|
+
ckpt_path = getattr(runner, "ckpt_path", None)
|
|
70
|
+
if ckpt_path:
|
|
71
|
+
from pathlib import Path
|
|
72
|
+
|
|
73
|
+
from .core.checkpoint import load_checkpoint
|
|
74
|
+
|
|
75
|
+
if not Path(ckpt_path).exists():
|
|
76
|
+
print(
|
|
77
|
+
f"`forge sample` found no checkpoint at {ckpt_path!r}. Train first "
|
|
78
|
+
f"(`forge train experiment=...`) or pass `checkpoint=<path.pt>`.",
|
|
79
|
+
file=sys.stderr,
|
|
80
|
+
)
|
|
81
|
+
return 1
|
|
82
|
+
runner.load_state(load_checkpoint(ckpt_path))
|
|
83
|
+
metrics = runner.evaluate()
|
|
84
|
+
print(f"[sample] {metrics}")
|
|
85
|
+
return 0
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _cmd_train(args: argparse.Namespace) -> int:
|
|
89
|
+
flat = _split_overrides(args.overrides)
|
|
90
|
+
if "experiment" not in flat:
|
|
91
|
+
print(f"`forge train` requires an experiment selection. {_EXPERIMENT_HINT}", file=sys.stderr)
|
|
92
|
+
return 2
|
|
93
|
+
return _run_from_config(args.overrides, "train")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _cmd_sample(args: argparse.Namespace) -> int:
|
|
97
|
+
flat = _split_overrides(args.overrides)
|
|
98
|
+
if "checkpoint" in flat:
|
|
99
|
+
# Self-contained path: rebuild from the .pt alone (Invariant 5).
|
|
100
|
+
from .runners.training import TrainingRunner
|
|
101
|
+
|
|
102
|
+
runner = TrainingRunner.from_checkpoint(flat["checkpoint"], build_fn=build)
|
|
103
|
+
metrics = runner.evaluate()
|
|
104
|
+
print(f"[sample] from checkpoint {flat['checkpoint']}: {metrics}")
|
|
105
|
+
return 0
|
|
106
|
+
if "experiment" not in flat:
|
|
107
|
+
print(
|
|
108
|
+
f"`forge sample` requires `experiment=...` or `checkpoint=<path.pt>`. {_EXPERIMENT_HINT}",
|
|
109
|
+
file=sys.stderr,
|
|
110
|
+
)
|
|
111
|
+
return 2
|
|
112
|
+
return _run_from_config(args.overrides, "sample")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
116
|
+
parser = argparse.ArgumentParser(
|
|
117
|
+
prog="forge",
|
|
118
|
+
description="A unified framework for generative modeling with a clean control layer.",
|
|
119
|
+
)
|
|
120
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
121
|
+
|
|
122
|
+
p_list = sub.add_parser("list", help="List registered components by category.")
|
|
123
|
+
p_list.set_defaults(func=_cmd_list)
|
|
124
|
+
|
|
125
|
+
p_train = sub.add_parser("train", help="Train a model from an experiment config.")
|
|
126
|
+
p_train.add_argument("overrides", nargs="*", help="Hydra-style overrides, e.g. experiment=...")
|
|
127
|
+
p_train.set_defaults(func=_cmd_train)
|
|
128
|
+
|
|
129
|
+
p_sample = sub.add_parser("sample", help="Sample from a trained model or checkpoint.")
|
|
130
|
+
p_sample.add_argument("overrides", nargs="*", help="experiment=... or checkpoint=<path.pt>")
|
|
131
|
+
p_sample.set_defaults(func=_cmd_sample)
|
|
132
|
+
|
|
133
|
+
return parser
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def main(argv: Optional[Sequence[str]] = None) -> int:
|
|
137
|
+
parser = _build_parser()
|
|
138
|
+
args = parser.parse_args(argv)
|
|
139
|
+
return args.func(args)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
if __name__ == "__main__":
|
|
143
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# genforge root config (Hydra base+delta).
|
|
2
|
+
#
|
|
3
|
+
# Per-category config groups live in the sibling directories (space/, schedule/, model/, ...).
|
|
4
|
+
# Experiments are base+delta bundles in the repo-root `experiment/` tree, which is added to the
|
|
5
|
+
# Hydra searchpath at runtime by the CLI (env var GENFORGE_EXP_ROOT). Select an experiment with
|
|
6
|
+
# `experiment=<env>/<params>/<method>` — it composes the component groups it needs.
|
|
7
|
+
|
|
8
|
+
defaults:
|
|
9
|
+
- _self_
|
|
10
|
+
- experiment: ??? # mandatory: train/sample require an experiment selection
|
|
11
|
+
|
|
12
|
+
# Global run settings (resolved into the checkpoint, Invariant 5).
|
|
13
|
+
seed: 0
|
|
14
|
+
|
|
15
|
+
hydra:
|
|
16
|
+
searchpath:
|
|
17
|
+
- file://${oc.env:GENFORGE_EXP_ROOT}
|
|
18
|
+
job:
|
|
19
|
+
chdir: false
|
|
20
|
+
output_subdir: null
|
|
21
|
+
run:
|
|
22
|
+
dir: .
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `control`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `cost`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `dataset`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `environment`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `method`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `model`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `preprocessor`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `runner`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
name: training
|
|
2
|
+
params:
|
|
3
|
+
steps: 2000
|
|
4
|
+
batch_size: 256
|
|
5
|
+
lr: 1.0e-3
|
|
6
|
+
ema_decay: 0.999
|
|
7
|
+
n_sample_steps: 100
|
|
8
|
+
n_eval_samples: 2000
|
|
9
|
+
eval_radius: 0.6
|
|
10
|
+
device: cpu
|
|
11
|
+
seed: 0
|
|
12
|
+
# Optional experiment logging — needs the `logging` extra (pip install genforge[logging]).
|
|
13
|
+
# Absent/false => no-op logger + plain loop; a default run is unchanged. Any runner accepts
|
|
14
|
+
# these via CLI too, e.g. runner.params.log.wandb=true (or env FORGE_WANDB=1).
|
|
15
|
+
log:
|
|
16
|
+
wandb: false # off by default; no wandb import, no login prompt
|
|
17
|
+
project: forge
|
|
18
|
+
mode: null # online | offline | disabled
|
|
19
|
+
name: null # run name; defaults to <method>-<steps>steps
|
|
20
|
+
progress: false # tqdm bar; plain loop if tqdm absent or stderr is non-TTY
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `sampler`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `schedule`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `space`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Phase 0: empty config group for `visualizer`. Concrete options arrive in later phases.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""The control layer: HOW you approximate the tilt Q ∝ exp(log_h)·P (the thesis surface)."""
|
forge/control/cbf.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""CBF — a control-barrier-function safety filter on the reverse DRIFT (the rate ẋ).
|
|
2
|
+
|
|
3
|
+
The first real `drift`-surface controller (Invariant 6), and the proof the surface works: a CBF
|
|
4
|
+
*reads the rate*, so it has no faithful x̂₀ reduction. Given a barrier ``h(x)≥0``, it minimally
|
|
5
|
+
edits the drift so the barrier's forward-invariance condition ``ḣ = ∇h·ẋ ≥ −α·h(x)`` holds. For one
|
|
6
|
+
linear barrier this is the closed-form CBF-QP solution (project the drift onto the safe halfspace):
|
|
7
|
+
|
|
8
|
+
slack = ∇h·drift + α·h ; if slack < 0: drift += (−slack/‖∇h‖²)·∇h (else unchanged).
|
|
9
|
+
|
|
10
|
+
With the data-ward heading drift and α=1 this keeps the clean estimate feasible every step
|
|
11
|
+
(``h(x̂₀') ≥ (1−α)·h(xₜ) = 0``), so the formed sample lands in the feasible set.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from ..core.interfaces import Controller
|
|
19
|
+
from ..core.registry import register
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@register("control", "cbf")
|
|
23
|
+
class CBF(Controller):
|
|
24
|
+
surface = "drift"
|
|
25
|
+
|
|
26
|
+
def __init__(self, cost, alpha: float = 1.0):
|
|
27
|
+
super().__init__(cost)
|
|
28
|
+
self.alpha = float(alpha)
|
|
29
|
+
|
|
30
|
+
def modify_drift(self, drift: torch.Tensor, x: torch.Tensor, t, schedule) -> torch.Tensor:
|
|
31
|
+
grad = self.cost.grad_h(x) # ∇h(x), (..., dim)
|
|
32
|
+
h = self.cost.value(x) # h(x), (...,)
|
|
33
|
+
slack = (grad * drift).sum(dim=-1) + self.alpha * h # ḣ + α·h ≥ 0 required
|
|
34
|
+
coef = torch.clamp(-slack, min=0.0) / (grad * grad).sum(dim=-1).clamp_min(1e-12)
|
|
35
|
+
return drift + coef.unsqueeze(-1) * grad # minimal safe correction
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""FBSDEControl — consume an FBSDE-learned value (cost-to-go) to steer sampling.
|
|
2
|
+
|
|
3
|
+
Same amortized machinery as `ValueGuidance` (loads the value from a checkpoint, never imports the
|
|
4
|
+
method), but descends the cost-to-go: the optimal control follows ``−∇V`` (``sign = −1``).
|
|
5
|
+
|
|
6
|
+
FBSDE control belongs on the ``drift`` surface (``u* = Gᵀ∇V`` added to b_θ). This implementation
|
|
7
|
+
keeps it on the inherited ``x0`` surface (a −∇V shift of x̂₀) with identical behavior; drift-level
|
|
8
|
+
FBSDE is future work (it depends on the dual-surface refactor landing first).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from ..core.registry import register
|
|
14
|
+
from .value_guidance import AmortizedValueController
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register("control", "fbsde_control")
|
|
18
|
+
class FBSDEControl(AmortizedValueController):
|
|
19
|
+
surface = "x0" # TODO: reclassify to the "drift" surface.
|
|
20
|
+
|
|
21
|
+
def __init__(self, value_checkpoint: str, cost=None, scale: float = 1.0, sigma_weight: bool = True):
|
|
22
|
+
# Descend the cost-to-go: optimal control is −∇V.
|
|
23
|
+
super().__init__(value_checkpoint, cost=cost, scale=scale, sigma_weight=sigma_weight, sign=-1.0)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Guidance controller — first-order ∇log h on the clean-sample estimate (DPS/RePaint style).
|
|
2
|
+
|
|
3
|
+
Each reverse step, nudge x̂₀ along ∇_{x̂₀} log h (computed by autograd), never the noisy iterate.
|
|
4
|
+
The step is scaled by ``scale`` and, optionally, by σ(t)² so the correction is gentler near the data
|
|
5
|
+
(t→0) — a soft tilt rather than a hard projection. The base model is never touched (Invariant 6).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from ..core.interfaces import Controller
|
|
13
|
+
from ..core.registry import register
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register("control", "guidance")
|
|
17
|
+
class Guidance(Controller):
|
|
18
|
+
surface = "x0"
|
|
19
|
+
|
|
20
|
+
def __init__(self, cost, scale: float = 1.0, sigma_weight: bool = True):
|
|
21
|
+
super().__init__(cost)
|
|
22
|
+
self.scale = float(scale)
|
|
23
|
+
self.sigma_weight = bool(sigma_weight)
|
|
24
|
+
|
|
25
|
+
def modify_x0(self, x0_hat: torch.Tensor, x: torch.Tensor, t, schedule) -> torch.Tensor:
|
|
26
|
+
with torch.enable_grad():
|
|
27
|
+
z = x0_hat.detach().requires_grad_(True)
|
|
28
|
+
lh = self.cost.log_h(z, t).sum()
|
|
29
|
+
(grad,) = torch.autograd.grad(lh, z)
|
|
30
|
+
step = self.scale
|
|
31
|
+
if self.sigma_weight:
|
|
32
|
+
# σ(t)² fades the correction as the estimate sharpens toward the data manifold.
|
|
33
|
+
sigma = torch.as_tensor(schedule.sigma(t), device=x0_hat.device)
|
|
34
|
+
step = step * (sigma**2)
|
|
35
|
+
return (x0_hat + step * grad).detach()
|