flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,260 @@
|
|
1
|
+
import jax
|
2
|
+
import flax.linen as nn
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from typing import Optional, Dict, Any, Union, List, Tuple, Type
|
5
|
+
|
6
|
+
from flaxdiff.trainer import (
|
7
|
+
SimpleTrainState,
|
8
|
+
TrainState,
|
9
|
+
)
|
10
|
+
from flaxdiff.samplers import (
|
11
|
+
DiffusionSampler,
|
12
|
+
)
|
13
|
+
from flaxdiff.schedulers import (
|
14
|
+
NoiseScheduler,
|
15
|
+
)
|
16
|
+
from flaxdiff.predictors import (
|
17
|
+
DiffusionPredictionTransform,
|
18
|
+
)
|
19
|
+
from flaxdiff.models.autoencoder import AutoEncoder
|
20
|
+
from flaxdiff.inputs import DiffusionInputConfig
|
21
|
+
from flaxdiff.utils import defaultTextEncodeModel, RandomMarkovState
|
22
|
+
from flaxdiff.samplers.euler import EulerAncestralSampler
|
23
|
+
from .utils import parse_config, load_from_wandb_run, load_from_wandb_registry
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class InferencePipeline:
|
27
|
+
"""Inference pipeline for a general model."""
|
28
|
+
model: nn.Module = None
|
29
|
+
state: SimpleTrainState = None
|
30
|
+
best_state: SimpleTrainState = None
|
31
|
+
|
32
|
+
def from_wandb(
|
33
|
+
self,
|
34
|
+
wandb_run: str,
|
35
|
+
wandb_project: str,
|
36
|
+
wandb_entity: str,
|
37
|
+
):
|
38
|
+
raise NotImplementedError("InferencePipeline does not support from_wandb.")
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class DiffusionInferencePipeline(InferencePipeline):
|
42
|
+
"""Inference pipeline for diffusion models.
|
43
|
+
|
44
|
+
This pipeline handles loading models from wandb and generating samples using the
|
45
|
+
DiffusionSampler from FlaxDiff.
|
46
|
+
"""
|
47
|
+
state: TrainState = None
|
48
|
+
best_state: TrainState = None
|
49
|
+
rngstate: Optional[RandomMarkovState] = None
|
50
|
+
noise_schedule: NoiseScheduler = None
|
51
|
+
model_output_transform: DiffusionPredictionTransform = None
|
52
|
+
autoencoder: AutoEncoder = None
|
53
|
+
input_config: DiffusionInputConfig = None
|
54
|
+
samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
|
55
|
+
config: Dict[str, Any] = field(default_factory=dict)
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def from_wandb_run(
|
59
|
+
cls,
|
60
|
+
wandb_run: str,
|
61
|
+
project: str,
|
62
|
+
entity: str,
|
63
|
+
):
|
64
|
+
"""Create an inference pipeline from a wandb run.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
wandb_run: Run ID or display name
|
68
|
+
project: Wandb project name
|
69
|
+
entity: Wandb entity name
|
70
|
+
wandb_modelname: Model name in wandb registry (if None, loads from checkpoint)
|
71
|
+
checkpoint_step: Specific checkpoint step to load (if None, loads latest)
|
72
|
+
config_overrides: Optional dictionary to override config values
|
73
|
+
checkpoint_base_path: Base path for checkpoint storage
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
DiffusionInferencePipeline instance
|
77
|
+
"""
|
78
|
+
states, config = load_from_wandb_run(
|
79
|
+
wandb_run,
|
80
|
+
project=project,
|
81
|
+
entity=entity,
|
82
|
+
)
|
83
|
+
|
84
|
+
if states is None:
|
85
|
+
raise ValueError("Failed to load model parameters from wandb.")
|
86
|
+
|
87
|
+
state, best_state = states
|
88
|
+
parsed_config = parse_config(config)
|
89
|
+
|
90
|
+
# Create the pipeline
|
91
|
+
pipeline = cls.create(
|
92
|
+
config=parsed_config,
|
93
|
+
state=state,
|
94
|
+
best_state=best_state,
|
95
|
+
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
96
|
+
)
|
97
|
+
return pipeline
|
98
|
+
|
99
|
+
@classmethod
|
100
|
+
def from_wandb_registry(
|
101
|
+
cls,
|
102
|
+
modelname: str,
|
103
|
+
project: str,
|
104
|
+
entity: str = None,
|
105
|
+
version: str = 'latest',
|
106
|
+
registry: str = 'wandb-registry-model',
|
107
|
+
):
|
108
|
+
"""Create an inference pipeline from a wandb model registry.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
modelname: Model name in wandb registry
|
112
|
+
project: Wandb project name
|
113
|
+
entity: Wandb entity name
|
114
|
+
version: Version of the model to load (default is 'latest')
|
115
|
+
registry: Registry name (default is 'wandb-registry-model')
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
DiffusionInferencePipeline instance
|
119
|
+
"""
|
120
|
+
states, config = load_from_wandb_registry(
|
121
|
+
modelname=modelname,
|
122
|
+
project=project,
|
123
|
+
entity=entity,
|
124
|
+
version=version,
|
125
|
+
registry=registry,
|
126
|
+
)
|
127
|
+
|
128
|
+
if states is None:
|
129
|
+
raise ValueError("Failed to load model parameters from wandb.")
|
130
|
+
|
131
|
+
state, best_state = states
|
132
|
+
parsed_config = parse_config(config)
|
133
|
+
|
134
|
+
# Create the pipeline
|
135
|
+
pipeline = cls.create(
|
136
|
+
config=parsed_config,
|
137
|
+
state=state,
|
138
|
+
best_state=best_state,
|
139
|
+
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
140
|
+
)
|
141
|
+
return pipeline
|
142
|
+
|
143
|
+
@classmethod
|
144
|
+
def create(
|
145
|
+
cls,
|
146
|
+
config: Dict[str, Any],
|
147
|
+
state: Dict[str, Any],
|
148
|
+
best_state: Optional[Dict[str, Any]] = None,
|
149
|
+
rngstate: Optional[RandomMarkovState] = None,
|
150
|
+
):
|
151
|
+
if rngstate is None:
|
152
|
+
rngstate = RandomMarkovState(jax.random.PRNGKey(42))
|
153
|
+
# Build and return pipeline
|
154
|
+
return cls(
|
155
|
+
model=config['model'],
|
156
|
+
state=state,
|
157
|
+
best_state=best_state,
|
158
|
+
rngstate=rngstate,
|
159
|
+
noise_schedule=config['noise_schedule'],
|
160
|
+
model_output_transform=config['prediction_transform'],
|
161
|
+
autoencoder=config['autoencoder'],
|
162
|
+
input_config=config['input_config'],
|
163
|
+
config=config,
|
164
|
+
)
|
165
|
+
|
166
|
+
def get_sampler(
|
167
|
+
self,
|
168
|
+
guidance_scale: float = 3.0,
|
169
|
+
sampler_class=EulerAncestralSampler,
|
170
|
+
) -> DiffusionSampler:
|
171
|
+
"""Get (or create) a sampler for generating samples.
|
172
|
+
|
173
|
+
This method caches samplers by their class and guidance scale for reuse.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
sampler_class: Class for the diffusion sampler
|
177
|
+
guidance_scale: Classifier-free guidance scale (0.0 to disable)
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
DiffusionSampler instance
|
181
|
+
"""
|
182
|
+
# Get or create dictionary for this sampler class
|
183
|
+
if sampler_class not in self.samplers:
|
184
|
+
self.samplers[sampler_class] = {}
|
185
|
+
|
186
|
+
# Check if we already have a sampler with this guidance scale
|
187
|
+
if guidance_scale not in self.samplers[sampler_class]:
|
188
|
+
# Create unconditional embeddings if using guidance
|
189
|
+
null_embeddings = None
|
190
|
+
if guidance_scale > 0.0:
|
191
|
+
null_text = self.input_config.conditions[0].get_unconditional()
|
192
|
+
null_embeddings = null_text
|
193
|
+
print(f"Created null embeddings for guidance with shape {null_embeddings.shape}")
|
194
|
+
|
195
|
+
# Create and cache the sampler
|
196
|
+
self.samplers[sampler_class][guidance_scale] = sampler_class(
|
197
|
+
model=self.model,
|
198
|
+
noise_schedule=self.noise_schedule,
|
199
|
+
model_output_transform=self.model_output_transform,
|
200
|
+
guidance_scale=guidance_scale,
|
201
|
+
input_config=self.input_config,
|
202
|
+
autoencoder=self.autoencoder,
|
203
|
+
)
|
204
|
+
|
205
|
+
return self.samplers[sampler_class][guidance_scale]
|
206
|
+
|
207
|
+
def generate_samples(
|
208
|
+
self,
|
209
|
+
num_samples: int,
|
210
|
+
resolution: int,
|
211
|
+
conditioning_data: Optional[List[Union[Tuple, Dict]]] = None, # one list per modality or list of tuples
|
212
|
+
sequence_length: Optional[int] = None,
|
213
|
+
diffusion_steps: int = 50,
|
214
|
+
guidance_scale: float = 1.0,
|
215
|
+
sampler_class=EulerAncestralSampler,
|
216
|
+
timestep_spacing: str = 'linear',
|
217
|
+
seed: Optional[int] = None,
|
218
|
+
start_step: Optional[int] = None,
|
219
|
+
end_step: int = 0,
|
220
|
+
steps_override=None,
|
221
|
+
priors=None,
|
222
|
+
use_best_params: bool = False,
|
223
|
+
use_ema: bool = False,
|
224
|
+
):
|
225
|
+
# Setup RNG
|
226
|
+
rngstate = self.rngstate or RandomMarkovState(jax.random.PRNGKey(seed or 0))
|
227
|
+
|
228
|
+
# Get cached or new sampler
|
229
|
+
sampler = self.get_sampler(
|
230
|
+
guidance_scale=guidance_scale,
|
231
|
+
sampler_class=sampler_class,
|
232
|
+
)
|
233
|
+
if hasattr(sampler, 'timestep_spacing'):
|
234
|
+
sampler.timestep_spacing = timestep_spacing
|
235
|
+
print(f"Generating samples: steps={diffusion_steps}, num_samples={num_samples}, guidance={guidance_scale}")
|
236
|
+
|
237
|
+
if use_best_params:
|
238
|
+
state = self.best_state
|
239
|
+
else:
|
240
|
+
state = self.state
|
241
|
+
|
242
|
+
if use_ema:
|
243
|
+
params = state['ema_params']
|
244
|
+
else:
|
245
|
+
params = state['params']
|
246
|
+
|
247
|
+
|
248
|
+
return sampler.generate_samples(
|
249
|
+
params=params,
|
250
|
+
num_samples=num_samples,
|
251
|
+
resolution=resolution,
|
252
|
+
sequence_length=sequence_length,
|
253
|
+
diffusion_steps=diffusion_steps,
|
254
|
+
start_step=start_step,
|
255
|
+
end_step=end_step,
|
256
|
+
steps_override=steps_override,
|
257
|
+
priors=priors,
|
258
|
+
rngstate=rngstate,
|
259
|
+
conditioning=conditioning_data
|
260
|
+
)
|
@@ -0,0 +1,320 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import json
|
4
|
+
from flaxdiff.schedulers import (
|
5
|
+
CosineNoiseScheduler,
|
6
|
+
KarrasVENoiseScheduler,
|
7
|
+
)
|
8
|
+
from flaxdiff.predictors import (
|
9
|
+
VPredictionTransform,
|
10
|
+
KarrasPredictionTransform,
|
11
|
+
)
|
12
|
+
from flaxdiff.models.common import kernel_init
|
13
|
+
from flaxdiff.models.simple_unet import Unet
|
14
|
+
from flaxdiff.models.simple_vit import UViT
|
15
|
+
from flaxdiff.models.general import BCHWModelWrapper
|
16
|
+
from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
|
17
|
+
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
|
18
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
19
|
+
from diffusers import FlaxUNet2DConditionModel
|
20
|
+
import wandb
|
21
|
+
from flaxdiff.models.simple_unet import Unet
|
22
|
+
from flaxdiff.models.simple_vit import UViT
|
23
|
+
from flaxdiff.models.general import BCHWModelWrapper
|
24
|
+
from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
|
25
|
+
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
|
26
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
27
|
+
|
28
|
+
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions, PyTreeCheckpointer
|
29
|
+
import os
|
30
|
+
|
31
|
+
import warnings
|
32
|
+
|
33
|
+
def get_wandb_run(wandb_run: str, project, entity):
|
34
|
+
"""
|
35
|
+
Try to get the wandb run for the given experiment name and project.
|
36
|
+
Return None if not found.
|
37
|
+
"""
|
38
|
+
import wandb
|
39
|
+
wandb_api = wandb.Api()
|
40
|
+
# First try to get the run by treating wandb_run as a run ID
|
41
|
+
try:
|
42
|
+
run = wandb_api.run(f"{entity}/{project}/{wandb_run}")
|
43
|
+
print(f"Found run: {run.name} ({run.id})")
|
44
|
+
return run
|
45
|
+
except wandb.Error as e:
|
46
|
+
print(f"Run not found by ID: {e}")
|
47
|
+
# If that fails, try to get the run by treating wandb_run as a display name
|
48
|
+
# This is a bit of a hack, but it works for now.
|
49
|
+
# Note: this will return all runs with the same display name, so be careful.
|
50
|
+
print(f"Trying to get run by display name: {wandb_run}")
|
51
|
+
runs = wandb_api.runs(path=f"{entity}/{project}", filters={"displayName": wandb_run})
|
52
|
+
for run in runs:
|
53
|
+
print(f"Found run: {run.name} ({run.id})")
|
54
|
+
return run
|
55
|
+
return None
|
56
|
+
|
57
|
+
def parse_config(config, overrides=None):
|
58
|
+
"""Parse configuration for inference pipeline.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
config: Configuration dictionary from wandb run
|
62
|
+
overrides: Optional dictionary of overrides for config parameters
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Dictionary containing model, sampler, scheduler, and other required components
|
66
|
+
including DiffusionInputConfig for the general diffusion framework
|
67
|
+
"""
|
68
|
+
warnings.filterwarnings("ignore")
|
69
|
+
|
70
|
+
# Merge config with overrides if provided
|
71
|
+
if overrides is not None:
|
72
|
+
# Create a deep copy of config to avoid modifying the original
|
73
|
+
merged_config = dict(config)
|
74
|
+
# Update arguments with overrides
|
75
|
+
if 'arguments' in merged_config:
|
76
|
+
merged_config['arguments'] = {**merged_config['arguments'], **overrides}
|
77
|
+
# Also update top-level config for key parameters
|
78
|
+
for key in overrides:
|
79
|
+
if key in merged_config:
|
80
|
+
merged_config[key] = overrides[key]
|
81
|
+
else:
|
82
|
+
merged_config = config
|
83
|
+
|
84
|
+
# Parse configuration from config dict
|
85
|
+
conf = merged_config
|
86
|
+
|
87
|
+
# Setup mappings for dtype, precision, and activation
|
88
|
+
DTYPE_MAP = {
|
89
|
+
'bfloat16': jnp.bfloat16,
|
90
|
+
'float32': jnp.float32,
|
91
|
+
'jax.numpy.float32': jnp.float32,
|
92
|
+
'jax.numpy.bfloat16': jnp.bfloat16,
|
93
|
+
'None': None,
|
94
|
+
None: None,
|
95
|
+
}
|
96
|
+
|
97
|
+
PRECISION_MAP = {
|
98
|
+
'high': jax.lax.Precision.HIGH,
|
99
|
+
'HIGH': jax.lax.Precision.HIGH,
|
100
|
+
'default': jax.lax.Precision.DEFAULT,
|
101
|
+
'DEFAULT': jax.lax.Precision.DEFAULT,
|
102
|
+
'highest': jax.lax.Precision.HIGHEST,
|
103
|
+
'HIGHEST': jax.lax.Precision.HIGHEST,
|
104
|
+
'None': None,
|
105
|
+
None: None,
|
106
|
+
}
|
107
|
+
|
108
|
+
ACTIVATION_MAP = {
|
109
|
+
'swish': jax.nn.swish,
|
110
|
+
'silu': jax.nn.silu,
|
111
|
+
'jax._src.nn.functions.silu': jax.nn.silu,
|
112
|
+
'mish': jax.nn.mish,
|
113
|
+
}
|
114
|
+
|
115
|
+
# Get model class based on architecture
|
116
|
+
MODEL_CLASSES = {
|
117
|
+
'unet': Unet,
|
118
|
+
'uvit': UViT,
|
119
|
+
'diffusers_unet_simple': FlaxUNet2DConditionModel
|
120
|
+
}
|
121
|
+
|
122
|
+
# Map all the leaves of the model config, converting strings to appropriate types
|
123
|
+
def map_nested_config(config):
|
124
|
+
new_config = {}
|
125
|
+
for key, value in config.items():
|
126
|
+
if isinstance(value, dict):
|
127
|
+
new_config[key] = map_nested_config(value)
|
128
|
+
elif isinstance(value, list):
|
129
|
+
new_config[key] = [map_nested_config(item) if isinstance(item, dict) else item for item in value]
|
130
|
+
elif isinstance(value, str):
|
131
|
+
if value in DTYPE_MAP:
|
132
|
+
new_config[key] = DTYPE_MAP[value]
|
133
|
+
elif value in PRECISION_MAP:
|
134
|
+
new_config[key] = PRECISION_MAP[value]
|
135
|
+
elif value in ACTIVATION_MAP:
|
136
|
+
new_config[key] = ACTIVATION_MAP[value]
|
137
|
+
elif value == 'None':
|
138
|
+
new_config[key] = None
|
139
|
+
elif '.'in value:
|
140
|
+
# Ignore any other string that contains a dot
|
141
|
+
print(f"Ignoring key {key} with value {value} as it contains a dot.")
|
142
|
+
else:
|
143
|
+
new_config[key] = value
|
144
|
+
else:
|
145
|
+
new_config[key] = value
|
146
|
+
return new_config
|
147
|
+
|
148
|
+
# Parse architecture and model config
|
149
|
+
model_config = conf['model']
|
150
|
+
|
151
|
+
# Get architecture type
|
152
|
+
architecture = conf.get('architecture', conf.get('arguments', {}).get('architecture', 'unet'))
|
153
|
+
|
154
|
+
# Handle autoencoder
|
155
|
+
autoencoder_name = conf.get('autoencoder', conf.get('arguments', {}).get('autoencoder'))
|
156
|
+
autoencoder_opts_str = conf.get('autoencoder_opts', conf.get('arguments', {}).get('autoencoder_opts', '{}'))
|
157
|
+
autoencoder = None
|
158
|
+
autoencoder_opts = None
|
159
|
+
|
160
|
+
if autoencoder_name:
|
161
|
+
print(f"Using autoencoder: {autoencoder_name}")
|
162
|
+
if isinstance(autoencoder_opts_str, str):
|
163
|
+
autoencoder_opts = json.loads(autoencoder_opts_str)
|
164
|
+
else:
|
165
|
+
autoencoder_opts = autoencoder_opts_str
|
166
|
+
|
167
|
+
if autoencoder_name == 'stable_diffusion':
|
168
|
+
print("Using Stable Diffusion Autoencoder for Latent Diffusion Modeling")
|
169
|
+
autoencoder_opts = map_nested_config(autoencoder_opts)
|
170
|
+
autoencoder = StableDiffusionVAE(**autoencoder_opts)
|
171
|
+
|
172
|
+
input_config = conf.get('input_config', None)
|
173
|
+
|
174
|
+
# If not provided, create one based on the older format (backward compatibility)
|
175
|
+
if input_config is None:
|
176
|
+
# Warn if input_config is not provided
|
177
|
+
print("No input_config provided, creating a default one.")
|
178
|
+
image_size = conf['arguments'].get('image_size', 128)
|
179
|
+
image_channels = 3 # Default number of channels
|
180
|
+
# Create text encoder
|
181
|
+
text_encoder = defaultTextEncodeModel()
|
182
|
+
# Create a conditional input config for text conditioning
|
183
|
+
text_conditional_config = ConditionalInputConfig(
|
184
|
+
encoder=text_encoder,
|
185
|
+
conditioning_data_key='text',
|
186
|
+
pretokenized=True,
|
187
|
+
unconditional_input="",
|
188
|
+
model_key_override="textcontext"
|
189
|
+
)
|
190
|
+
|
191
|
+
# Create the main input config
|
192
|
+
input_config = DiffusionInputConfig(
|
193
|
+
sample_data_key='image',
|
194
|
+
sample_data_shape=(image_size, image_size, image_channels),
|
195
|
+
conditions=[text_conditional_config]
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
# Deserialize the input config if it's a string
|
199
|
+
input_config = DiffusionInputConfig.deserialize(input_config)
|
200
|
+
|
201
|
+
model_kwargs = map_nested_config(model_config)
|
202
|
+
|
203
|
+
print(f"Model kwargs after mapping: {model_kwargs}")
|
204
|
+
|
205
|
+
model_class = MODEL_CLASSES.get(architecture)
|
206
|
+
if not model_class:
|
207
|
+
raise ValueError(f"Unknown architecture: {architecture}. Supported architectures: {', '.join(MODEL_CLASSES.keys())}")
|
208
|
+
|
209
|
+
# Instantiate the model
|
210
|
+
model = model_class(**model_kwargs)
|
211
|
+
|
212
|
+
# If using diffusers UNet, wrap it for consistent interface
|
213
|
+
if 'diffusers' in architecture:
|
214
|
+
model = BCHWModelWrapper(model)
|
215
|
+
|
216
|
+
# Create noise scheduler based on configuration
|
217
|
+
noise_schedule_type = conf.get('noise_schedule', conf.get('arguments', {}).get('noise_schedule', 'edm'))
|
218
|
+
if noise_schedule_type in ['edm', 'karras']:
|
219
|
+
# For both EDM and karras, we use the karras scheduler for inference
|
220
|
+
noise_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
221
|
+
prediction_transform = KarrasPredictionTransform(sigma_data=noise_schedule.sigma_data)
|
222
|
+
elif noise_schedule_type == 'cosine':
|
223
|
+
noise_schedule = CosineNoiseScheduler(1000, beta_end=1)
|
224
|
+
prediction_transform = VPredictionTransform()
|
225
|
+
else:
|
226
|
+
raise ValueError(f"Unknown noise schedule: {noise_schedule_type}")
|
227
|
+
|
228
|
+
# Prepare return dictionary with all components
|
229
|
+
result = {
|
230
|
+
'model': model,
|
231
|
+
'model_config': model_kwargs,
|
232
|
+
'architecture': architecture,
|
233
|
+
'autoencoder': autoencoder,
|
234
|
+
'noise_schedule': noise_schedule,
|
235
|
+
'prediction_transform': prediction_transform,
|
236
|
+
'input_config': input_config,
|
237
|
+
'raw_config': conf,
|
238
|
+
}
|
239
|
+
|
240
|
+
return result
|
241
|
+
|
242
|
+
def load_from_checkpoint(
|
243
|
+
checkpoint_dir: str,
|
244
|
+
):
|
245
|
+
try:
|
246
|
+
checkpointer = PyTreeCheckpointer()
|
247
|
+
options = CheckpointManagerOptions(create=False)
|
248
|
+
# Convert checkpoint_dir to absolute path
|
249
|
+
checkpoint_dir = os.path.abspath(checkpoint_dir)
|
250
|
+
manager = CheckpointManager(checkpoint_dir, checkpointer, options)
|
251
|
+
ckpt = manager.restore(checkpoint_dir)
|
252
|
+
# Extract as above
|
253
|
+
state, best_state = None, None
|
254
|
+
if 'state' in ckpt:
|
255
|
+
state = ckpt['state']
|
256
|
+
if 'best_state' in ckpt:
|
257
|
+
best_state = ckpt['best_state']
|
258
|
+
print(f"Loaded checkpoint from local dir {checkpoint_dir}")
|
259
|
+
return state, best_state
|
260
|
+
except Exception as e:
|
261
|
+
print(f"Warning: Failed to load checkpoint from local dir: {e}")
|
262
|
+
return None, None
|
263
|
+
|
264
|
+
def load_from_wandb_run(
|
265
|
+
run,
|
266
|
+
project: str,
|
267
|
+
entity: str = None,
|
268
|
+
):
|
269
|
+
"""
|
270
|
+
Loads model from wandb model registry.
|
271
|
+
"""
|
272
|
+
# Get the model version from wandb
|
273
|
+
states = None
|
274
|
+
config = None
|
275
|
+
try:
|
276
|
+
if isinstance(run, str):
|
277
|
+
run = get_wandb_run(run, project, entity)
|
278
|
+
# Search for model artifact
|
279
|
+
models = [i for i in run.logged_artifacts() if i.type == 'model']
|
280
|
+
if len(models) == 0:
|
281
|
+
raise ValueError(f"No model artifacts found in run {run.id}")
|
282
|
+
# Pick out any model artifact
|
283
|
+
highest_version = max([{'version':int(i.version[1:]), 'name': i.qualified_name} for i in models], key=lambda x: x['version'])
|
284
|
+
wandb_modelname = highest_version['name']
|
285
|
+
|
286
|
+
print(f"Loading model from wandb: {wandb_modelname} out of versions {[i.version for i in models]}")
|
287
|
+
artifact = run.use_artifact(wandb.Api().artifact(wandb_modelname))
|
288
|
+
ckpt_dir = artifact.download()
|
289
|
+
print(f"Loaded model from wandb: {wandb_modelname} at path {ckpt_dir}")
|
290
|
+
# Load the model from the checkpoint directory
|
291
|
+
states = load_from_checkpoint(ckpt_dir)
|
292
|
+
config = run.config
|
293
|
+
except Exception as e:
|
294
|
+
print(f"Warning: Failed to load model from wandb: {e}")
|
295
|
+
return states, config
|
296
|
+
|
297
|
+
def load_from_wandb_registry(
|
298
|
+
modelname: str,
|
299
|
+
project: str,
|
300
|
+
entity: str = None,
|
301
|
+
version: str = 'latest',
|
302
|
+
registry: str = 'wandb-registry-model',
|
303
|
+
):
|
304
|
+
"""
|
305
|
+
Loads model from wandb model registry.
|
306
|
+
"""
|
307
|
+
# Get the model version from wandb
|
308
|
+
states = None
|
309
|
+
config = None
|
310
|
+
try:
|
311
|
+
artifact = wandb.Api().artifact(f"{registry}/{modelname}:{version}")
|
312
|
+
ckpt_dir = artifact.download()
|
313
|
+
print(f"Loaded model from wandb registry: {modelname} at path {ckpt_dir}")
|
314
|
+
# Load the model from the checkpoint directory
|
315
|
+
states = load_from_checkpoint(ckpt_dir)
|
316
|
+
run = artifact.logged_by()
|
317
|
+
config = run.config
|
318
|
+
except Exception as e:
|
319
|
+
print(f"Warning: Failed to load model from wandb: {e}")
|
320
|
+
return states, config
|