flaxdiff 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +8 -65
- flaxdiff/models/autoencoder/diffusers.py +1 -1
- flaxdiff/models/common.py +14 -4
- flaxdiff/models/simple_unet.py +20 -10
- flaxdiff/models/simple_vit.py +13 -16
- flaxdiff/trainer/diffusion_trainer.py +41 -11
- flaxdiff/trainer/simple_trainer.py +80 -60
- {flaxdiff-0.1.7.dist-info → flaxdiff-0.1.9.dist-info}/METADATA +18 -1
- {flaxdiff-0.1.7.dist-info → flaxdiff-0.1.9.dist-info}/RECORD +11 -11
- {flaxdiff-0.1.7.dist-info → flaxdiff-0.1.9.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.7.dist-info → flaxdiff-0.1.9.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -156,71 +156,14 @@ class NormalAttention(nn.Module):
|
|
156
156
|
value = self.value(context)
|
157
157
|
|
158
158
|
hidden_states = nn.dot_product_attention(
|
159
|
-
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
159
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
160
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
|
161
|
+
deterministic=True
|
160
162
|
)
|
161
163
|
proj = self.proj_attn(hidden_states)
|
162
164
|
proj = proj.reshape(orig_x_shape)
|
163
165
|
return proj
|
164
166
|
|
165
|
-
class BasicTransformerBlock(nn.Module):
|
166
|
-
# Has self and cross attention
|
167
|
-
query_dim: int
|
168
|
-
heads: int = 4
|
169
|
-
dim_head: int = 64
|
170
|
-
dtype: Optional[Dtype] = None
|
171
|
-
precision: PrecisionLike = None
|
172
|
-
use_bias: bool = True
|
173
|
-
kernel_init: Callable = lambda : kernel_init(1.0)
|
174
|
-
use_flash_attention:bool = False
|
175
|
-
use_cross_only:bool = False
|
176
|
-
|
177
|
-
def setup(self):
|
178
|
-
if self.use_flash_attention:
|
179
|
-
attenBlock = EfficientAttention
|
180
|
-
else:
|
181
|
-
attenBlock = NormalAttention
|
182
|
-
|
183
|
-
self.attention1 = attenBlock(
|
184
|
-
query_dim=self.query_dim,
|
185
|
-
heads=self.heads,
|
186
|
-
dim_head=self.dim_head,
|
187
|
-
name=f'Attention1',
|
188
|
-
precision=self.precision,
|
189
|
-
use_bias=self.use_bias,
|
190
|
-
dtype=self.dtype,
|
191
|
-
kernel_init=self.kernel_init
|
192
|
-
)
|
193
|
-
self.attention2 = attenBlock(
|
194
|
-
query_dim=self.query_dim,
|
195
|
-
heads=self.heads,
|
196
|
-
dim_head=self.dim_head,
|
197
|
-
name=f'Attention2',
|
198
|
-
precision=self.precision,
|
199
|
-
use_bias=self.use_bias,
|
200
|
-
dtype=self.dtype,
|
201
|
-
kernel_init=self.kernel_init
|
202
|
-
)
|
203
|
-
|
204
|
-
self.ff = FlaxFeedForward(dim=self.query_dim)
|
205
|
-
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
206
|
-
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
207
|
-
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
208
|
-
|
209
|
-
@nn.compact
|
210
|
-
def __call__(self, hidden_states, context=None):
|
211
|
-
# self attention
|
212
|
-
if not self.use_cross_only:
|
213
|
-
print("Using self attention")
|
214
|
-
hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))
|
215
|
-
|
216
|
-
# cross attention
|
217
|
-
hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)
|
218
|
-
|
219
|
-
# feed forward
|
220
|
-
hidden_states = hidden_states + self.ff(self.norm3(hidden_states))
|
221
|
-
|
222
|
-
return hidden_states
|
223
|
-
|
224
167
|
class FlaxGEGLU(nn.Module):
|
225
168
|
r"""
|
226
169
|
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
@@ -246,7 +189,7 @@ class FlaxGEGLU(nn.Module):
|
|
246
189
|
|
247
190
|
def __call__(self, hidden_states):
|
248
191
|
hidden_states = self.proj(hidden_states)
|
249
|
-
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis
|
192
|
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)
|
250
193
|
return hidden_linear * nn.gelu(hidden_gelu)
|
251
194
|
|
252
195
|
class FlaxFeedForward(nn.Module):
|
@@ -330,7 +273,7 @@ class BasicTransformerBlock(nn.Module):
|
|
330
273
|
@nn.compact
|
331
274
|
def __call__(self, hidden_states, context=None):
|
332
275
|
if self.only_pure_attention:
|
333
|
-
return self.attention2(
|
276
|
+
return self.attention2(hidden_states, context)
|
334
277
|
|
335
278
|
# self attention
|
336
279
|
if not self.use_cross_only:
|
@@ -350,14 +293,14 @@ class TransformerBlock(nn.Module):
|
|
350
293
|
dtype: Optional[Dtype] = None
|
351
294
|
precision: PrecisionLike = None
|
352
295
|
use_projection: bool = False
|
353
|
-
use_flash_attention:bool =
|
354
|
-
use_self_and_cross:bool =
|
296
|
+
use_flash_attention:bool = False
|
297
|
+
use_self_and_cross:bool = True
|
355
298
|
only_pure_attention:bool = False
|
356
299
|
|
357
300
|
@nn.compact
|
358
301
|
def __call__(self, x, context=None):
|
359
302
|
inner_dim = self.heads * self.dim_head
|
360
|
-
|
303
|
+
C = x.shape[-1]
|
361
304
|
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
362
305
|
if self.use_projection == True:
|
363
306
|
if self.use_linear_attention:
|
@@ -78,7 +78,7 @@ class StableDiffusionVAE(AutoEncoder):
|
|
78
78
|
log_std = jnp.clip(log_std, -30, 20)
|
79
79
|
std = jnp.exp(0.5 * log_std)
|
80
80
|
latents = mean + std * jax.random.normal(rngkey, mean.shape, dtype=mean.dtype)
|
81
|
-
print("Sampled")
|
81
|
+
# print("Sampled")
|
82
82
|
else:
|
83
83
|
# return the mean
|
84
84
|
latents, _ = jnp.split(latents, 2, axis=-1)
|
flaxdiff/models/common.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Optional, Any, Callable, Sequence, Union
|
|
5
5
|
from flax.typing import Dtype, PrecisionLike
|
6
6
|
from typing import Dict, Callable, Sequence, Any, Union
|
7
7
|
import einops
|
8
|
+
from functools import partial
|
8
9
|
|
9
10
|
# Kernel initializer to use
|
10
11
|
def kernel_init(scale, dtype=jnp.float32):
|
@@ -266,12 +267,21 @@ class ResidualBlock(nn.Module):
|
|
266
267
|
kernel_init:Callable=kernel_init(1.0)
|
267
268
|
dtype: Optional[Dtype] = None
|
268
269
|
precision: PrecisionLike = None
|
270
|
+
|
271
|
+
def setup(self):
|
272
|
+
if self.norm_groups > 0:
|
273
|
+
norm = partial(nn.GroupNorm, self.norm_groups)
|
274
|
+
else:
|
275
|
+
norm = partial(nn.RMSNorm, 1e-5)
|
276
|
+
|
277
|
+
self.norm1 = norm()
|
278
|
+
self.norm2 = norm()
|
269
279
|
|
270
280
|
@nn.compact
|
271
281
|
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
272
282
|
residual = x
|
273
|
-
|
274
|
-
out = nn.RMSNorm()(x)
|
283
|
+
out = self.norm1(x)
|
284
|
+
# out = nn.RMSNorm()(x)
|
275
285
|
out = self.activation(out)
|
276
286
|
|
277
287
|
out = ConvLayer(
|
@@ -295,8 +305,8 @@ class ResidualBlock(nn.Module):
|
|
295
305
|
# out = out * (1 + scale) + shift
|
296
306
|
out = out + temb
|
297
307
|
|
298
|
-
|
299
|
-
out = nn.RMSNorm()(out)
|
308
|
+
out = self.norm2(out)
|
309
|
+
# out = nn.RMSNorm()(out)
|
300
310
|
out = self.activation(out)
|
301
311
|
|
302
312
|
out = ConvLayer(
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Optional
|
|
6
6
|
import einops
|
7
7
|
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
|
8
8
|
from .attention import TransformerBlock
|
9
|
+
from functools import partial
|
9
10
|
|
10
11
|
class Unet(nn.Module):
|
11
12
|
output_channels:int=3
|
@@ -19,6 +20,15 @@ class Unet(nn.Module):
|
|
19
20
|
dtype: Optional[Dtype] = None
|
20
21
|
precision: PrecisionLike = None
|
21
22
|
|
23
|
+
def setup(self):
|
24
|
+
if self.norm_groups > 0:
|
25
|
+
norm = partial(nn.GroupNorm, self.norm_groups)
|
26
|
+
else:
|
27
|
+
norm = partial(nn.RMSNorm, 1e-5)
|
28
|
+
|
29
|
+
# self.last_up_norm = norm()
|
30
|
+
self.conv_out_norm = norm()
|
31
|
+
|
22
32
|
@nn.compact
|
23
33
|
def __call__(self, x, temb, textcontext):
|
24
34
|
# print("embedding features", self.emb_features)
|
@@ -65,11 +75,11 @@ class Unet(nn.Module):
|
|
65
75
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
66
76
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
67
77
|
dim_head=dim_in // attention_config['heads'],
|
68
|
-
use_flash_attention=attention_config.get("flash_attention",
|
78
|
+
use_flash_attention=attention_config.get("flash_attention", False),
|
69
79
|
use_projection=attention_config.get("use_projection", False),
|
70
80
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
71
81
|
precision=attention_config.get("precision", self.precision),
|
72
|
-
only_pure_attention=True,
|
82
|
+
only_pure_attention=attention_config.get("only_pure_attention", True),
|
73
83
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
74
84
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
75
85
|
downs.append(x)
|
@@ -103,12 +113,12 @@ class Unet(nn.Module):
|
|
103
113
|
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
|
104
114
|
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
|
105
115
|
dim_head=middle_dim_out // middle_attention['heads'],
|
106
|
-
use_flash_attention=middle_attention.get("flash_attention",
|
116
|
+
use_flash_attention=middle_attention.get("flash_attention", False),
|
107
117
|
use_linear_attention=False,
|
108
118
|
use_projection=middle_attention.get("use_projection", False),
|
109
119
|
use_self_and_cross=False,
|
110
|
-
precision=
|
111
|
-
only_pure_attention=True,
|
120
|
+
precision=middle_attention.get("precision", self.precision),
|
121
|
+
only_pure_attention=middle_attention.get("only_pure_attention", True),
|
112
122
|
name=f"middle_attention_{j}")(x, textcontext)
|
113
123
|
x = ResidualBlock(
|
114
124
|
middle_conv_type,
|
@@ -146,11 +156,11 @@ class Unet(nn.Module):
|
|
146
156
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
147
157
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
148
158
|
dim_head=dim_out // attention_config['heads'],
|
149
|
-
use_flash_attention=attention_config.get("flash_attention",
|
159
|
+
use_flash_attention=attention_config.get("flash_attention", False),
|
150
160
|
use_projection=attention_config.get("use_projection", False),
|
151
161
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
152
162
|
precision=attention_config.get("precision", self.precision),
|
153
|
-
only_pure_attention=True,
|
163
|
+
only_pure_attention=attention_config.get("only_pure_attention", True),
|
154
164
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
155
165
|
# print("Upscaling ", i, x.shape)
|
156
166
|
if i != len(feature_depths) - 1:
|
@@ -163,13 +173,13 @@ class Unet(nn.Module):
|
|
163
173
|
precision=self.precision
|
164
174
|
)(x)
|
165
175
|
|
166
|
-
# x =
|
176
|
+
# x = self.last_up_norm(x)
|
167
177
|
x = ConvLayer(
|
168
178
|
conv_type,
|
169
179
|
features=self.feature_depths[0],
|
170
180
|
kernel_size=(3, 3),
|
171
181
|
strides=(1, 1),
|
172
|
-
kernel_init=kernel_init(
|
182
|
+
kernel_init=kernel_init(1.0),
|
173
183
|
dtype=self.dtype,
|
174
184
|
precision=self.precision
|
175
185
|
)(x)
|
@@ -189,7 +199,7 @@ class Unet(nn.Module):
|
|
189
199
|
precision=self.precision
|
190
200
|
)(x, temb)
|
191
201
|
|
192
|
-
x =
|
202
|
+
x = self.conv_out_norm(x)
|
193
203
|
x = self.activation(x)
|
194
204
|
|
195
205
|
noise_out = ConvLayer(
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -4,7 +4,7 @@ import jax
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
from flax import linen as nn
|
6
6
|
from typing import Callable, Any
|
7
|
-
from .
|
7
|
+
from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
|
8
8
|
from .attention import TransformerBlock
|
9
9
|
|
10
10
|
class PatchEmbedding(nn.Module):
|
@@ -40,22 +40,23 @@ class PositionalEncoding(nn.Module):
|
|
40
40
|
class TransformerEncoder(nn.Module):
|
41
41
|
num_layers: int
|
42
42
|
num_heads: int
|
43
|
-
mlp_dim: int
|
44
43
|
dropout_rate: float = 0.1
|
45
44
|
dtype: Any = jnp.float32
|
46
45
|
precision: Any = jax.lax.Precision.HIGH
|
46
|
+
use_projection: bool = False
|
47
47
|
|
48
48
|
@nn.compact
|
49
|
-
def __call__(self, x,
|
49
|
+
def __call__(self, x, context=None):
|
50
50
|
for _ in range(self.num_layers):
|
51
51
|
x = TransformerBlock(
|
52
52
|
heads=self.num_heads,
|
53
53
|
dim_head=x.shape[-1] // self.num_heads,
|
54
|
-
mlp_dim=self.mlp_dim,
|
55
54
|
dropout_rate=self.dropout_rate,
|
56
55
|
dtype=self.dtype,
|
57
|
-
precision=self.precision
|
58
|
-
|
56
|
+
precision=self.precision,
|
57
|
+
use_self_and_cross=True,
|
58
|
+
use_projection=self.use_projection,
|
59
|
+
)(x, context)
|
59
60
|
return x
|
60
61
|
|
61
62
|
class VisionTransformer(nn.Module):
|
@@ -63,11 +64,11 @@ class VisionTransformer(nn.Module):
|
|
63
64
|
embedding_dim: int = 768
|
64
65
|
num_layers: int = 12
|
65
66
|
num_heads: int = 12
|
66
|
-
mlp_dim: int = 3072
|
67
67
|
emb_features: int = 256
|
68
68
|
dropout_rate: float = 0.1
|
69
69
|
dtype: Any = jnp.float32
|
70
70
|
precision: Any = jax.lax.Precision.HIGH
|
71
|
+
use_projection: bool = False
|
71
72
|
|
72
73
|
@nn.compact
|
73
74
|
def __call__(self, x, temb, textcontext=None):
|
@@ -81,27 +82,23 @@ class VisionTransformer(nn.Module):
|
|
81
82
|
|
82
83
|
# Add positional encoding
|
83
84
|
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)
|
85
|
+
|
86
|
+
num_patches = x.shape[1]
|
84
87
|
|
85
88
|
# Add time embedding
|
86
89
|
temb = jnp.expand_dims(temb, axis=1)
|
87
90
|
x = jnp.concatenate([x, temb], axis=1)
|
88
91
|
|
89
|
-
# Add text context
|
90
|
-
if textcontext is not None:
|
91
|
-
x = jnp.concatenate([x, textcontext], axis=1)
|
92
|
-
|
93
92
|
# Transformer encoder
|
94
93
|
x = TransformerEncoder(
|
95
94
|
num_layers=self.num_layers,
|
96
95
|
num_heads=self.num_heads,
|
97
|
-
mlp_dim=self.mlp_dim,
|
98
96
|
dropout_rate=self.dropout_rate,
|
99
97
|
dtype=self.dtype,
|
100
|
-
precision=self.precision
|
101
|
-
|
98
|
+
precision=self.precision,
|
99
|
+
use_projection=self.use_projection
|
100
|
+
)(x, textcontext)
|
102
101
|
|
103
|
-
# Extract the image tokens (exclude time and text embeddings)
|
104
|
-
num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
|
105
102
|
x = x[:, :num_patches, :]
|
106
103
|
|
107
104
|
# Reshape to image dimensions
|
@@ -29,6 +29,8 @@ class TrainState(SimpleTrainState):
|
|
29
29
|
)
|
30
30
|
return self.replace(ema_params=new_ema_params)
|
31
31
|
|
32
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
33
|
+
|
32
34
|
class DiffusionTrainer(SimpleTrainer):
|
33
35
|
noise_schedule: NoiseScheduler
|
34
36
|
model_output_transform: DiffusionPredictionTransform
|
@@ -40,7 +42,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
40
42
|
optimizer: optax.GradientTransformation,
|
41
43
|
noise_schedule: NoiseScheduler,
|
42
44
|
rngs: jax.random.PRNGKey,
|
43
|
-
unconditional_prob: float = 0.
|
45
|
+
unconditional_prob: float = 0.12,
|
44
46
|
name: str = "Diffusion",
|
45
47
|
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
46
48
|
autoencoder: AutoEncoder = None,
|
@@ -67,7 +69,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
67
69
|
existing_state: dict = None,
|
68
70
|
existing_best_state: dict = None,
|
69
71
|
model: nn.Module = None,
|
70
|
-
param_transforms: Callable = None
|
72
|
+
param_transforms: Callable = None,
|
73
|
+
use_dynamic_scale: bool = False
|
71
74
|
) -> Tuple[TrainState, TrainState]:
|
72
75
|
print("Generating states for DiffusionTrainer")
|
73
76
|
rngs, subkey = jax.random.split(rngs)
|
@@ -88,7 +91,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
88
91
|
ema_params=new_state['ema_params'],
|
89
92
|
tx=optimizer,
|
90
93
|
rngs=rngs,
|
91
|
-
metrics=Metrics.empty()
|
94
|
+
metrics=Metrics.empty(),
|
95
|
+
dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
|
92
96
|
)
|
93
97
|
|
94
98
|
if existing_best_state is not None:
|
@@ -125,14 +129,14 @@ class DiffusionTrainer(SimpleTrainer):
|
|
125
129
|
local_rng_state = RandomMarkovState(subkey)
|
126
130
|
|
127
131
|
images = batch['image']
|
132
|
+
images = jnp.array(images, dtype=jnp.float32)
|
133
|
+
# normalize image
|
134
|
+
images = (images - 127.5) / 127.5
|
128
135
|
|
129
136
|
if autoencoder is not None:
|
130
137
|
# Convert the images to latent space
|
131
138
|
local_rng_state, rngs = local_rng_state.get_random_key()
|
132
139
|
images = autoencoder.encode(images, rngs)
|
133
|
-
else:
|
134
|
-
# normalize image
|
135
|
-
images = (images - 127.5) / 127.5
|
136
140
|
|
137
141
|
output = text_embedder(
|
138
142
|
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
@@ -163,12 +167,39 @@ class DiffusionTrainer(SimpleTrainer):
|
|
163
167
|
loss = nloss
|
164
168
|
return loss
|
165
169
|
|
166
|
-
|
170
|
+
|
171
|
+
if train_state.dynamic_scale is not None:
|
172
|
+
# dynamic scale takes care of averaging gradients across replicas
|
173
|
+
grad_fn = train_state.dynamic_scale.value_and_grad(
|
174
|
+
model_loss, axis_name="data"
|
175
|
+
)
|
176
|
+
dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
|
177
|
+
train_state = train_state.replace(dynamic_scale=dynamic_scale)
|
178
|
+
else:
|
179
|
+
grad_fn = jax.value_and_grad(model_loss)
|
180
|
+
loss, grads = grad_fn(train_state.params)
|
181
|
+
if distributed_training:
|
182
|
+
grads = jax.lax.pmean(grads, "data")
|
183
|
+
|
184
|
+
new_state = train_state.apply_gradients(grads=grads)
|
185
|
+
|
186
|
+
if train_state.dynamic_scale:
|
187
|
+
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
|
188
|
+
# params should be restored (= skip this step).
|
189
|
+
select_fn = functools.partial(jnp.where, is_fin)
|
190
|
+
new_state = train_state.replace(
|
191
|
+
opt_state=jax.tree_util.tree_map(
|
192
|
+
select_fn, new_state.opt_state, train_state.opt_state
|
193
|
+
),
|
194
|
+
params=jax.tree_util.tree_map(
|
195
|
+
select_fn, new_state.params, train_state.params
|
196
|
+
),
|
197
|
+
)
|
198
|
+
|
199
|
+
train_state = new_state.apply_ema(self.ema_decay)
|
200
|
+
|
167
201
|
if distributed_training:
|
168
|
-
grads = jax.lax.pmean(grads, "data")
|
169
202
|
loss = jax.lax.pmean(loss, "data")
|
170
|
-
train_state = train_state.apply_gradients(grads=grads)
|
171
|
-
train_state = train_state.apply_ema(self.ema_decay)
|
172
203
|
return train_state, loss, rng_state
|
173
204
|
|
174
205
|
if distributed_training:
|
@@ -199,4 +230,3 @@ def boolean_string(s):
|
|
199
230
|
if type(s) == bool:
|
200
231
|
return s
|
201
232
|
return s == 'True'
|
202
|
-
|
@@ -39,23 +39,23 @@ PROCESS_COLOR_MAP = {
|
|
39
39
|
def _build_global_shape_and_sharding(
|
40
40
|
local_shape: tuple[int, ...], global_mesh: Mesh
|
41
41
|
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
42
|
-
|
43
|
-
|
44
|
-
|
42
|
+
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
43
|
+
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
44
|
+
return global_shape, sharding
|
45
45
|
|
46
46
|
|
47
47
|
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
48
|
+
"""Put local sharded array into local devices"""
|
49
|
+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
50
|
+
try:
|
51
|
+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
52
|
+
except ValueError as array_split_error:
|
53
|
+
raise ValueError(
|
54
|
+
f"Unable to put to devices shape {array.shape} with "
|
55
|
+
f"local device count {len(global_mesh.local_devices)} "
|
56
|
+
) from array_split_error
|
57
|
+
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
58
|
+
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
59
59
|
|
60
60
|
def convert_to_global_tree(global_mesh, pytree):
|
61
61
|
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
@@ -67,12 +67,8 @@ class Metrics(metrics.Collection):
|
|
67
67
|
|
68
68
|
# Define the TrainState
|
69
69
|
class SimpleTrainState(train_state.TrainState):
|
70
|
-
rngs: jax.random.PRNGKey
|
71
70
|
metrics: Metrics
|
72
|
-
|
73
|
-
def get_random_key(self):
|
74
|
-
rngs, subkey = jax.random.split(self.rngs)
|
75
|
-
return self.replace(rngs=rngs), subkey
|
71
|
+
dynamic_scale: flax.training.dynamic_scale.DynamicScale
|
76
72
|
|
77
73
|
class SimpleTrainer:
|
78
74
|
state: SimpleTrainState
|
@@ -88,20 +84,22 @@ class SimpleTrainer:
|
|
88
84
|
rngs: jax.random.PRNGKey,
|
89
85
|
train_state: SimpleTrainState = None,
|
90
86
|
name: str = "Simple",
|
91
|
-
load_from_checkpoint:
|
87
|
+
load_from_checkpoint: str = None,
|
92
88
|
checkpoint_suffix: str = "",
|
93
|
-
checkpoint_id: str = None,
|
94
89
|
loss_fn=optax.l2_loss,
|
95
90
|
param_transforms: Callable = None,
|
96
91
|
wandb_config: Dict[str, Any] = None,
|
97
92
|
distributed_training: bool = None,
|
98
93
|
checkpoint_base_path: str = "./checkpoints",
|
94
|
+
checkpoint_step: int = None,
|
95
|
+
use_dynamic_scale: bool = False,
|
99
96
|
):
|
100
97
|
if distributed_training is None or distributed_training is True:
|
101
98
|
# Auto-detect if we are running on multiple devices
|
102
99
|
distributed_training = jax.device_count() > 1
|
103
100
|
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
|
104
|
-
|
101
|
+
else:
|
102
|
+
self.mesh = None
|
105
103
|
|
106
104
|
self.distributed_training = distributed_training
|
107
105
|
self.model = model
|
@@ -112,7 +110,6 @@ class SimpleTrainer:
|
|
112
110
|
|
113
111
|
|
114
112
|
if wandb_config is not None and jax.process_index() == 0:
|
115
|
-
import wandb
|
116
113
|
run = wandb.init(**wandb_config)
|
117
114
|
self.wandb = run
|
118
115
|
|
@@ -126,11 +123,6 @@ class SimpleTrainer:
|
|
126
123
|
self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
|
127
124
|
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
|
128
125
|
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
|
129
|
-
|
130
|
-
if checkpoint_id is None:
|
131
|
-
self.checkpoint_id = name.replace(' ', '_').replace('-', '_').lower()
|
132
|
-
else:
|
133
|
-
self.checkpoint_id = checkpoint_id
|
134
126
|
|
135
127
|
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
136
128
|
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
|
@@ -140,12 +132,12 @@ class SimpleTrainer:
|
|
140
132
|
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
141
133
|
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
|
142
134
|
|
143
|
-
if load_from_checkpoint:
|
144
|
-
latest_epoch, old_state, old_best_state, rngstate = self.load()
|
135
|
+
if load_from_checkpoint is not None:
|
136
|
+
latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
|
145
137
|
else:
|
146
|
-
latest_epoch, old_state, old_best_state, rngstate = 0, None, None, None
|
138
|
+
latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
|
147
139
|
|
148
|
-
self.
|
140
|
+
self.latest_step = latest_step
|
149
141
|
|
150
142
|
if rngstate:
|
151
143
|
self.rngstate = RandomMarkovState(**rngstate)
|
@@ -156,7 +148,7 @@ class SimpleTrainer:
|
|
156
148
|
|
157
149
|
if train_state == None:
|
158
150
|
state, best_state = self.generate_states(
|
159
|
-
optimizer, subkey, old_state, old_best_state, model, param_transforms
|
151
|
+
optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
|
160
152
|
)
|
161
153
|
self.init_state(state, best_state)
|
162
154
|
else:
|
@@ -174,7 +166,8 @@ class SimpleTrainer:
|
|
174
166
|
existing_state: dict = None,
|
175
167
|
existing_best_state: dict = None,
|
176
168
|
model: nn.Module = None,
|
177
|
-
param_transforms: Callable = None
|
169
|
+
param_transforms: Callable = None,
|
170
|
+
use_dynamic_scale: bool = False
|
178
171
|
) -> Tuple[SimpleTrainState, SimpleTrainState]:
|
179
172
|
print("Generating states for SimpleTrainer")
|
180
173
|
rngs, subkey = jax.random.split(rngs)
|
@@ -189,7 +182,8 @@ class SimpleTrainer:
|
|
189
182
|
apply_fn=model.apply,
|
190
183
|
params=params,
|
191
184
|
tx=optimizer,
|
192
|
-
metrics=Metrics.empty()
|
185
|
+
metrics=Metrics.empty(),
|
186
|
+
dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
|
193
187
|
)
|
194
188
|
if existing_best_state is not None:
|
195
189
|
best_state = state.replace(
|
@@ -222,7 +216,7 @@ class SimpleTrainer:
|
|
222
216
|
return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
|
223
217
|
|
224
218
|
def checkpoint_path(self):
|
225
|
-
path = os.path.join(self.checkpoint_base_path, self.
|
219
|
+
path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
|
226
220
|
if not os.path.exists(path):
|
227
221
|
os.makedirs(path)
|
228
222
|
return path
|
@@ -234,31 +228,46 @@ class SimpleTrainer:
|
|
234
228
|
os.makedirs(path)
|
235
229
|
return path
|
236
230
|
|
237
|
-
def load(self):
|
238
|
-
|
239
|
-
|
240
|
-
|
231
|
+
def load(self, checkpoint_path=None, checkpoint_step=None):
|
232
|
+
if checkpoint_path is None:
|
233
|
+
checkpointer = self.checkpointer
|
234
|
+
else:
|
235
|
+
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
236
|
+
options = orbax.checkpoint.CheckpointManagerOptions(
|
237
|
+
max_to_keep=4, create=False)
|
238
|
+
checkpointer = orbax.checkpoint.CheckpointManager(
|
239
|
+
checkpoint_path, checkpointer, options)
|
240
|
+
|
241
|
+
if checkpoint_step is None:
|
242
|
+
step = checkpointer.latest_step()
|
243
|
+
else:
|
244
|
+
step = checkpoint_step
|
245
|
+
|
246
|
+
print("Loading model from checkpoint at step ", step)
|
247
|
+
ckpt = checkpointer.restore(step)
|
241
248
|
state = ckpt['state']
|
242
249
|
best_state = ckpt['best_state']
|
243
250
|
rngstate = ckpt['rngs']
|
244
251
|
# Convert the state to a TrainState
|
245
252
|
self.best_loss = ckpt['best_loss']
|
253
|
+
current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
|
246
254
|
print(
|
247
|
-
f"Loaded model from checkpoint at epoch {
|
248
|
-
return
|
255
|
+
f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
|
256
|
+
return current_epoch, step, state, best_state, rngstate
|
249
257
|
|
250
|
-
def save(self, epoch=0):
|
251
|
-
print(f"Saving model at epoch {epoch}")
|
258
|
+
def save(self, epoch=0, step=0):
|
259
|
+
print(f"Saving model at epoch {epoch} step {step}")
|
252
260
|
ckpt = {
|
253
261
|
# 'model': self.model,
|
254
262
|
'rngs': self.get_rngstate(),
|
255
263
|
'state': self.get_state(),
|
256
264
|
'best_state': self.get_best_state(),
|
257
265
|
'best_loss': np.array(self.best_loss),
|
266
|
+
'epoch': epoch,
|
258
267
|
}
|
259
268
|
try:
|
260
269
|
save_args = orbax_utils.save_args_from_target(ckpt)
|
261
|
-
self.checkpointer.save(
|
270
|
+
self.checkpointer.save(step, ckpt, save_kwargs={
|
262
271
|
'save_args': save_args}, force=True)
|
263
272
|
self.checkpointer.wait_until_finished()
|
264
273
|
pass
|
@@ -350,9 +359,10 @@ class SimpleTrainer:
|
|
350
359
|
else:
|
351
360
|
global_device_indexes = 0
|
352
361
|
|
353
|
-
def train_loop(
|
362
|
+
def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state):
|
354
363
|
epoch_loss = 0
|
355
|
-
|
364
|
+
current_epoch = current_step // steps_per_epoch
|
365
|
+
last_save_time = time.time()
|
356
366
|
for i in range(steps_per_epoch):
|
357
367
|
batch = next(train_ds)
|
358
368
|
if self.distributed_training and global_device_count > 1:
|
@@ -363,36 +373,46 @@ class SimpleTrainer:
|
|
363
373
|
if self.distributed_training:
|
364
374
|
loss = jax.experimental.multihost_utils.process_allgather(loss)
|
365
375
|
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
376
|
+
|
377
|
+
if loss <= 1e-6:
|
378
|
+
# If the loss is too low, we can assume the model has diverged
|
379
|
+
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
380
|
+
# Exit the training loop
|
381
|
+
exit(1)
|
366
382
|
|
367
383
|
epoch_loss += loss
|
368
|
-
|
369
|
-
if
|
370
|
-
if
|
384
|
+
current_step += 1
|
385
|
+
if i % 100 == 0:
|
386
|
+
if pbar is not None:
|
371
387
|
pbar.set_postfix(loss=f'{loss:.4f}')
|
372
388
|
pbar.update(100)
|
373
|
-
current_step = current_epoch*steps_per_epoch + i
|
374
389
|
if self.wandb is not None:
|
375
390
|
self.wandb.log({
|
376
391
|
"train/step" : current_step,
|
377
392
|
"train/loss": loss,
|
378
393
|
}, step=current_step)
|
394
|
+
# Save the model every 40 minutes
|
395
|
+
if time.time() - last_save_time > 40 * 60:
|
396
|
+
print(f"Saving model after 40 minutes at step {current_step}")
|
397
|
+
self.save(current_epoch, current_step)
|
398
|
+
last_save_time = time.time()
|
379
399
|
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
|
380
400
|
return epoch_loss, current_step, train_state, rng_state
|
381
401
|
|
382
|
-
while self.
|
383
|
-
current_epoch = self.
|
384
|
-
self.latest_epoch += 1
|
402
|
+
while self.latest_step < epochs * steps_per_epoch:
|
403
|
+
current_epoch = self.latest_step // steps_per_epoch
|
385
404
|
print(f"\nEpoch {current_epoch}/{epochs}")
|
386
405
|
start_time = time.time()
|
387
406
|
epoch_loss = 0
|
388
407
|
|
389
408
|
if process_index == 0:
|
390
409
|
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
|
391
|
-
epoch_loss, current_step, train_state, rng_state = train_loop(
|
410
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, pbar, train_state, rng_state)
|
392
411
|
else:
|
393
|
-
epoch_loss, current_step, train_state, rng_state = train_loop(
|
394
|
-
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP
|
395
|
-
|
412
|
+
epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, None, train_state, rng_state)
|
413
|
+
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
414
|
+
|
415
|
+
self.latest_step = current_step
|
396
416
|
end_time = time.time()
|
397
417
|
self.state = train_state
|
398
418
|
self.rngstate = rng_state
|
@@ -402,7 +422,7 @@ class SimpleTrainer:
|
|
402
422
|
if avg_loss < self.best_loss:
|
403
423
|
self.best_loss = avg_loss
|
404
424
|
self.best_state = train_state
|
405
|
-
self.save(current_epoch)
|
425
|
+
self.save(current_epoch, current_step)
|
406
426
|
|
407
427
|
if process_index == 0:
|
408
428
|
if self.wandb is not None:
|
@@ -415,4 +435,4 @@ class SimpleTrainer:
|
|
415
435
|
}, step=current_step)
|
416
436
|
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
417
437
|
self.save(epochs)
|
418
|
-
return self.state
|
438
|
+
return self.state
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.9
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
@@ -234,6 +234,23 @@ plotImages(samples, dpi=300)
|
|
234
234
|
|
235
235
|
## Gallery
|
236
236
|
|
237
|
+
### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
|
238
|
+
Model trained on Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M on TPU-v4-32:
|
239
|
+
`a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful landscape with a river with mountains, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a beautiful forest with a river and sunlight, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden, a big mansion with a garden`
|
240
|
+
|
241
|
+
**Params**:
|
242
|
+
`Dataset: Laion-Aesthetics 12M + CC12M + MS COCO + 1M aesthetic 6+ subset of COYO-700M`
|
243
|
+
`Batch size: 256`
|
244
|
+
`Image Size: 128`
|
245
|
+
`Training Epochs: 5`
|
246
|
+
`Steps per epoch: 74573`
|
247
|
+
`Model Configurations: feature_depths=[128, 256, 512, 1024]`
|
248
|
+
|
249
|
+
`Training Noise Schedule: EDMNoiseScheduler`
|
250
|
+
`Inference Noise Schedule: KarrasEDMPredictor`
|
251
|
+
|
252
|
+

|
253
|
+
|
237
254
|
### Images generated by Euler Ancestral Sampler in 200 Steps [text2image with CFG]
|
238
255
|
Images generated by the following prompts using classifier free guidance with guidance factor = 2:
|
239
256
|
`'water tulip, a water lily, a water lily, a water lily, a photo of a marigold, a water lily, a water lily, a photo of a lotus, a photo of a lotus, a photo of a lotus, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose, a photo of a rose'`
|
@@ -1,14 +1,14 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
-
flaxdiff/models/attention.py,sha256=
|
5
|
-
flaxdiff/models/common.py,sha256=
|
4
|
+
flaxdiff/models/attention.py,sha256=YyVI3dTAMB8cS8VWHgtIigr2YY-MYfFTlaNDfjNJOCk,12596
|
5
|
+
flaxdiff/models/common.py,sha256=nh32GIfgT_vVab4DEFiRAns5WGKbv6L5xNhzzfKKyBs,10590
|
6
6
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
|
-
flaxdiff/models/simple_unet.py,sha256=
|
8
|
-
flaxdiff/models/simple_vit.py,sha256=
|
7
|
+
flaxdiff/models/simple_unet.py,sha256=_elSWNaB3EG-DwnrdIPVPF4OkU0xaa2IJk6OVITOwWM,9691
|
8
|
+
flaxdiff/models/simple_vit.py,sha256=xD23i1b7WEvoH4tUMsLyCe9ebDcv-PpaV0Nso38Jlb8,3887
|
9
9
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
10
10
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
11
|
-
flaxdiff/models/autoencoder/diffusers.py,sha256=
|
11
|
+
flaxdiff/models/autoencoder/diffusers.py,sha256=l4teVksXd9XCCQWcVn9eB820xJyLT8hpg1CXQ_aHZ6M,3611
|
12
12
|
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
13
13
|
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
14
14
|
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
@@ -30,9 +30,9 @@ flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k
|
|
30
30
|
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
31
31
|
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
32
32
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
33
|
-
flaxdiff/trainer/diffusion_trainer.py,sha256=
|
34
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
35
|
-
flaxdiff-0.1.
|
36
|
-
flaxdiff-0.1.
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
33
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=z-ERdPt8mB6drXXlLjbGpbPreDIQlGmJFPRJhaoEZ1M,9242
|
34
|
+
flaxdiff/trainer/simple_trainer.py,sha256=Dv2F7e2PQS_2b972iRr66odCcPPdJ9cZAD5t9LguOiw,18110
|
35
|
+
flaxdiff-0.1.9.dist-info/METADATA,sha256=HhZlM5rBZrOSpNhS8KpeBCoXSmbsHy8ZAKY7gj10P0c,22082
|
36
|
+
flaxdiff-0.1.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
37
|
+
flaxdiff-0.1.9.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
38
|
+
flaxdiff-0.1.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|