flaxdiff 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,105 +11,6 @@ import functools
11
11
  import math
12
12
  from .common import kernel_init
13
13
 
14
- def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
15
- """Multi-head dot product attention with a limited number of queries."""
16
- num_kv, num_heads, k_features = key.shape[-3:]
17
- v_features = value.shape[-1]
18
- key_chunk_size = min(key_chunk_size, num_kv)
19
- query = query / jnp.sqrt(k_features)
20
-
21
- @functools.partial(jax.checkpoint, prevent_cse=False)
22
- def summarize_chunk(query, key, value):
23
- attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
24
-
25
- max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
26
- max_score = jax.lax.stop_gradient(max_score)
27
- exp_weights = jnp.exp(attn_weights - max_score)
28
-
29
- exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
30
- max_score = jnp.einsum("...qhk->...qh", max_score)
31
-
32
- return (exp_values, exp_weights.sum(axis=-1), max_score)
33
-
34
- def chunk_scanner(chunk_idx):
35
- # julienne key array
36
- key_chunk = jax.lax.dynamic_slice(
37
- operand=key,
38
- start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
39
- slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
40
- )
41
-
42
- # julienne value array
43
- value_chunk = jax.lax.dynamic_slice(
44
- operand=value,
45
- start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
46
- slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
47
- )
48
-
49
- return summarize_chunk(query, key_chunk, value_chunk)
50
-
51
- chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
52
-
53
- global_max = jnp.max(chunk_max, axis=0, keepdims=True)
54
- max_diffs = jnp.exp(chunk_max - global_max)
55
-
56
- chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
57
- chunk_weights *= max_diffs
58
-
59
- all_values = chunk_values.sum(axis=0)
60
- all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
61
-
62
- return all_values / all_weights
63
-
64
-
65
- def jax_memory_efficient_attention(
66
- query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
67
- ):
68
- r"""
69
- Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
70
- https://github.com/AminRezaei0x443/memory-efficient-attention
71
-
72
- Args:
73
- query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
74
- key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
75
- value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
76
- precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
77
- numerical precision for computation
78
- query_chunk_size (`int`, *optional*, defaults to 1024):
79
- chunk size to divide query array value must divide query_length equally without remainder
80
- key_chunk_size (`int`, *optional*, defaults to 4096):
81
- chunk size to divide key and value array value must divide key_value_length equally without remainder
82
-
83
- Returns:
84
- (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
85
- """
86
- num_q, num_heads, q_features = query.shape[-3:]
87
-
88
- def chunk_scanner(chunk_idx, _):
89
- # julienne query array
90
- query_chunk = jax.lax.dynamic_slice(
91
- operand=query,
92
- start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
93
- slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
94
- )
95
-
96
- return (
97
- chunk_idx + query_chunk_size, # unused ignore it
98
- _query_chunk_attention(
99
- query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
100
- ),
101
- )
102
-
103
- _, res = jax.lax.scan(
104
- f=chunk_scanner,
105
- init=0,
106
- xs=None,
107
- length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
108
- )
109
-
110
- return jnp.concatenate(res, axis=-3) # fuse the chunked result back
111
-
112
-
113
14
  class EfficientAttention(nn.Module):
114
15
  """
115
16
  Based on the pallas attention implementation.
@@ -125,41 +26,77 @@ class EfficientAttention(nn.Module):
125
26
  def setup(self):
126
27
  inner_dim = self.dim_head * self.heads
127
28
  # Weights were exported with old names {to_q, to_k, to_v, to_out}
128
- self.query = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
129
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_q")
130
- self.key = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
131
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_k")
132
- self.value = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
133
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_v")
29
+ dense = functools.partial(
30
+ nn.Dense,
31
+ self.heads * self.dim_head,
32
+ precision=self.precision,
33
+ use_bias=self.use_bias,
34
+ kernel_init=self.kernel_init(),
35
+ dtype=self.dtype
36
+ )
37
+ self.query = dense(name="to_q")
38
+ self.key = dense(name="to_k")
39
+ self.value = dense(name="to_v")
40
+
134
41
  self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
135
42
  kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
136
43
  # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
44
+
45
+ def _reshape_tensor_to_head_dim(self, tensor):
46
+ batch_size, _, seq_len, dim = tensor.shape
47
+ head_size = self.heads
48
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
49
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
50
+ return tensor
51
+
52
+ def _reshape_tensor_from_head_dim(self, tensor):
53
+ batch_size, _, seq_len, dim = tensor.shape
54
+ head_size = self.heads
55
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
56
+ tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size)
57
+ return tensor
137
58
 
138
59
  @nn.compact
139
60
  def __call__(self, x:jax.Array, context=None):
61
+ # print(x.shape)
140
62
  # x has shape [B, H * W, C]
141
63
  context = x if context is None else context
64
+
65
+ B, H, W, C = x.shape
66
+ x = x.reshape((B, 1, H * W, C))
67
+
68
+ if len(context.shape) == 4:
69
+ B, _H, _W, _C = context.shape
70
+ context = context.reshape((B, 1, _H * _W, _C))
71
+ else:
72
+ B, SEQ, _C = context.shape
73
+ context = context.reshape((B, 1, SEQ, _C))
74
+
142
75
  query = self.query(x)
143
76
  key = self.key(context)
144
77
  value = self.value(context)
145
78
 
146
- # print(query.shape, key.shape, value.shape)
79
+ query = self._reshape_tensor_to_head_dim(query)
80
+ key = self._reshape_tensor_to_head_dim(key)
81
+ value = self._reshape_tensor_to_head_dim(value)
147
82
 
148
- # hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.mha_reference(
149
- # query, key, value, None
150
- # )
151
-
152
- hidden_states = nn.dot_product_attention(
153
- query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
83
+ hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(
84
+ query, key, value, None
154
85
  )
155
- # hidden_states = self.attnfn(
156
- # query, key, value, None
86
+
87
+ hidden_states = self._reshape_tensor_from_head_dim(hidden_states)
88
+
89
+
90
+ # hidden_states = nn.dot_product_attention(
91
+ # query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
157
92
  # )
158
93
 
159
94
  proj = self.proj_attn(hidden_states)
95
+
96
+ proj = proj.reshape((B, H, W, C))
97
+
160
98
  return proj
161
99
 
162
-
163
100
  class NormalAttention(nn.Module):
164
101
  """
165
102
  Simple implementation of the normal attention.
@@ -201,7 +138,11 @@ class NormalAttention(nn.Module):
201
138
  @nn.compact
202
139
  def __call__(self, x, context=None):
203
140
  # x has shape [B, H, W, C]
141
+ B, H, W, C = x.shape
142
+ x = x.reshape((B, H*W, C))
204
143
  context = x if context is None else context
144
+ if len(context.shape) == 4:
145
+ context = context.reshape((B, H*W, C))
205
146
  query = self.query(x)
206
147
  key = self.key(context)
207
148
  value = self.value(context)
@@ -210,6 +151,7 @@ class NormalAttention(nn.Module):
210
151
  query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
211
152
  )
212
153
  proj = self.proj_attn(hidden_states)
154
+ proj = proj.reshape((B, H, W, C))
213
155
  return proj
214
156
 
215
157
  class AttentionBlock(nn.Module):
flaxdiff/models/common.py CHANGED
@@ -2,6 +2,6 @@ import jax.numpy as jnp
2
2
  from flax import linen as nn
3
3
 
4
4
  # Kernel initializer to use
5
- def kernel_init(scale):
5
+ def kernel_init(scale, dtype=jnp.float32):
6
6
  scale = max(scale, 1e-10)
7
- return nn.initializers.variance_scaling(scale=scale, mode="fan_in", distribution="truncated_normal")
7
+ return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
@@ -5,6 +5,7 @@ from typing import Dict, Callable, Sequence, Any, Union
5
5
  import einops
6
6
  from .common import kernel_init
7
7
  from .attention import TransformerBlock
8
+
8
9
  class WeightStandardizedConv(nn.Module):
9
10
  """
10
11
  apply weight standardization https://arxiv.org/abs/1903.10520
@@ -243,6 +244,7 @@ def l2norm(t, axis=1, eps=1e-12):
243
244
  denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
244
245
  out = t/denom
245
246
  return (out)
247
+
246
248
  class ResidualBlock(nn.Module):
247
249
  conv_type:str
248
250
  features:int
@@ -327,7 +329,7 @@ class Unet(nn.Module):
327
329
  precision: Any = jax.lax.Precision.HIGH
328
330
 
329
331
  @nn.compact
330
- def __call__(self, x, temb, textcontext=None):
332
+ def __call__(self, x, temb, textcontext):
331
333
  # print("embedding features", self.emb_features)
332
334
  temb = FourierEmbedding(features=self.emb_features)(temb)
333
335
  temb = TimeProjection(features=self.emb_features)(temb)
@@ -340,7 +342,7 @@ class Unet(nn.Module):
340
342
 
341
343
  conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
342
344
  # middle_conv_type = "separable"
343
-
345
+
344
346
  x = ConvLayer(
345
347
  conv_type,
346
348
  features=self.feature_depths[0],
@@ -370,18 +372,13 @@ class Unet(nn.Module):
370
372
  precision=self.precision
371
373
  )(x, temb)
372
374
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
373
- B, H, W, _ = x.shape
374
- if H > TS:
375
- padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
376
- else:
377
- padded_context = None
378
375
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
379
376
  dim_head=dim_in // attention_config['heads'],
380
377
  use_flash_attention=attention_config.get("flash_attention", True),
381
378
  use_projection=attention_config.get("use_projection", False),
382
379
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
383
380
  precision=attention_config.get("precision", self.precision),
384
- name=f"down_{i}_attention_{j}")(x, padded_context)
381
+ name=f"down_{i}_attention_{j}")(x, textcontext)
385
382
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
386
383
  downs.append(x)
387
384
  if i != len(feature_depths) - 1:
@@ -419,7 +416,7 @@ class Unet(nn.Module):
419
416
  use_projection=middle_attention.get("use_projection", False),
420
417
  use_self_and_cross=False,
421
418
  precision=attention_config.get("precision", self.precision),
422
- name=f"middle_attention_{j}")(x)
419
+ name=f"middle_attention_{j}")(x, textcontext)
423
420
  x = ResidualBlock(
424
421
  middle_conv_type,
425
422
  name=f"middle_res2_{j}",
@@ -454,18 +451,13 @@ class Unet(nn.Module):
454
451
  precision=self.precision
455
452
  )(x, temb)
456
453
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
457
- B, H, W, _ = x.shape
458
- if H > TS:
459
- padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
460
- else:
461
- padded_context = None
462
454
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
463
455
  dim_head=dim_out // attention_config['heads'],
464
456
  use_flash_attention=attention_config.get("flash_attention", True),
465
457
  use_projection=attention_config.get("use_projection", False),
466
458
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
467
459
  precision=attention_config.get("precision", self.precision),
468
- name=f"up_{i}_attention_{j}")(x, padded_context)
460
+ name=f"up_{i}_attention_{j}")(x, textcontext)
469
461
  # print("Upscaling ", i, x.shape)
470
462
  if i != len(feature_depths) - 1:
471
463
  x = Upsample(
@@ -0,0 +1,123 @@
1
+ # simple_vit.py
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import linen as nn
6
+ from typing import Callable, Any
7
+ from .simply_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
8
+ from .attention import TransformerBlock
9
+
10
+ class PatchEmbedding(nn.Module):
11
+ patch_size: int
12
+ embedding_dim: int
13
+ dtype: Any = jnp.float32
14
+ precision: Any = jax.lax.Precision.HIGH
15
+
16
+ @nn.compact
17
+ def __call__(self, x):
18
+ batch, height, width, channels = x.shape
19
+ assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
20
+
21
+ x = nn.Conv(features=self.embedding_dim,
22
+ kernel_size=(self.patch_size, self.patch_size),
23
+ strides=(self.patch_size, self.patch_size),
24
+ dtype=self.dtype,
25
+ precision=self.precision)(x)
26
+ x = jnp.reshape(x, (batch, -1, self.embedding_dim))
27
+ return x
28
+
29
+ class PositionalEncoding(nn.Module):
30
+ max_len: int
31
+ embedding_dim: int
32
+
33
+ @nn.compact
34
+ def __call__(self, x):
35
+ pe = self.param('pos_encoding',
36
+ jax.nn.initializers.zeros,
37
+ (1, self.max_len, self.embedding_dim))
38
+ return x + pe[:, :x.shape[1], :]
39
+
40
+ class TransformerEncoder(nn.Module):
41
+ num_layers: int
42
+ num_heads: int
43
+ mlp_dim: int
44
+ dropout_rate: float = 0.1
45
+ dtype: Any = jnp.float32
46
+ precision: Any = jax.lax.Precision.HIGH
47
+
48
+ @nn.compact
49
+ def __call__(self, x, training=True):
50
+ for _ in range(self.num_layers):
51
+ x = TransformerBlock(
52
+ heads=self.num_heads,
53
+ dim_head=x.shape[-1] // self.num_heads,
54
+ mlp_dim=self.mlp_dim,
55
+ dropout_rate=self.dropout_rate,
56
+ dtype=self.dtype,
57
+ precision=self.precision
58
+ )(x)
59
+ return x
60
+
61
+ class VisionTransformer(nn.Module):
62
+ patch_size: int = 16
63
+ embedding_dim: int = 768
64
+ num_layers: int = 12
65
+ num_heads: int = 12
66
+ mlp_dim: int = 3072
67
+ emb_features: int = 256
68
+ dropout_rate: float = 0.1
69
+ dtype: Any = jnp.float32
70
+ precision: Any = jax.lax.Precision.HIGH
71
+
72
+ @nn.compact
73
+ def __call__(self, x, temb, textcontext=None):
74
+ # Time embedding
75
+ temb = FourierEmbedding(features=self.emb_features)(temb)
76
+ temb = TimeProjection(features=self.emb_features)(temb)
77
+
78
+ # Patch embedding
79
+ x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.embedding_dim,
80
+ dtype=self.dtype, precision=self.precision)(x)
81
+
82
+ # Add positional encoding
83
+ x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)
84
+
85
+ # Add time embedding
86
+ temb = jnp.expand_dims(temb, axis=1)
87
+ x = jnp.concatenate([x, temb], axis=1)
88
+
89
+ # Add text context
90
+ if textcontext is not None:
91
+ x = jnp.concatenate([x, textcontext], axis=1)
92
+
93
+ # Transformer encoder
94
+ x = TransformerEncoder(
95
+ num_layers=self.num_layers,
96
+ num_heads=self.num_heads,
97
+ mlp_dim=self.mlp_dim,
98
+ dropout_rate=self.dropout_rate,
99
+ dtype=self.dtype,
100
+ precision=self.precision
101
+ )(x)
102
+
103
+ # Extract the image tokens (exclude time and text embeddings)
104
+ num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
105
+ x = x[:, :num_patches, :]
106
+
107
+ # Reshape to image dimensions
108
+ batch, _, _ = x.shape
109
+ height = width = int((num_patches) ** 0.5)
110
+ x = jnp.reshape(x, (batch, height, width, self.embedding_dim))
111
+
112
+ # Final convolution to get the desired output channels
113
+ x = ConvLayer(
114
+ conv_type="conv",
115
+ features=3,
116
+ kernel_size=(3, 3),
117
+ strides=(1, 1),
118
+ kernel_init=kernel_init(0.0),
119
+ dtype=self.dtype,
120
+ precision=self.precision
121
+ )(x)
122
+
123
+ return x
@@ -17,18 +17,9 @@ from flax.training import orbax_utils
17
17
  from ..schedulers import NoiseScheduler
18
18
  from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
19
19
 
20
- @struct.dataclass
21
- class Metrics(metrics.Collection):
22
- loss: metrics.Average.from_output('loss') # type: ignore
20
+ from .simple_trainer import SimpleTrainer, SimpleTrainState
23
21
 
24
- class ModelState():
25
- model: nn.Module
26
- params: dict
27
- noise_schedule: NoiseScheduler
28
- model_output_transform: DiffusionPredictionTransform
29
-
30
- # Define the TrainState with EMA parameters
31
- class TrainState(train_state.TrainState):
22
+ class TrainState(SimpleTrainState):
32
23
  rngs: jax.random.PRNGKey
33
24
  ema_params: dict
34
25
 
@@ -36,7 +27,7 @@ class TrainState(train_state.TrainState):
36
27
  rngs, subkey = jax.random.split(self.rngs)
37
28
  return self.replace(rngs=rngs), subkey
38
29
 
39
- def apply_ema(self, decay: float=0.999):
30
+ def apply_ema(self, decay: float = 0.999):
40
31
  new_ema_params = jax.tree_util.tree_map(
41
32
  lambda ema, param: decay * ema + (1 - decay) * param,
42
33
  self.ema_params,
@@ -44,141 +35,142 @@ class TrainState(train_state.TrainState):
44
35
  )
45
36
  return self.replace(ema_params=new_ema_params)
46
37
 
47
- class DiffusionTrainer:
48
- state : TrainState
49
- best_state : TrainState
50
- best_loss : float
51
- model : nn.Module
52
- noise_schedule : NoiseScheduler
53
- model_output_transform:DiffusionPredictionTransform
54
- ema_decay:float = 0.999
55
-
56
- def __init__(self,
57
- model:nn.Module,
38
+ class DiffusionTrainer(SimpleTrainer):
39
+ noise_schedule: NoiseScheduler
40
+ model_output_transform: DiffusionPredictionTransform
41
+ ema_decay: float = 0.999
42
+
43
+ def __init__(self,
44
+ model: nn.Module,
45
+ input_shapes: Dict[str, Tuple[int]],
58
46
  optimizer: optax.GradientTransformation,
59
- noise_schedule:NoiseScheduler,
60
- rngs:jax.random.PRNGKey,
61
- train_state:TrainState=None,
62
- name:str="Diffusion",
63
- load_from_checkpoint:bool=False,
64
- param_transforms:Callable=None,
65
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
66
- loss_fn=optax.l2_loss,
47
+ noise_schedule: NoiseScheduler,
48
+ rngs: jax.random.PRNGKey,
49
+ unconditional_prob: float = 0.2,
50
+ name: str = "Diffusion",
51
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
52
+ **kwargs
67
53
  ):
68
- self.model = model
54
+ super().__init__(
55
+ model=model,
56
+ input_shapes=input_shapes,
57
+ optimizer=optimizer,
58
+ rngs=rngs,
59
+ name=name,
60
+ **kwargs
61
+ )
69
62
  self.noise_schedule = noise_schedule
70
- self.name = name
71
63
  self.model_output_transform = model_output_transform
72
- self.loss_fn = loss_fn
73
-
74
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
75
- options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
76
- self.checkpointer = orbax.checkpoint.CheckpointManager(self.checkpoint_path(), checkpointer, options)
64
+ self.unconditional_prob = unconditional_prob
65
+
66
+ def __init_fn(
67
+ self,
68
+ optimizer: optax.GradientTransformation,
69
+ rngs: jax.random.PRNGKey,
70
+ existing_state: dict = None,
71
+ existing_best_state: dict = None,
72
+ model: nn.Module = None,
73
+ param_transforms: Callable = None
74
+ ) -> Tuple[TrainState, TrainState]:
75
+ rngs, subkey = jax.random.split(rngs)
77
76
 
78
- if load_from_checkpoint:
79
- params = self.load()
77
+ if existing_state == None:
78
+ input_vars = self.get_input_ones()
79
+ params = model.init(subkey, **input_vars)
80
+ new_state = {"params": params, "ema_params": params}
80
81
  else:
81
- params = None
82
+ new_state = existing_state
82
83
 
83
- if train_state == None:
84
- self.init_state(optimizer, rngs, params=params, model=model, param_transforms=param_transforms)
85
- else:
86
- self.state = train_state
87
- self.best_state = train_state
88
- self.best_loss = 1e9
89
-
90
- def init_state(self,
91
- optimizer: optax.GradientTransformation,
92
- rngs:jax.random.PRNGKey,
93
- params:dict=None,
94
- model:nn.Module=None,
95
- param_transforms:Callable=None,
96
- batch_size=16,
97
- image_size=64
98
- ):
99
- inp = jnp.ones((batch_size, image_size, image_size, 3))
100
- temb = jnp.ones((batch_size,))
101
- rngs, subkey = jax.random.split(rngs)
102
- if params == None:
103
- params = model.init(subkey, inp, temb)
104
84
  if param_transforms is not None:
105
85
  params = param_transforms(params)
106
- self.best_loss = 1e9
107
- self.state = TrainState.create(
86
+
87
+ state = TrainState.create(
108
88
  apply_fn=model.apply,
109
- params=params,
110
- ema_params=params,
89
+ params=new_state['params'],
90
+ ema_params=new_state['ema_params'],
111
91
  tx=optimizer,
112
92
  rngs=rngs,
93
+ metrics=Metrics.empty()
113
94
  )
114
- self.best_state = self.state
115
-
116
- def checkpoint_path(self):
117
- experiment_name = self.name
118
- path = os.path.join(os.path.abspath('./models'), experiment_name)
119
- if not os.path.exists(path):
120
- os.makedirs(path)
121
- return path
122
-
123
- def load(self):
124
- step = self.checkpointer.latest_step()
125
- print("Loading model from checkpoint", step)
126
- ckpt = self.checkpointer.restore(step)
127
- state = ckpt['state']
128
- # Convert the state to a TrainState
129
- self.best_loss = ckpt['best_loss']
130
- print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
131
- return state.get('params', None)#, ckpt.get('model', None)
132
-
133
- def save(self, epoch=0, best=False):
134
- print(f"Saving model at epoch {epoch}")
135
- state = self.best_state if best else self.state
136
- # filename = os.path.join(self.checkpoint_path(), f'model_{epoch}' if not best else 'best_model')
137
- ckpt = {
138
- 'model': self.model,
139
- 'state': state,
140
- 'best_loss': self.best_loss
141
- }
142
- save_args = orbax_utils.save_args_from_target(ckpt)
143
- self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args})
144
-
145
- def summary(self, image_size=64):
146
- inp = jnp.ones((1, image_size, image_size, 3))
147
- temb = jnp.ones((1,))
148
- print(self.model.tabulate(jax.random.key(0), inp, temb, console_kwargs={"width": 200, "force_jupyter":True, }))
149
-
150
- def _define_train_step(self):
95
+
96
+ if existing_best_state is not None:
97
+ best_state = state.replace(
98
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
99
+ else:
100
+ best_state = state
101
+
102
+ return state, best_state
103
+
104
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
151
105
  noise_schedule = self.noise_schedule
152
106
  model = self.model
153
107
  model_output_transform = self.model_output_transform
154
108
  loss_fn = self.loss_fn
155
- @jax.jit
156
- def train_step(state:TrainState, batch):
109
+ unconditional_prob = self.unconditional_prob
110
+
111
+ # Determine the number of unconditional samples
112
+ num_unconditional = int(batch_size * unconditional_prob)
113
+
114
+ nS, nC = null_labels_seq.shape
115
+ null_labels_seq = jnp.broadcast_to(
116
+ null_labels_seq, (batch_size, nS, nC))
117
+
118
+ distributed_training = self.distributed_training
119
+
120
+ def train_step(state: TrainState, batch):
157
121
  """Train for a single step."""
158
- images = batch
159
- noise_level, state = noise_schedule.generate_timesteps(images.shape[0], state)
122
+ images = batch['image']
123
+ # normalize image
124
+ images = (images - 127.5) / 127.5
125
+
126
+ output = text_embedder(
127
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
128
+ # output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
129
+
130
+ label_seq = output.last_hidden_state
131
+
132
+ # Generate random probabilities to decide how much of this batch will be unconditional
133
+
134
+ label_seq = jnp.concat(
135
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
136
+
137
+ noise_level, state = noise_schedule.generate_timesteps(
138
+ images.shape[0], state)
160
139
  state, rngs = state.get_random_key()
161
- noise:jax.Array = jax.random.normal(rngs, shape=images.shape)
140
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
162
141
  rates = noise_schedule.get_rates(noise_level)
163
- noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
142
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
143
+ images, noise, rates)
144
+
164
145
  def model_loss(params):
165
- preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level))
166
- preds = model_output_transform.pred_transform(noisy_images, preds, rates)
146
+ preds = model.apply(
147
+ params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
148
+ preds = model_output_transform.pred_transform(
149
+ noisy_images, preds, rates)
167
150
  nloss = loss_fn(preds, expected_output)
168
151
  # nloss = jnp.mean(nloss, axis=1)
169
152
  nloss *= noise_schedule.get_weights(noise_level)
170
153
  nloss = jnp.mean(nloss)
171
154
  loss = nloss
172
155
  return loss
156
+
173
157
  loss, grads = jax.value_and_grad(model_loss)(state.params)
174
- state = state.apply_gradients(grads=grads)
158
+ if distributed_training:
159
+ grads = jax.lax.pmean(grads, "device")
160
+ state = state.apply_gradients(grads=grads)
175
161
  state = state.apply_ema(self.ema_decay)
176
162
  return state, loss
163
+
164
+ if distributed_training:
165
+ train_step = jax.pmap(axis_name="device")(train_step)
166
+ else:
167
+ train_step = jax.jit(train_step)
168
+
177
169
  return train_step
178
-
170
+
179
171
  def _define_compute_metrics(self):
180
172
  @jax.jit
181
- def compute_metrics(state:TrainState, expected, pred):
173
+ def compute_metrics(state: TrainState, expected, pred):
182
174
  loss = jnp.mean(jnp.square(pred - expected))
183
175
  metric_updates = state.metrics.single_from_model_output(loss=loss)
184
176
  metrics = state.metrics.merge(metric_updates)
@@ -187,20 +179,13 @@ class DiffusionTrainer:
187
179
  return compute_metrics
188
180
 
189
181
  def fit(self, data, steps_per_epoch, epochs):
190
- data = iter(data)
191
- train_step = self._define_train_step()
192
- compute_metrics = self._define_compute_metrics()
193
- state = self.state
194
- for epoch in range(epochs):
195
- print(f"\nEpoch {epoch+1}/{epochs}")
196
- start_time = time.time()
197
- epoch_loss = 0
198
- with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {epoch+1}', ncols=100, unit='step') as pbar:
199
- for i in range(steps_per_epoch):
200
- batch = next(data)
201
- state, loss = train_step(state, batch)
202
- epoch_loss += loss
203
- if i % 100 == 0:
182
+ null_labels_full = data['null_labels_full']
183
+ local_batch_size = data['local_batch_size']
184
+ text_embedder = data['model']
185
+ super().fit(data, steps_per_epoch, epochs, {
186
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
187
+
188
+
204
189
  pbar.set_postfix(loss=f'{loss:.4f}')
205
190
  pbar.update(100)
206
191
  end_time = time.time()
@@ -0,0 +1,323 @@
1
+ import orbax.checkpoint
2
+ import tqdm
3
+ from flax import linen as nn
4
+ import jax
5
+ from typing import Callable
6
+ from dataclasses import field
7
+ import jax.numpy as jnp
8
+ from clu import metrics
9
+ from flax.training import train_state # Useful dataclass to keep train state
10
+ import optax
11
+ from flax import struct # Flax dataclasses
12
+ import time
13
+ import os
14
+ import orbax
15
+ from flax.training import orbax_utils
16
+
17
+ @struct.dataclass
18
+ class Metrics(metrics.Collection):
19
+ accuracy: metrics.Accuracy
20
+ loss: metrics.Average.from_output('loss')
21
+
22
+ # Define the TrainState
23
+ class SimpleTrainState(train_state.TrainState):
24
+ rngs: jax.random.PRNGKey
25
+ metrics: Metrics
26
+
27
+ def get_random_key(self):
28
+ rngs, subkey = jax.random.split(self.rngs)
29
+ return self.replace(rngs=rngs), subkey
30
+
31
+ class SimpleTrainer:
32
+ state: SimpleTrainState
33
+ best_state: SimpleTrainState
34
+ best_loss: float
35
+ model: nn.Module
36
+ ema_decay: float = 0.999
37
+
38
+ def __init__(self,
39
+ model: nn.Module,
40
+ input_shapes: Dict[str, Tuple[int]],
41
+ optimizer: optax.GradientTransformation,
42
+ rngs: jax.random.PRNGKey,
43
+ train_state: SimpleTrainState = None,
44
+ name: str = "Simple",
45
+ load_from_checkpoint: bool = False,
46
+ checkpoint_suffix: str = "",
47
+ loss_fn=optax.l2_loss,
48
+ param_transforms: Callable = None,
49
+ wandb_config: Dict[str, Any] = None,
50
+ distributed_training: bool = None,
51
+ ):
52
+ if distributed_training is None or distributed_training is True:
53
+ # Auto-detect if we are running on multiple devices
54
+ distributed_training = jax.device_count() > 1
55
+
56
+ self.distributed_training = distributed_training
57
+ self.model = model
58
+ self.name = name
59
+ self.loss_fn = loss_fn
60
+ self.input_shapes = input_shapes
61
+
62
+ if wandb_config is not None:
63
+ run = wandb.init(**wandb_config)
64
+ self.wandb = run
65
+
66
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
67
+ options = orbax.checkpoint.CheckpointManagerOptions(
68
+ max_to_keep=4, create=True)
69
+ self.checkpointer = orbax.checkpoint.CheckpointManager(
70
+ self.checkpoint_path() + checkpoint_suffix, checkpointer, options)
71
+
72
+ if load_from_checkpoint:
73
+ latest_epoch, old_state, old_best_state = self.load()
74
+ else:
75
+ latest_epoch, old_state, old_best_state = 0, None, None
76
+
77
+ self.latest_epoch = latest_epoch
78
+
79
+ if train_state == None:
80
+ self.init_state(optimizer, rngs, existing_state=old_state,
81
+ existing_best_state=old_best_state, model=model, param_transforms=param_transforms)
82
+ else:
83
+ self.state = train_state
84
+ self.best_state = train_state
85
+ self.best_loss = 1e9
86
+
87
+ def get_input_ones(self):
88
+ return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
89
+
90
+ def __init_fn(
91
+ self,
92
+ optimizer: optax.GradientTransformation,
93
+ rngs: jax.random.PRNGKey,
94
+ existing_state: dict = None,
95
+ existing_best_state: dict = None,
96
+ model: nn.Module = None,
97
+ param_transforms: Callable = None
98
+ ) -> Tuple[SimpleTrainState, SimpleTrainState]:
99
+ rngs, subkey = jax.random.split(rngs)
100
+
101
+ if existing_state == None:
102
+ input_vars = self.get_input_ones()
103
+ params = model.init(subkey, **input_vars)
104
+
105
+ state = SimpleTrainState.create(
106
+ apply_fn=model.apply,
107
+ params=params,
108
+ tx=optimizer,
109
+ rngs=rngs,
110
+ metrics=Metrics.empty()
111
+ )
112
+ if existing_best_state is not None:
113
+ best_state = state.replace(
114
+ params=existing_best_state['params'])
115
+ else:
116
+ best_state = state
117
+
118
+ return state, best_state
119
+
120
+ def init_state(
121
+ self,
122
+ optimizer: optax.GradientTransformation,
123
+ rngs: jax.random.PRNGKey,
124
+ existing_state: dict = None,
125
+ existing_best_state: dict = None,
126
+ model: nn.Module = None,
127
+ param_transforms: Callable = None
128
+ ):
129
+
130
+ state, best_state = self.__init_fn(
131
+ optimizer, rngs, existing_state, existing_best_state, model, param_transforms
132
+ )
133
+ self.best_loss = 1e9
134
+
135
+ if self.distributed_training:
136
+ devices = jax.local_devices()
137
+ if len(devices) > 1:
138
+ print("Replicating state across devices ", devices)
139
+ state = flax.jax_utils.replicate(state, devices)
140
+ best_state = flax.jax_utils.replicate(best_state, devices)
141
+ else:
142
+ print("Not replicating any state, Only single device connected to the process")
143
+
144
+ self.state = state
145
+ self.best_state = best_state
146
+
147
+ def get_state(self):
148
+ return flax.jax_utils.unreplicate(self.state)
149
+
150
+ def get_best_state(self):
151
+ return flax.jax_utils.unreplicate(self.best_state)
152
+
153
+ def checkpoint_path(self):
154
+ experiment_name = self.name
155
+ path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
156
+ if not os.path.exists(path):
157
+ os.makedirs(path)
158
+ return path
159
+
160
+ def tensorboard_path(self):
161
+ experiment_name = self.name
162
+ path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
163
+ if not os.path.exists(path):
164
+ os.makedirs(path)
165
+ return path
166
+
167
+ def load(self):
168
+ epoch = self.checkpointer.latest_step()
169
+ print("Loading model from checkpoint", epoch)
170
+ ckpt = self.checkpointer.restore(epoch)
171
+ state = ckpt['state']
172
+ best_state = ckpt['best_state']
173
+ # Convert the state to a TrainState
174
+ self.best_loss = ckpt['best_loss']
175
+ print(
176
+ f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
177
+ return epoch, state, best_state
178
+
179
+ def save(self, epoch=0):
180
+ print(f"Saving model at epoch {epoch}")
181
+ ckpt = {
182
+ # 'model': self.model,
183
+ 'state': self.get_state(),
184
+ 'best_state': self.get_best_state(),
185
+ 'best_loss': self.best_loss
186
+ }
187
+ try:
188
+ save_args = orbax_utils.save_args_from_target(ckpt)
189
+ self.checkpointer.save(epoch, ckpt, save_kwargs={
190
+ 'save_args': save_args}, force=True)
191
+ pass
192
+ except Exception as e:
193
+ print("Error saving checkpoint", e)
194
+
195
+ def _define_train_step(self, **kwargs):
196
+ model = self.model
197
+ loss_fn = self.loss_fn
198
+ distributed_training = self.distributed_training
199
+
200
+ def train_step(state: SimpleTrainState, batch):
201
+ """Train for a single step."""
202
+ images = batch['image']
203
+ labels = batch['label']
204
+
205
+ def model_loss(params):
206
+ preds = model.apply(params, images)
207
+ expected_output = labels
208
+ nloss = loss_fn(preds, expected_output)
209
+ loss = jnp.mean(nloss)
210
+ return loss
211
+ loss, grads = jax.value_and_grad(model_loss)(state.params)
212
+ if distributed_training:
213
+ grads = jax.lax.pmean(grads, "device")
214
+ state = state.apply_gradients(grads=grads)
215
+ return state, loss
216
+
217
+ if distributed_training:
218
+ train_step = jax.pmap(axis_name="device")(train_step)
219
+ else:
220
+ train_step = jax.jit(train_step)
221
+
222
+ return train_step
223
+
224
+ def _define_compute_metrics(self):
225
+ model = self.model
226
+ loss_fn = self.loss_fn
227
+
228
+ @jax.jit
229
+ def compute_metrics(state: SimpleTrainState, batch):
230
+ preds = model.apply(state.params, batch['image'])
231
+ expected_output = batch['label']
232
+ loss = jnp.mean(loss_fn(preds, expected_output))
233
+ metric_updates = state.metrics.single_from_model_output(
234
+ loss=loss, logits=preds, labels=expected_output)
235
+ metrics = state.metrics.merge(metric_updates)
236
+ state = state.replace(metrics=metrics)
237
+ return state
238
+ return compute_metrics
239
+
240
+ def summary(self):
241
+ input_vars = self.get_input_ones()
242
+ print(self.model.tabulate(jax.random.key(0), **input_vars,
243
+ console_kwargs={"width": 200, "force_jupyter": True, }))
244
+
245
+ def config(self):
246
+ return {
247
+ "model": self.model,
248
+ "state": self.state,
249
+ "name": self.name,
250
+ "input_shapes": self.input_shapes
251
+ }
252
+
253
+ def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
254
+ summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
255
+ summary_writer.hparams({
256
+ **self.config(),
257
+ "steps_per_epoch": steps_per_epoch,
258
+ "epochs": epochs,
259
+ "batch_size": batch_size
260
+ })
261
+ return summary_writer
262
+
263
+ def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
264
+ train_ds = iter(data['train']())
265
+ if 'test' in data:
266
+ test_ds = data['test']
267
+ else:
268
+ test_ds = None
269
+ train_step = self._define_train_step(**train_step_args)
270
+ compute_metrics = self._define_compute_metrics()
271
+ state = self.state
272
+ device_count = jax.local_device_count()
273
+ # train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices())
274
+
275
+ summary_writer = self.init_tensorboard(
276
+ data['global_batch_size'], steps_per_epoch, epochs)
277
+
278
+ while self.latest_epoch <= epochs:
279
+ self.latest_epoch += 1
280
+ current_epoch = self.latest_epoch
281
+ print(f"\nEpoch {current_epoch}/{epochs}")
282
+ start_time = time.time()
283
+ epoch_loss = 0
284
+
285
+ with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
286
+ for i in range(steps_per_epoch):
287
+ batch = next(train_ds)
288
+ if self.distributed_training and device_count > 1:
289
+ batch = jax.tree.map(lambda x: x.reshape(
290
+ (device_count, -1, *x.shape[1:])), batch)
291
+
292
+ state, loss = train_step(state, batch)
293
+ loss = jnp.mean(loss)
294
+
295
+ epoch_loss += loss
296
+ if i % 100 == 0:
297
+ pbar.set_postfix(loss=f'{loss:.4f}')
298
+ pbar.update(100)
299
+ current_step = current_epoch*steps_per_epoch + i
300
+ summary_writer.scalar(
301
+ 'Train Loss', loss, step=current_step)
302
+ if self.wandb is not None:
303
+ self.wandb.log({"train/loss": loss})
304
+
305
+ print(f"\n\tEpoch done")
306
+ end_time = time.time()
307
+ self.state = state
308
+ total_time = end_time - start_time
309
+ avg_time_per_step = total_time / steps_per_epoch
310
+ avg_loss = epoch_loss / steps_per_epoch
311
+ if avg_loss < self.best_loss:
312
+ self.best_loss = avg_loss
313
+ self.best_state = state
314
+ self.save(current_epoch)
315
+
316
+ # Compute Metrics
317
+ metrics_str = ''
318
+
319
+ print(
320
+ f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}")
321
+
322
+ self.save(epochs)
323
+ return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.1
3
+ Version: 0.1.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,10 +1,11 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
- flaxdiff/models/attention.py,sha256=enyqoZP4NMbIn07UdnduxvohtfpbsYW-n7nALE3K_s4,18369
5
- flaxdiff/models/common.py,sha256=WUCbuqSa8jEWAUt0UbEStTlpt5j1Mw8oZmZXYj5VwWQ,241
4
+ flaxdiff/models/attention.py,sha256=SL9cvINjmabW1LPvXLAFZNHv-FF1Ez_d3J7n5uHBTyQ,15301
5
+ flaxdiff/models/common.py,sha256=CjC4iRLjkF3oQ0f6rAqfiLaiHllZGtCOwN3rXDUndbE,274
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
- flaxdiff/models/simple_unet.py,sha256=EExRXSo0nvpiDUF_3lPKp4eQVGBa05PSskNs1ER0sqU,19273
7
+ flaxdiff/models/simple_unet.py,sha256=WlLry6v18syHBzcN8zAJ-zIVtq6ItMEIBWbeCcX0MLU,18693
8
+ flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
8
9
  flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
9
10
  flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
10
11
  flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
@@ -23,8 +24,9 @@ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,60
23
24
  flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
24
25
  flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
25
26
  flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
26
- flaxdiff/trainer/__init__.py,sha256=iXnrIugF2g2ZLgW3HxZZBzgsoxJx7bWvLxqVmWpmAbo,8536
27
- flaxdiff-0.1.1.dist-info/METADATA,sha256=ZcNAw19k8s40DKgBILh3CriHkieOXuwhUbUJjx_YW8U,19229
28
- flaxdiff-0.1.1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
29
- flaxdiff-0.1.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
30
- flaxdiff-0.1.1.dist-info/RECORD,,
27
+ flaxdiff/trainer/__init__.py,sha256=kwzkm-BD97hffFIXZUP1Hb3_D85fZ4SRNO7bviEwHU8,7591
28
+ flaxdiff/trainer/simple_trainer.py,sha256=jafxr-yZ6FXn0Qi-iTSnlf275QWnIO4GnSvNAeB3H-Q,11651
29
+ flaxdiff-0.1.4.dist-info/METADATA,sha256=G8OijdrrYWuKyAfCNtD_dKwdfBmdME56vpR-EYIZKXg,19229
30
+ flaxdiff-0.1.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
31
+ flaxdiff-0.1.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
32
+ flaxdiff-0.1.4.dist-info/RECORD,,