flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,262 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Any, Optional
5
+ import einops
6
+ from flax.typing import Dtype, PrecisionLike
7
+
8
+ from .attention import NormalAttention
9
+
10
+ def unpatchify(x, channels=3):
11
+ patch_size = int((x.shape[2] // channels) ** 0.5)
12
+ h = w = int(x.shape[1] ** .5)
13
+ assert h * w == x.shape[1] and patch_size ** 2 * \
14
+ channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
15
+ x = einops.rearrange(
16
+ x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
17
+ return x
18
+
19
+
20
+ class PatchEmbedding(nn.Module):
21
+ patch_size: int
22
+ embedding_dim: int
23
+ dtype: Any = jnp.float32
24
+ precision: Any = jax.lax.Precision.HIGH
25
+
26
+ @nn.compact
27
+ def __call__(self, x):
28
+ batch, height, width, channels = x.shape
29
+ assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
30
+
31
+ x = nn.Conv(features=self.embedding_dim,
32
+ kernel_size=(self.patch_size, self.patch_size),
33
+ strides=(self.patch_size, self.patch_size),
34
+ dtype=self.dtype,
35
+ precision=self.precision)(x)
36
+ x = jnp.reshape(x, (batch, -1, self.embedding_dim))
37
+ return x
38
+
39
+
40
+ class PositionalEncoding(nn.Module):
41
+ max_len: int
42
+ embedding_dim: int
43
+
44
+ @nn.compact
45
+ def __call__(self, x):
46
+ pe = self.param('pos_encoding',
47
+ jax.nn.initializers.zeros,
48
+ (1, self.max_len, self.embedding_dim))
49
+ return x + pe[:, :x.shape[1], :]
50
+
51
+
52
+ # --- Rotary Positional Embedding (RoPE) ---
53
+ # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
54
+
55
+
56
+ def _rotate_half(x: jax.Array) -> jax.Array:
57
+ """Rotates half the hidden dims of the input."""
58
+ x1 = x[..., : x.shape[-1] // 2]
59
+ x2 = x[..., x.shape[-1] // 2:]
60
+ return jnp.concatenate((-x2, x1), axis=-1)
61
+
62
+ def apply_rotary_embedding(
63
+ x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
64
+ ) -> jax.Array:
65
+ """Applies rotary embedding to the input tensor using rotate_half method."""
66
+ # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
67
+ # freqs_cos/sin shape: [Sequence, Dimension / 2]
68
+
69
+ # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
70
+ if x.ndim == 4: # [B, H, S, D]
71
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
72
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
73
+ elif x.ndim == 3: # [B, S, D]
74
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
75
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
76
+
77
+ # Duplicate cos and sin for the full dimension D
78
+ # Shape becomes [..., S, D]
79
+ cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
80
+ sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
81
+
82
+ # Apply rotation: x * cos + rotate_half(x) * sin
83
+ x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
84
+ return x_rotated.astype(x.dtype)
85
+
86
+ class RotaryEmbedding(nn.Module):
87
+ dim: int
88
+ max_seq_len: int = 4096 # Increased default based on SimpleDiT
89
+ base: int = 10000
90
+ dtype: Dtype = jnp.float32
91
+
92
+ def setup(self):
93
+ inv_freq = 1.0 / (
94
+ self.base ** (jnp.arange(0, self.dim, 2,
95
+ dtype=jnp.float32) / self.dim)
96
+ )
97
+ t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
98
+ freqs = jnp.outer(t, inv_freq)
99
+ self.freqs_cos = jnp.cos(freqs)
100
+ self.freqs_sin = jnp.sin(freqs)
101
+
102
+ def __call__(self, seq_len: int):
103
+ if seq_len > self.max_seq_len:
104
+ # Dynamically extend frequencies if needed (more robust)
105
+ t = jnp.arange(seq_len, dtype=jnp.float32)
106
+ inv_freq = 1.0 / (
107
+ self.base ** (jnp.arange(0, self.dim, 2,
108
+ dtype=jnp.float32) / self.dim)
109
+ )
110
+ freqs = jnp.outer(t, inv_freq)
111
+ freqs_cos = jnp.cos(freqs)
112
+ freqs_sin = jnp.sin(freqs)
113
+ # Consider caching extended freqs if this happens often
114
+ return freqs_cos, freqs_sin
115
+ # Or raise error like before:
116
+ # raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
117
+ return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
118
+
119
+
120
+ # --- Attention with RoPE ---
121
+
122
+
123
+ class RoPEAttention(NormalAttention):
124
+ rope_emb: RotaryEmbedding = None
125
+
126
+ @nn.compact
127
+ def __call__(self, x, context=None, freqs_cis=None):
128
+ orig_x_shape = x.shape
129
+ is_4d = len(orig_x_shape) == 4
130
+ if is_4d:
131
+ B, H, W, C = x.shape
132
+ seq_len = H * W
133
+ x = x.reshape((B, seq_len, C))
134
+ else:
135
+ B, seq_len, C = x.shape
136
+
137
+ context = x if context is None else context
138
+ if len(context.shape) == 4:
139
+ _B, _H, _W, _C = context.shape
140
+ context_seq_len = _H * _W
141
+ context = context.reshape((B, context_seq_len, _C))
142
+ # else: # context is already [B, S_ctx, C]
143
+
144
+ query = self.query(x) # [B, S, H, D]
145
+ key = self.key(context) # [B, S_ctx, H, D]
146
+ value = self.value(context) # [B, S_ctx, H, D]
147
+
148
+ if freqs_cis is None and self.rope_emb is not None:
149
+ seq_len_q = query.shape[1] # Use query's sequence length
150
+ freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
151
+ elif freqs_cis is not None:
152
+ freqs_cos, freqs_sin = freqs_cis
153
+ else:
154
+ # Should not happen if rope_emb is provided or freqs_cis are passed
155
+ raise ValueError("RoPE frequencies not provided.")
156
+
157
+ # Apply RoPE to query and key
158
+ # Permute to [B, H, S, D] for RoPE application
159
+ query = einops.rearrange(query, 'b s h d -> b h s d')
160
+ key = einops.rearrange(key, 'b s h d -> b h s d')
161
+
162
+ # Apply RoPE only up to the context sequence length for keys if different
163
+ # Assuming self-attention or context has same seq len for simplicity here
164
+ query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
165
+ key = apply_rotary_embedding(
166
+ key, freqs_cos, freqs_sin) # Apply same freqs to key
167
+
168
+ # Permute back to [B, S, H, D] for dot_product_attention
169
+ query = einops.rearrange(query, 'b h s d -> b s h d')
170
+ key = einops.rearrange(key, 'b h s d -> b s h d')
171
+
172
+ hidden_states = nn.dot_product_attention(
173
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
174
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
175
+ deterministic=True
176
+ )
177
+
178
+ proj = self.proj_attn(hidden_states)
179
+
180
+ if is_4d:
181
+ proj = proj.reshape(orig_x_shape)
182
+
183
+ return proj
184
+
185
+
186
+ # --- adaLN-Zero ---
187
+
188
+
189
+ class AdaLNZero(nn.Module):
190
+ features: int
191
+ dtype: Optional[Dtype] = None
192
+ precision: PrecisionLike = None
193
+ norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
194
+
195
+ @nn.compact
196
+ def __call__(self, x, conditioning):
197
+ # Project conditioning signal to get scale and shift parameters
198
+ # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
199
+ # Or [B, 1, 6*features] if x is [B, S, F]
200
+
201
+ # Ensure conditioning has seq dim if x does
202
+ # x=[B,S,F], cond=[B,D_cond]
203
+ if x.ndim == 3 and conditioning.ndim == 2:
204
+ conditioning = jnp.expand_dims(
205
+ conditioning, axis=1) # cond=[B,1,D_cond]
206
+
207
+ # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
208
+ # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
209
+ ada_params = nn.Dense(
210
+ features=6 * self.features,
211
+ dtype=self.dtype,
212
+ precision=self.precision,
213
+ # Initialize projection to zero (Zero init)
214
+ kernel_init=nn.initializers.zeros,
215
+ name="ada_proj"
216
+ )(conditioning)
217
+
218
+ # Split into scale, shift, gate for MLP and Attention
219
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
220
+ ada_params, 6, axis=-1)
221
+
222
+ scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
223
+ shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
224
+ # Apply Layer Normalization
225
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon,
226
+ use_scale=False, use_bias=False, dtype=self.dtype)
227
+ # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
228
+
229
+ norm_x = norm(x)
230
+
231
+ # Modulate for Attention path
232
+ x_attn = norm_x * (1 + scale_attn) + shift_attn
233
+
234
+ # Modulate for MLP path
235
+ x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
236
+
237
+ # Return modulated outputs and gates
238
+ return x_attn, gate_attn, x_mlp, gate_mlp
239
+
240
+ class AdaLNParams(nn.Module): # Renamed for clarity
241
+ features: int
242
+ dtype: Optional[Dtype] = None
243
+ precision: PrecisionLike = None
244
+
245
+ @nn.compact
246
+ def __call__(self, conditioning):
247
+ # Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
248
+ if conditioning.ndim == 2:
249
+ conditioning = jnp.expand_dims(conditioning, axis=1)
250
+
251
+ # Project conditioning to get 6 params per feature
252
+ ada_params = nn.Dense(
253
+ features=6 * self.features,
254
+ dtype=self.dtype,
255
+ precision=self.precision,
256
+ kernel_init=nn.initializers.zeros,
257
+ name="ada_proj"
258
+ )(conditioning)
259
+ # Return all params (or split if preferred, but maybe return tuple/dict)
260
+ # Shape: [B, 1, 6*F]
261
+ return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
262
+
@@ -129,6 +129,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
129
129
  frames_per_sample: int = None,
130
130
  wandb_config: Dict[str, Any] = None,
131
131
  eval_metrics: List[EvaluationMetric] = None,
132
+ best_tracker_metric: str = "train/best_loss",
132
133
  **kwargs
133
134
  ):
134
135
  """
@@ -196,6 +197,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
196
197
  **kwargs
197
198
  )
198
199
 
200
+ self.best_tracker_metric = best_tracker_metric
201
+
199
202
  # Store video-specific parameters
200
203
  self.frames_per_sample = frames_per_sample
201
204
 
@@ -203,6 +206,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
203
206
  self.conditional_inputs = input_config.conditions
204
207
  # Determine if we're working with video or images
205
208
  self.is_video = self._is_video_data()
209
+ self.best_val_metrics = {}
206
210
 
207
211
  def _is_video_data(self):
208
212
  sample_data_shape = self.input_config.sample_data_shape
@@ -423,7 +427,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
423
427
  process_index = jax.process_index()
424
428
  generate_samples = val_step_fn
425
429
 
426
- val_ds = iter(val_ds()) if val_ds else None
430
+ val_ds = iter(val_ds) if val_ds else None
431
+ print(f"Validation loop started for process index {process_index} with {global_device_count} devices.")
427
432
  # Evaluation step
428
433
  try:
429
434
  metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
@@ -465,11 +470,17 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
465
470
  else: # [B,H,W,C] - Image data
466
471
  self._log_image_samples(samples, current_step)
467
472
 
468
- if getattr(self, 'wandb', None) is not None and self.wandb:
469
- # metrics is a dict of metrics
470
- if metrics and type(metrics) == dict:
471
- # Flatten the metrics
472
- metrics = {k: np.mean(v) for k, v in metrics.items()}
473
+ # Flatten the metrics
474
+ if metrics:
475
+ metrics = {k: np.mean(v) for k, v in metrics.items()}
476
+ # Update the best validation metrics
477
+ for key, value in metrics.items():
478
+ if key not in self.best_val_metrics:
479
+ self.best_val_metrics[key] = value
480
+ else:
481
+ self.best_val_metrics[key] = min(self.best_val_metrics[key], value)
482
+ # Log the best validation metrics
483
+ if getattr(self, 'wandb', None) is not None and self.wandb:
473
484
  # Log the metrics
474
485
  for key, value in metrics.items():
475
486
  if isinstance(value, jnp.ndarray):
@@ -477,7 +488,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
477
488
  self.wandb.log({
478
489
  f"val/{key}": value,
479
490
  }, step=current_step)
480
-
491
+ print(f"Validation metrics for process index {process_index}: {metrics}")
492
+
493
+ # Close validation dataset iterator
494
+ del val_ds
481
495
  except StopIteration:
482
496
  print(f"Validation dataset exhausted for process index {process_index}")
483
497
  except Exception as e:
@@ -602,9 +616,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
602
616
  """
603
617
  if self.wandb is None:
604
618
  raise ValueError("Wandb is not initialized. Cannot get best runs.")
605
-
619
+ import wandb
606
620
  # Get the sweep runs
607
- runs = sorted(self.wandb.runs, key=lambda x: x.summary.get(metric, float('inf')))
621
+ runs = [i for i in wandb.Api().runs(path=f"{self.wandb.entity}/{self.wandb.project}", filters={"config.dataset.name": self.wandb.config['dataset']['name']})]
622
+ if not runs:
623
+ raise ValueError("No runs found in wandb.")
624
+ print(f"Getting best runs from wandb {self.wandb.id}...")
625
+ runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
608
626
  best_runs = runs[:top_k]
609
627
  lower_bound = best_runs[-1].summary.get(metric, float('inf'))
610
628
  upper_bound = best_runs[0].summary.get(metric, float('inf'))
@@ -636,6 +654,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
636
654
  # Check if current run is one of the best
637
655
  if metric == "train/best_loss":
638
656
  current_run_metric = self.best_loss
657
+ elif metric in self.best_val_metrics:
658
+ current_run_metric = self.best_val_metrics[metric]
639
659
  else:
640
660
  current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
641
661
 
@@ -653,7 +673,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
653
673
  if self.wandb is not None:
654
674
  checkpoint = get_latest_checkpoint(self.checkpoint_path())
655
675
  try:
656
- is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss", from_sweeps=hasattr(self, "wandb_sweep"))
676
+ is_good, is_best = self.__compare_run_against_best__(top_k=5, metric=self.best_tracker_metric, from_sweeps=hasattr(self, "wandb_sweep"))
657
677
  if is_good:
658
678
  # Push to registry with appropriate aliases
659
679
  aliases = []
@@ -26,6 +26,7 @@ from flax.training.dynamic_scale import DynamicScale
26
26
  from flaxdiff.utils import RandomMarkovState
27
27
  from flax.training import dynamic_scale as dynamic_scale_lib
28
28
  from dataclasses import dataclass
29
+ import shutil
29
30
  import gc
30
31
 
31
32
  PROCESS_COLOR_MAP = {
@@ -73,6 +74,76 @@ class SimpleTrainState(train_state.TrainState):
73
74
  metrics: Metrics
74
75
  dynamic_scale: dynamic_scale_lib.DynamicScale
75
76
 
77
+ def move_contents_to_subdir(target_dir, new_subdir_name):
78
+ # --- 1. Validate Target Directory ---
79
+ if not os.path.isdir(target_dir):
80
+ print(f"Error: Target directory '{target_dir}' not found or is not a directory.")
81
+ return
82
+ # --- 2. Define Paths ---
83
+ # Construct the full path for the new subdirectory
84
+ new_subdir_path = os.path.join(target_dir, new_subdir_name)
85
+ # --- 3. Create New Subdirectory ---
86
+ try:
87
+ # Create the subdirectory.
88
+ # exist_ok=True prevents an error if the directory already exists.
89
+ os.makedirs(new_subdir_path, exist_ok=True)
90
+ print(f"Subdirectory '{new_subdir_path}' created or already exists.")
91
+ except OSError as e:
92
+ print(f"Error creating subdirectory '{new_subdir_path}': {e}")
93
+ return # Stop execution if subdirectory creation fails
94
+ # --- 4. List Contents of Target Directory ---
95
+ try:
96
+ items_to_move = os.listdir(target_dir)
97
+ except OSError as e:
98
+ print(f"Error listing contents of '{target_dir}': {e}")
99
+ return # Stop if we can't list directory contents
100
+ # --- 5. Move Items ---
101
+ print(f"Moving items from '{target_dir}' to '{new_subdir_path}'...")
102
+ moved_count = 0
103
+ error_count = 0
104
+ for item_name in items_to_move:
105
+ # Construct the full path of the item in the target directory
106
+ source_path = os.path.join(target_dir, item_name)
107
+ # IMPORTANT: Skip the newly created subdirectory itself!
108
+ if source_path == new_subdir_path:
109
+ continue
110
+ # Construct the destination path inside the new subdirectory
111
+ destination_path = os.path.join(new_subdir_path, item_name)
112
+ # Move the item
113
+ try:
114
+ shutil.move(source_path, destination_path)
115
+ # print(f" Moved: '{item_name}'") # Uncomment for verbose output
116
+ moved_count += 1
117
+ except Exception as e:
118
+ print(f" Error moving '{item_name}': {e}")
119
+ error_count += 1
120
+ print(f"\nOperation complete.")
121
+ print(f" Successfully moved: {moved_count} item(s).")
122
+ if error_count > 0:
123
+ print(f" Errors encountered: {error_count} item(s).")
124
+
125
+ def load_from_checkpoint(
126
+ checkpoint_dir: str,
127
+ ):
128
+ try:
129
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
130
+ options = orbax.checkpoint.CheckpointManagerOptions(create=False)
131
+ # Convert checkpoint_dir to absolute path
132
+ checkpoint_dir = os.path.abspath(checkpoint_dir)
133
+ manager = orbax.checkpoint.CheckpointManager(checkpoint_dir, checkpointer, options)
134
+ ckpt = manager.restore(checkpoint_dir)
135
+ # Extract as above
136
+ state, best_state = None, None
137
+ if 'state' in ckpt:
138
+ state = ckpt['state']
139
+ if 'best_state' in ckpt:
140
+ best_state = ckpt['best_state']
141
+ print(f"Loaded checkpoint from local dir {checkpoint_dir}")
142
+ return state, best_state
143
+ except Exception as e:
144
+ print(f"Warning: Failed to load checkpoint from local dir: {e}")
145
+ return None, None
146
+
76
147
  @dataclass
77
148
  class SimpleTrainer:
78
149
  state: SimpleTrainState
@@ -97,6 +168,7 @@ class SimpleTrainer:
97
168
  checkpoint_step: int = None,
98
169
  use_dynamic_scale: bool = False,
99
170
  max_checkpoints_to_keep: int = 2,
171
+ train_start_step_override: int = None,
100
172
  ):
101
173
  if distributed_training is None or distributed_training is True:
102
174
  # Auto-detect if we are running on multiple devices
@@ -112,11 +184,32 @@ class SimpleTrainer:
112
184
  self.input_shapes = input_shapes
113
185
  self.checkpoint_base_path = checkpoint_base_path
114
186
 
187
+ load_directly_from_dir = False
188
+
115
189
  if wandb_config is not None and jax.process_index() == 0:
116
190
  import wandb
117
191
  run = wandb.init(resume='allow', **wandb_config)
118
192
  self.wandb = run
119
193
 
194
+ if 'id' in wandb_config:
195
+ # If resuming from a previous run, and train_start_step_override is not set,
196
+ # set the start step to the last step of the previous run
197
+ if train_start_step_override is None:
198
+ train_start_step_override = run.summary['train/step'] + 1
199
+ print(f"Resuming from previous run {wandb_config['id']} with start step {train_start_step_override}")
200
+
201
+ # If load_from_checkpoint is not set, and an artifact is found, load the artifact
202
+ if load_from_checkpoint is None:
203
+ api_run = wandb.Api().run(f"{wandb_config['entity']}/{wandb_config['project']}/{wandb_config['id']}")
204
+ model_artifacts = [i for i in api_run.logged_artifacts() if i.type == 'model']
205
+ if model_artifacts:
206
+ artifact = model_artifacts[0]
207
+ artifact_dir = artifact.download()
208
+ print(f"Loading model from artifact {artifact.name} at {artifact_dir}")
209
+ # Move the artifact's contents
210
+ load_from_checkpoint = artifact_dir
211
+ load_directly_from_dir = True
212
+
120
213
  # define our custom x axis metric
121
214
  self.wandb.define_metric("train/step")
122
215
  self.wandb.define_metric("train/epoch")
@@ -142,12 +235,16 @@ class SimpleTrainer:
142
235
  self.checkpoint_path(), async_checkpointer, options)
143
236
 
144
237
  if load_from_checkpoint is not None:
145
- latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
238
+ latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step, load_directly_from_dir)
146
239
  else:
147
- latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
240
+ latest_step, old_state, old_best_state, rngstate = 0, None, None, None
148
241
 
149
242
  self.latest_step = latest_step
150
243
 
244
+ if train_start_step_override is not None:
245
+ self.latest_step = train_start_step_override
246
+ print(f"Overriding start step to {self.latest_step}")
247
+
151
248
  if rngstate:
152
249
  self.rngstate = RandomMarkovState(**rngstate)
153
250
  else:
@@ -239,15 +336,12 @@ class SimpleTrainer:
239
336
  os.makedirs(path)
240
337
  return path
241
338
 
242
- def load(self, checkpoint_path=None, checkpoint_step=None):
243
- if checkpoint_path is None:
244
- checkpointer = self.checkpointer
245
- else:
246
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
247
- options = orbax.checkpoint.CheckpointManagerOptions(
248
- max_to_keep=4, create=False)
249
- checkpointer = orbax.checkpoint.CheckpointManager(
250
- checkpoint_path, checkpointer, options)
339
+ def load(self, checkpoint_path, checkpoint_step=None, load_directly_from_dir=False):
340
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
341
+ options = orbax.checkpoint.CheckpointManagerOptions(
342
+ max_to_keep=4, create=False)
343
+ checkpointer = orbax.checkpoint.CheckpointManager(
344
+ checkpoint_path, checkpointer, options)
251
345
 
252
346
  if checkpoint_step is None:
253
347
  step = checkpointer.latest_step()
@@ -259,7 +353,8 @@ class SimpleTrainer:
259
353
  checkpoint_path if checkpoint_path else self.checkpoint_path(),
260
354
  f"{step}")
261
355
  self.loaded_checkpoint_path = loaded_checkpoint_path
262
- ckpt = checkpointer.restore(step)
356
+ ckpt = checkpointer.restore(step) if not load_directly_from_dir else checkpointer.restore(checkpoint_path)
357
+
263
358
  state = ckpt['state']
264
359
  best_state = ckpt['best_state']
265
360
  rngstate = ckpt['rngs']
@@ -268,10 +363,8 @@ class SimpleTrainer:
268
363
  if self.best_loss == 0:
269
364
  # It cant be zero as that must have been some problem
270
365
  self.best_loss = 1e9
271
- current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
272
- print(
273
- f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
274
- return current_epoch, step, state, best_state, rngstate
366
+ print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
367
+ return step, state, best_state, rngstate
275
368
 
276
369
  def save(self, epoch=0, step=0, state=None, rngstate=None):
277
370
  print(f"Saving model at epoch {epoch} step {step}")
@@ -507,6 +600,7 @@ class SimpleTrainer:
507
600
 
508
601
  def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
509
602
  train_ds = iter(data['train']())
603
+ val_ds = data.get('val', data.get('test', None))()
510
604
  train_step = self._define_train_step(**train_step_args)
511
605
  val_step = self._define_validation_step(**validation_step_args)
512
606
  train_state = self.state
@@ -520,7 +614,7 @@ class SimpleTrainer:
520
614
  self.validation_loop(
521
615
  train_state,
522
616
  val_step,
523
- data.get('val', data.get('test', None)),
617
+ val_ds,
524
618
  val_steps_per_epoch,
525
619
  self.latest_step,
526
620
  )
@@ -571,11 +665,11 @@ class SimpleTrainer:
571
665
  self.validation_loop(
572
666
  train_state,
573
667
  val_step,
574
- data.get('val', data.get('test', None)),
668
+ val_ds,
575
669
  val_steps_per_epoch,
576
670
  current_step,
577
671
  )
578
672
  print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
579
673
 
580
- self.save(epochs)
674
+ self.save(epochs)#
581
675
  return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -2,20 +2,20 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=LV8ugqoB86yihfYeOJZHHdRZJNmZ63A2NQkdILMR9QA,23564
6
- flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
5
+ flaxdiff/data/dataloaders.py,sha256=k_3YGJhiY2Wt_-7qK0Yjl4pmF2QJjX_-BlSFuXbH5-M,23628
6
+ flaxdiff/data/dataset_map.py,sha256=p30U23RkfgMbR8kfPBDIjrjfzDBszWQ9Q1ff2BvDYZk,5116
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
10
  flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
- flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
12
- flaxdiff/data/sources/images.py,sha256=RFLtKW1xzw6ZPVXtCMmnTg1MPb8dc7rP77rZWbK7qpo,11796
11
+ flaxdiff/data/sources/base.py,sha256=4Rm9pCtXxzoB8FO0lkDHsrX3ULoU_PNNcid978e6ir0,4610
12
+ flaxdiff/data/sources/images.py,sha256=ZHBmZ2fnPN75Hc2kiog-Wcs_NZJZOiqw4WcSH5WZJHA,16572
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
- flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
14
+ flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjoNE,9545
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
16
  flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
18
- flaxdiff/inference/utils.py,sha256=MVnWl0LnC-1ILk0SsLd1YFu6igaQFR7mGhzo0jE797E,12323
18
+ flaxdiff/inference/utils.py,sha256=JEBZYSgj-0DLJTV-TNmIAllAqqVJMn0KfryHwFO-MFs,12606
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
21
  flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -27,14 +27,17 @@ flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
28
28
  flaxdiff/models/__init__.py,sha256=amtDF07DfiAdnZsvWX4eaW79nwNEU1s8Zb4PB3ewtg4,118
29
29
  flaxdiff/models/attention.py,sha256=YkED3_MRTjI9aTMTTQdsuReHhG8MK0Z4OVuU2j8ZAHs,13524
30
- flaxdiff/models/better_uvit.py,sha256=wPxvYBjuWQH6-OqW79VedzN6_WRY1f2mysPxaciWLww,15598
31
- flaxdiff/models/common.py,sha256=0j9AAjGPgkBLHo2DlYj0R6OsUNw2QaoDjaXSKq2mqkA,12647
30
+ flaxdiff/models/common.py,sha256=QpciwuJldvLUwyAyWBQqiPPGVI-c9qLR7h7C1YoRX7w,10510
32
31
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
33
32
  flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
33
+ flaxdiff/models/hilbert.py,sha256=AjlAv49dL6UAYWslMJfCMLiFqY4kTgpiUWr2nc1mk34,24823
34
+ flaxdiff/models/simple_dit.py,sha256=l238MYHRTArv_pS57aY24C2PTfxeL8EmzJ24iQqdoWI,11702
35
+ flaxdiff/models/simple_mmdit.py,sha256=ARk0juopn2k7giln5BAUrnYD1pTFwgTJoSzrhozQ0A8,31356
34
36
  flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
35
- flaxdiff/models/simple_vit.py,sha256=no0o3os8dEKGU5I4PMBJlXq6XKjhUex8S8uZ9BDPZS4,7971
37
+ flaxdiff/models/simple_vit.py,sha256=J9s3hBF87_iVrJDBe2cs9a56N7ect6pux_f_ge07XXc,17357
36
38
  flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
37
39
  flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
40
+ flaxdiff/models/vit_common.py,sha256=1OGu4ezY3uzKinTnw3p8YkQAslHDqEbN78JheXnTleI,9831
38
41
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
39
42
  flaxdiff/models/autoencoder/autoencoder.py,sha256=8XWdsWvsPsyWGtzpCT8w0KXi_ZLGpRuQpn4oXo1gHKw,6039
40
43
  flaxdiff/models/autoencoder/diffusers.py,sha256=tPz77YuctrT--jF2AOL8G6vr0NiIr3RXANNrZCxe0bg,5921
@@ -60,9 +63,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
60
63
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
61
64
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
62
65
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
63
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=BeDpJzgR8bUClJI4epQXlAul27MwiSfRW0lIBZSiPWk,28342
64
- flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
65
- flaxdiff-0.2.7.dist-info/METADATA,sha256=nwglJYeF2lH_MNq5PeFLR8TSPU-I9tzJUcBbTaLYxRM,24057
66
- flaxdiff-0.2.7.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
67
- flaxdiff-0.2.7.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
68
- flaxdiff-0.2.7.dist-info/RECORD,,
66
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=OtE2spZIBFPpY6q-ijYol5Y-CaP2UHJYIDX3PFBiPtg,29492
67
+ flaxdiff/trainer/simple_trainer.py,sha256=Hdltuo3lgF61N04Lxc7L3z6NLveW4_h1ff7_5mu3Wbg,28730
68
+ flaxdiff-0.2.9.dist-info/METADATA,sha256=a8btxHRkAZVieuZfTyXgPkJbEG9fZRknEhq2Ti3_7m4,24057
69
+ flaxdiff-0.2.9.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
70
+ flaxdiff-0.2.9.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
71
+ flaxdiff-0.2.9.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.1.0)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5