flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,260 @@
1
+ import jax
2
+ import flax.linen as nn
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional, Dict, Any, Union, List, Tuple, Type
5
+
6
+ from flaxdiff.trainer import (
7
+ SimpleTrainState,
8
+ TrainState,
9
+ )
10
+ from flaxdiff.samplers import (
11
+ DiffusionSampler,
12
+ )
13
+ from flaxdiff.schedulers import (
14
+ NoiseScheduler,
15
+ )
16
+ from flaxdiff.predictors import (
17
+ DiffusionPredictionTransform,
18
+ )
19
+ from flaxdiff.models.autoencoder import AutoEncoder
20
+ from flaxdiff.inputs import DiffusionInputConfig
21
+ from flaxdiff.utils import defaultTextEncodeModel, RandomMarkovState
22
+ from flaxdiff.samplers.euler import EulerAncestralSampler
23
+ from .utils import parse_config, load_from_wandb_run, load_from_wandb_registry
24
+
25
+ @dataclass
26
+ class InferencePipeline:
27
+ """Inference pipeline for a general model."""
28
+ model: nn.Module = None
29
+ state: SimpleTrainState = None
30
+ best_state: SimpleTrainState = None
31
+
32
+ def from_wandb(
33
+ self,
34
+ wandb_run: str,
35
+ wandb_project: str,
36
+ wandb_entity: str,
37
+ ):
38
+ raise NotImplementedError("InferencePipeline does not support from_wandb.")
39
+
40
+ @dataclass
41
+ class DiffusionInferencePipeline(InferencePipeline):
42
+ """Inference pipeline for diffusion models.
43
+
44
+ This pipeline handles loading models from wandb and generating samples using the
45
+ DiffusionSampler from FlaxDiff.
46
+ """
47
+ state: TrainState = None
48
+ best_state: TrainState = None
49
+ rngstate: Optional[RandomMarkovState] = None
50
+ noise_schedule: NoiseScheduler = None
51
+ model_output_transform: DiffusionPredictionTransform = None
52
+ autoencoder: AutoEncoder = None
53
+ input_config: DiffusionInputConfig = None
54
+ samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
55
+ config: Dict[str, Any] = field(default_factory=dict)
56
+
57
+ @classmethod
58
+ def from_wandb_run(
59
+ cls,
60
+ wandb_run: str,
61
+ project: str,
62
+ entity: str,
63
+ ):
64
+ """Create an inference pipeline from a wandb run.
65
+
66
+ Args:
67
+ wandb_run: Run ID or display name
68
+ project: Wandb project name
69
+ entity: Wandb entity name
70
+ wandb_modelname: Model name in wandb registry (if None, loads from checkpoint)
71
+ checkpoint_step: Specific checkpoint step to load (if None, loads latest)
72
+ config_overrides: Optional dictionary to override config values
73
+ checkpoint_base_path: Base path for checkpoint storage
74
+
75
+ Returns:
76
+ DiffusionInferencePipeline instance
77
+ """
78
+ states, config = load_from_wandb_run(
79
+ wandb_run,
80
+ project=project,
81
+ entity=entity,
82
+ )
83
+
84
+ if states is None:
85
+ raise ValueError("Failed to load model parameters from wandb.")
86
+
87
+ state, best_state = states
88
+ parsed_config = parse_config(config)
89
+
90
+ # Create the pipeline
91
+ pipeline = cls.create(
92
+ config=parsed_config,
93
+ state=state,
94
+ best_state=best_state,
95
+ rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
96
+ )
97
+ return pipeline
98
+
99
+ @classmethod
100
+ def from_wandb_registry(
101
+ cls,
102
+ modelname: str,
103
+ project: str,
104
+ entity: str = None,
105
+ version: str = 'latest',
106
+ registry: str = 'wandb-registry-model',
107
+ ):
108
+ """Create an inference pipeline from a wandb model registry.
109
+
110
+ Args:
111
+ modelname: Model name in wandb registry
112
+ project: Wandb project name
113
+ entity: Wandb entity name
114
+ version: Version of the model to load (default is 'latest')
115
+ registry: Registry name (default is 'wandb-registry-model')
116
+
117
+ Returns:
118
+ DiffusionInferencePipeline instance
119
+ """
120
+ states, config = load_from_wandb_registry(
121
+ modelname=modelname,
122
+ project=project,
123
+ entity=entity,
124
+ version=version,
125
+ registry=registry,
126
+ )
127
+
128
+ if states is None:
129
+ raise ValueError("Failed to load model parameters from wandb.")
130
+
131
+ state, best_state = states
132
+ parsed_config = parse_config(config)
133
+
134
+ # Create the pipeline
135
+ pipeline = cls.create(
136
+ config=parsed_config,
137
+ state=state,
138
+ best_state=best_state,
139
+ rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
140
+ )
141
+ return pipeline
142
+
143
+ @classmethod
144
+ def create(
145
+ cls,
146
+ config: Dict[str, Any],
147
+ state: Dict[str, Any],
148
+ best_state: Optional[Dict[str, Any]] = None,
149
+ rngstate: Optional[RandomMarkovState] = None,
150
+ ):
151
+ if rngstate is None:
152
+ rngstate = RandomMarkovState(jax.random.PRNGKey(42))
153
+ # Build and return pipeline
154
+ return cls(
155
+ model=config['model'],
156
+ state=state,
157
+ best_state=best_state,
158
+ rngstate=rngstate,
159
+ noise_schedule=config['noise_schedule'],
160
+ model_output_transform=config['prediction_transform'],
161
+ autoencoder=config['autoencoder'],
162
+ input_config=config['input_config'],
163
+ config=config,
164
+ )
165
+
166
+ def get_sampler(
167
+ self,
168
+ guidance_scale: float = 3.0,
169
+ sampler_class=EulerAncestralSampler,
170
+ ) -> DiffusionSampler:
171
+ """Get (or create) a sampler for generating samples.
172
+
173
+ This method caches samplers by their class and guidance scale for reuse.
174
+
175
+ Args:
176
+ sampler_class: Class for the diffusion sampler
177
+ guidance_scale: Classifier-free guidance scale (0.0 to disable)
178
+
179
+ Returns:
180
+ DiffusionSampler instance
181
+ """
182
+ # Get or create dictionary for this sampler class
183
+ if sampler_class not in self.samplers:
184
+ self.samplers[sampler_class] = {}
185
+
186
+ # Check if we already have a sampler with this guidance scale
187
+ if guidance_scale not in self.samplers[sampler_class]:
188
+ # Create unconditional embeddings if using guidance
189
+ null_embeddings = None
190
+ if guidance_scale > 0.0:
191
+ null_text = self.input_config.conditions[0].get_unconditional()
192
+ null_embeddings = null_text
193
+ print(f"Created null embeddings for guidance with shape {null_embeddings.shape}")
194
+
195
+ # Create and cache the sampler
196
+ self.samplers[sampler_class][guidance_scale] = sampler_class(
197
+ model=self.model,
198
+ noise_schedule=self.noise_schedule,
199
+ model_output_transform=self.model_output_transform,
200
+ guidance_scale=guidance_scale,
201
+ input_config=self.input_config,
202
+ autoencoder=self.autoencoder,
203
+ )
204
+
205
+ return self.samplers[sampler_class][guidance_scale]
206
+
207
+ def generate_samples(
208
+ self,
209
+ num_samples: int,
210
+ resolution: int,
211
+ conditioning_data: Optional[List[Union[Tuple, Dict]]] = None, # one list per modality or list of tuples
212
+ sequence_length: Optional[int] = None,
213
+ diffusion_steps: int = 50,
214
+ guidance_scale: float = 1.0,
215
+ sampler_class=EulerAncestralSampler,
216
+ timestep_spacing: str = 'linear',
217
+ seed: Optional[int] = None,
218
+ start_step: Optional[int] = None,
219
+ end_step: int = 0,
220
+ steps_override=None,
221
+ priors=None,
222
+ use_best_params: bool = False,
223
+ use_ema: bool = False,
224
+ ):
225
+ # Setup RNG
226
+ rngstate = self.rngstate or RandomMarkovState(jax.random.PRNGKey(seed or 0))
227
+
228
+ # Get cached or new sampler
229
+ sampler = self.get_sampler(
230
+ guidance_scale=guidance_scale,
231
+ sampler_class=sampler_class,
232
+ )
233
+ if hasattr(sampler, 'timestep_spacing'):
234
+ sampler.timestep_spacing = timestep_spacing
235
+ print(f"Generating samples: steps={diffusion_steps}, num_samples={num_samples}, guidance={guidance_scale}")
236
+
237
+ if use_best_params:
238
+ state = self.best_state
239
+ else:
240
+ state = self.state
241
+
242
+ if use_ema:
243
+ params = state['ema_params']
244
+ else:
245
+ params = state['params']
246
+
247
+
248
+ return sampler.generate_samples(
249
+ params=params,
250
+ num_samples=num_samples,
251
+ resolution=resolution,
252
+ sequence_length=sequence_length,
253
+ diffusion_steps=diffusion_steps,
254
+ start_step=start_step,
255
+ end_step=end_step,
256
+ steps_override=steps_override,
257
+ priors=priors,
258
+ rngstate=rngstate,
259
+ conditioning=conditioning_data
260
+ )
@@ -0,0 +1,320 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import json
4
+ from flaxdiff.schedulers import (
5
+ CosineNoiseScheduler,
6
+ KarrasVENoiseScheduler,
7
+ )
8
+ from flaxdiff.predictors import (
9
+ VPredictionTransform,
10
+ KarrasPredictionTransform,
11
+ )
12
+ from flaxdiff.models.common import kernel_init
13
+ from flaxdiff.models.simple_unet import Unet
14
+ from flaxdiff.models.simple_vit import UViT
15
+ from flaxdiff.models.general import BCHWModelWrapper
16
+ from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
17
+ from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
18
+ from flaxdiff.utils import defaultTextEncodeModel
19
+ from diffusers import FlaxUNet2DConditionModel
20
+ import wandb
21
+ from flaxdiff.models.simple_unet import Unet
22
+ from flaxdiff.models.simple_vit import UViT
23
+ from flaxdiff.models.general import BCHWModelWrapper
24
+ from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
25
+ from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
26
+ from flaxdiff.utils import defaultTextEncodeModel
27
+
28
+ from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions, PyTreeCheckpointer
29
+ import os
30
+
31
+ import warnings
32
+
33
+ def get_wandb_run(wandb_run: str, project, entity):
34
+ """
35
+ Try to get the wandb run for the given experiment name and project.
36
+ Return None if not found.
37
+ """
38
+ import wandb
39
+ wandb_api = wandb.Api()
40
+ # First try to get the run by treating wandb_run as a run ID
41
+ try:
42
+ run = wandb_api.run(f"{entity}/{project}/{wandb_run}")
43
+ print(f"Found run: {run.name} ({run.id})")
44
+ return run
45
+ except wandb.Error as e:
46
+ print(f"Run not found by ID: {e}")
47
+ # If that fails, try to get the run by treating wandb_run as a display name
48
+ # This is a bit of a hack, but it works for now.
49
+ # Note: this will return all runs with the same display name, so be careful.
50
+ print(f"Trying to get run by display name: {wandb_run}")
51
+ runs = wandb_api.runs(path=f"{entity}/{project}", filters={"displayName": wandb_run})
52
+ for run in runs:
53
+ print(f"Found run: {run.name} ({run.id})")
54
+ return run
55
+ return None
56
+
57
+ def parse_config(config, overrides=None):
58
+ """Parse configuration for inference pipeline.
59
+
60
+ Args:
61
+ config: Configuration dictionary from wandb run
62
+ overrides: Optional dictionary of overrides for config parameters
63
+
64
+ Returns:
65
+ Dictionary containing model, sampler, scheduler, and other required components
66
+ including DiffusionInputConfig for the general diffusion framework
67
+ """
68
+ warnings.filterwarnings("ignore")
69
+
70
+ # Merge config with overrides if provided
71
+ if overrides is not None:
72
+ # Create a deep copy of config to avoid modifying the original
73
+ merged_config = dict(config)
74
+ # Update arguments with overrides
75
+ if 'arguments' in merged_config:
76
+ merged_config['arguments'] = {**merged_config['arguments'], **overrides}
77
+ # Also update top-level config for key parameters
78
+ for key in overrides:
79
+ if key in merged_config:
80
+ merged_config[key] = overrides[key]
81
+ else:
82
+ merged_config = config
83
+
84
+ # Parse configuration from config dict
85
+ conf = merged_config
86
+
87
+ # Setup mappings for dtype, precision, and activation
88
+ DTYPE_MAP = {
89
+ 'bfloat16': jnp.bfloat16,
90
+ 'float32': jnp.float32,
91
+ 'jax.numpy.float32': jnp.float32,
92
+ 'jax.numpy.bfloat16': jnp.bfloat16,
93
+ 'None': None,
94
+ None: None,
95
+ }
96
+
97
+ PRECISION_MAP = {
98
+ 'high': jax.lax.Precision.HIGH,
99
+ 'HIGH': jax.lax.Precision.HIGH,
100
+ 'default': jax.lax.Precision.DEFAULT,
101
+ 'DEFAULT': jax.lax.Precision.DEFAULT,
102
+ 'highest': jax.lax.Precision.HIGHEST,
103
+ 'HIGHEST': jax.lax.Precision.HIGHEST,
104
+ 'None': None,
105
+ None: None,
106
+ }
107
+
108
+ ACTIVATION_MAP = {
109
+ 'swish': jax.nn.swish,
110
+ 'silu': jax.nn.silu,
111
+ 'jax._src.nn.functions.silu': jax.nn.silu,
112
+ 'mish': jax.nn.mish,
113
+ }
114
+
115
+ # Get model class based on architecture
116
+ MODEL_CLASSES = {
117
+ 'unet': Unet,
118
+ 'uvit': UViT,
119
+ 'diffusers_unet_simple': FlaxUNet2DConditionModel
120
+ }
121
+
122
+ # Map all the leaves of the model config, converting strings to appropriate types
123
+ def map_nested_config(config):
124
+ new_config = {}
125
+ for key, value in config.items():
126
+ if isinstance(value, dict):
127
+ new_config[key] = map_nested_config(value)
128
+ elif isinstance(value, list):
129
+ new_config[key] = [map_nested_config(item) if isinstance(item, dict) else item for item in value]
130
+ elif isinstance(value, str):
131
+ if value in DTYPE_MAP:
132
+ new_config[key] = DTYPE_MAP[value]
133
+ elif value in PRECISION_MAP:
134
+ new_config[key] = PRECISION_MAP[value]
135
+ elif value in ACTIVATION_MAP:
136
+ new_config[key] = ACTIVATION_MAP[value]
137
+ elif value == 'None':
138
+ new_config[key] = None
139
+ elif '.'in value:
140
+ # Ignore any other string that contains a dot
141
+ print(f"Ignoring key {key} with value {value} as it contains a dot.")
142
+ else:
143
+ new_config[key] = value
144
+ else:
145
+ new_config[key] = value
146
+ return new_config
147
+
148
+ # Parse architecture and model config
149
+ model_config = conf['model']
150
+
151
+ # Get architecture type
152
+ architecture = conf.get('architecture', conf.get('arguments', {}).get('architecture', 'unet'))
153
+
154
+ # Handle autoencoder
155
+ autoencoder_name = conf.get('autoencoder', conf.get('arguments', {}).get('autoencoder'))
156
+ autoencoder_opts_str = conf.get('autoencoder_opts', conf.get('arguments', {}).get('autoencoder_opts', '{}'))
157
+ autoencoder = None
158
+ autoencoder_opts = None
159
+
160
+ if autoencoder_name:
161
+ print(f"Using autoencoder: {autoencoder_name}")
162
+ if isinstance(autoencoder_opts_str, str):
163
+ autoencoder_opts = json.loads(autoencoder_opts_str)
164
+ else:
165
+ autoencoder_opts = autoencoder_opts_str
166
+
167
+ if autoencoder_name == 'stable_diffusion':
168
+ print("Using Stable Diffusion Autoencoder for Latent Diffusion Modeling")
169
+ autoencoder_opts = map_nested_config(autoencoder_opts)
170
+ autoencoder = StableDiffusionVAE(**autoencoder_opts)
171
+
172
+ input_config = conf.get('input_config', None)
173
+
174
+ # If not provided, create one based on the older format (backward compatibility)
175
+ if input_config is None:
176
+ # Warn if input_config is not provided
177
+ print("No input_config provided, creating a default one.")
178
+ image_size = conf['arguments'].get('image_size', 128)
179
+ image_channels = 3 # Default number of channels
180
+ # Create text encoder
181
+ text_encoder = defaultTextEncodeModel()
182
+ # Create a conditional input config for text conditioning
183
+ text_conditional_config = ConditionalInputConfig(
184
+ encoder=text_encoder,
185
+ conditioning_data_key='text',
186
+ pretokenized=True,
187
+ unconditional_input="",
188
+ model_key_override="textcontext"
189
+ )
190
+
191
+ # Create the main input config
192
+ input_config = DiffusionInputConfig(
193
+ sample_data_key='image',
194
+ sample_data_shape=(image_size, image_size, image_channels),
195
+ conditions=[text_conditional_config]
196
+ )
197
+ else:
198
+ # Deserialize the input config if it's a string
199
+ input_config = DiffusionInputConfig.deserialize(input_config)
200
+
201
+ model_kwargs = map_nested_config(model_config)
202
+
203
+ print(f"Model kwargs after mapping: {model_kwargs}")
204
+
205
+ model_class = MODEL_CLASSES.get(architecture)
206
+ if not model_class:
207
+ raise ValueError(f"Unknown architecture: {architecture}. Supported architectures: {', '.join(MODEL_CLASSES.keys())}")
208
+
209
+ # Instantiate the model
210
+ model = model_class(**model_kwargs)
211
+
212
+ # If using diffusers UNet, wrap it for consistent interface
213
+ if 'diffusers' in architecture:
214
+ model = BCHWModelWrapper(model)
215
+
216
+ # Create noise scheduler based on configuration
217
+ noise_schedule_type = conf.get('noise_schedule', conf.get('arguments', {}).get('noise_schedule', 'edm'))
218
+ if noise_schedule_type in ['edm', 'karras']:
219
+ # For both EDM and karras, we use the karras scheduler for inference
220
+ noise_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
221
+ prediction_transform = KarrasPredictionTransform(sigma_data=noise_schedule.sigma_data)
222
+ elif noise_schedule_type == 'cosine':
223
+ noise_schedule = CosineNoiseScheduler(1000, beta_end=1)
224
+ prediction_transform = VPredictionTransform()
225
+ else:
226
+ raise ValueError(f"Unknown noise schedule: {noise_schedule_type}")
227
+
228
+ # Prepare return dictionary with all components
229
+ result = {
230
+ 'model': model,
231
+ 'model_config': model_kwargs,
232
+ 'architecture': architecture,
233
+ 'autoencoder': autoencoder,
234
+ 'noise_schedule': noise_schedule,
235
+ 'prediction_transform': prediction_transform,
236
+ 'input_config': input_config,
237
+ 'raw_config': conf,
238
+ }
239
+
240
+ return result
241
+
242
+ def load_from_checkpoint(
243
+ checkpoint_dir: str,
244
+ ):
245
+ try:
246
+ checkpointer = PyTreeCheckpointer()
247
+ options = CheckpointManagerOptions(create=False)
248
+ # Convert checkpoint_dir to absolute path
249
+ checkpoint_dir = os.path.abspath(checkpoint_dir)
250
+ manager = CheckpointManager(checkpoint_dir, checkpointer, options)
251
+ ckpt = manager.restore(checkpoint_dir)
252
+ # Extract as above
253
+ state, best_state = None, None
254
+ if 'state' in ckpt:
255
+ state = ckpt['state']
256
+ if 'best_state' in ckpt:
257
+ best_state = ckpt['best_state']
258
+ print(f"Loaded checkpoint from local dir {checkpoint_dir}")
259
+ return state, best_state
260
+ except Exception as e:
261
+ print(f"Warning: Failed to load checkpoint from local dir: {e}")
262
+ return None, None
263
+
264
+ def load_from_wandb_run(
265
+ run,
266
+ project: str,
267
+ entity: str = None,
268
+ ):
269
+ """
270
+ Loads model from wandb model registry.
271
+ """
272
+ # Get the model version from wandb
273
+ states = None
274
+ config = None
275
+ try:
276
+ if isinstance(run, str):
277
+ run = get_wandb_run(run, project, entity)
278
+ # Search for model artifact
279
+ models = [i for i in run.logged_artifacts() if i.type == 'model']
280
+ if len(models) == 0:
281
+ raise ValueError(f"No model artifacts found in run {run.id}")
282
+ # Pick out any model artifact
283
+ highest_version = max([{'version':int(i.version[1:]), 'name': i.qualified_name} for i in models], key=lambda x: x['version'])
284
+ wandb_modelname = highest_version['name']
285
+
286
+ print(f"Loading model from wandb: {wandb_modelname} out of versions {[i.version for i in models]}")
287
+ artifact = run.use_artifact(wandb.Api().artifact(wandb_modelname))
288
+ ckpt_dir = artifact.download()
289
+ print(f"Loaded model from wandb: {wandb_modelname} at path {ckpt_dir}")
290
+ # Load the model from the checkpoint directory
291
+ states = load_from_checkpoint(ckpt_dir)
292
+ config = run.config
293
+ except Exception as e:
294
+ print(f"Warning: Failed to load model from wandb: {e}")
295
+ return states, config
296
+
297
+ def load_from_wandb_registry(
298
+ modelname: str,
299
+ project: str,
300
+ entity: str = None,
301
+ version: str = 'latest',
302
+ registry: str = 'wandb-registry-model',
303
+ ):
304
+ """
305
+ Loads model from wandb model registry.
306
+ """
307
+ # Get the model version from wandb
308
+ states = None
309
+ config = None
310
+ try:
311
+ artifact = wandb.Api().artifact(f"{registry}/{modelname}:{version}")
312
+ ckpt_dir = artifact.download()
313
+ print(f"Loaded model from wandb registry: {modelname} at path {ckpt_dir}")
314
+ # Load the model from the checkpoint directory
315
+ states = load_from_checkpoint(ckpt_dir)
316
+ run = artifact.logged_by()
317
+ config = run.config
318
+ except Exception as e:
319
+ print(f"Warning: Failed to load model from wandb: {e}")
320
+ return states, config