flaxdiff 0.1.9__tar.gz → 0.1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/PKG-INFO +1 -1
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/attention.py +17 -8
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/common.py +5 -3
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/simple_unet.py +28 -16
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/diffusion_trainer.py +4 -2
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/simple_trainer.py +6 -3
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/setup.py +1 -1
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/README.md +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.9 → flaxdiff-0.1.11}/setup.cfg +0 -0
@@ -23,6 +23,7 @@ class EfficientAttention(nn.Module):
|
|
23
23
|
precision: PrecisionLike = None
|
24
24
|
use_bias: bool = True
|
25
25
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
26
|
+
force_fp32_for_softmax: bool = True
|
26
27
|
|
27
28
|
def setup(self):
|
28
29
|
inner_dim = self.dim_head * self.heads
|
@@ -114,6 +115,7 @@ class NormalAttention(nn.Module):
|
|
114
115
|
precision: PrecisionLike = None
|
115
116
|
use_bias: bool = True
|
116
117
|
kernel_init: Callable = lambda : kernel_init(1.0)
|
118
|
+
force_fp32_for_softmax: bool = True
|
117
119
|
|
118
120
|
def setup(self):
|
119
121
|
inner_dim = self.dim_head * self.heads
|
@@ -157,7 +159,7 @@ class NormalAttention(nn.Module):
|
|
157
159
|
|
158
160
|
hidden_states = nn.dot_product_attention(
|
159
161
|
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
160
|
-
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=
|
162
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
161
163
|
deterministic=True
|
162
164
|
)
|
163
165
|
proj = self.proj_attn(hidden_states)
|
@@ -237,6 +239,7 @@ class BasicTransformerBlock(nn.Module):
|
|
237
239
|
use_flash_attention:bool = False
|
238
240
|
use_cross_only:bool = False
|
239
241
|
only_pure_attention:bool = False
|
242
|
+
force_fp32_for_softmax: bool = True
|
240
243
|
|
241
244
|
def setup(self):
|
242
245
|
if self.use_flash_attention:
|
@@ -252,7 +255,8 @@ class BasicTransformerBlock(nn.Module):
|
|
252
255
|
precision=self.precision,
|
253
256
|
use_bias=self.use_bias,
|
254
257
|
dtype=self.dtype,
|
255
|
-
kernel_init=self.kernel_init
|
258
|
+
kernel_init=self.kernel_init,
|
259
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax
|
256
260
|
)
|
257
261
|
self.attention2 = attenBlock(
|
258
262
|
query_dim=self.query_dim,
|
@@ -262,7 +266,8 @@ class BasicTransformerBlock(nn.Module):
|
|
262
266
|
precision=self.precision,
|
263
267
|
use_bias=self.use_bias,
|
264
268
|
dtype=self.dtype,
|
265
|
-
kernel_init=self.kernel_init
|
269
|
+
kernel_init=self.kernel_init,
|
270
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax
|
266
271
|
)
|
267
272
|
|
268
273
|
self.ff = FlaxFeedForward(dim=self.query_dim)
|
@@ -296,6 +301,8 @@ class TransformerBlock(nn.Module):
|
|
296
301
|
use_flash_attention:bool = False
|
297
302
|
use_self_and_cross:bool = True
|
298
303
|
only_pure_attention:bool = False
|
304
|
+
force_fp32_for_softmax: bool = True
|
305
|
+
kernel_init: Callable = lambda : kernel_init(1.0)
|
299
306
|
|
300
307
|
@nn.compact
|
301
308
|
def __call__(self, x, context=None):
|
@@ -306,12 +313,12 @@ class TransformerBlock(nn.Module):
|
|
306
313
|
if self.use_linear_attention:
|
307
314
|
projected_x = nn.Dense(features=inner_dim,
|
308
315
|
use_bias=False, precision=self.precision,
|
309
|
-
kernel_init=kernel_init(
|
316
|
+
kernel_init=self.kernel_init(),
|
310
317
|
dtype=self.dtype, name=f'project_in')(normed_x)
|
311
318
|
else:
|
312
319
|
projected_x = nn.Conv(
|
313
320
|
features=inner_dim, kernel_size=(1, 1),
|
314
|
-
kernel_init=kernel_init(
|
321
|
+
kernel_init=self.kernel_init(),
|
315
322
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
316
323
|
precision=self.precision, name=f'project_in_conv',
|
317
324
|
)(normed_x)
|
@@ -331,19 +338,21 @@ class TransformerBlock(nn.Module):
|
|
331
338
|
dtype=self.dtype,
|
332
339
|
use_flash_attention=self.use_flash_attention,
|
333
340
|
use_cross_only=(not self.use_self_and_cross),
|
334
|
-
only_pure_attention=self.only_pure_attention
|
341
|
+
only_pure_attention=self.only_pure_attention,
|
342
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
343
|
+
kernel_init=self.kernel_init
|
335
344
|
)(projected_x, context)
|
336
345
|
|
337
346
|
if self.use_projection == True:
|
338
347
|
if self.use_linear_attention:
|
339
348
|
projected_x = nn.Dense(features=C, precision=self.precision,
|
340
349
|
dtype=self.dtype, use_bias=False,
|
341
|
-
kernel_init=kernel_init(
|
350
|
+
kernel_init=self.kernel_init(),
|
342
351
|
name=f'project_out')(projected_x)
|
343
352
|
else:
|
344
353
|
projected_x = nn.Conv(
|
345
354
|
features=C, kernel_size=(1, 1),
|
346
|
-
kernel_init=kernel_init(
|
355
|
+
kernel_init=self.kernel_init(),
|
347
356
|
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
348
357
|
precision=self.precision, name=f'project_out_conv',
|
349
358
|
)(projected_x)
|
@@ -267,15 +267,17 @@ class ResidualBlock(nn.Module):
|
|
267
267
|
kernel_init:Callable=kernel_init(1.0)
|
268
268
|
dtype: Optional[Dtype] = None
|
269
269
|
precision: PrecisionLike = None
|
270
|
+
named_norms:bool=False
|
270
271
|
|
271
272
|
def setup(self):
|
272
273
|
if self.norm_groups > 0:
|
273
274
|
norm = partial(nn.GroupNorm, self.norm_groups)
|
275
|
+
self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
|
276
|
+
self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
|
274
277
|
else:
|
275
278
|
norm = partial(nn.RMSNorm, 1e-5)
|
276
|
-
|
277
|
-
|
278
|
-
self.norm2 = norm()
|
279
|
+
self.norm1 = norm()
|
280
|
+
self.norm2 = norm()
|
279
281
|
|
280
282
|
@nn.compact
|
281
283
|
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
|
@@ -19,15 +19,16 @@ class Unet(nn.Module):
|
|
19
19
|
norm_groups:int=8
|
20
20
|
dtype: Optional[Dtype] = None
|
21
21
|
precision: PrecisionLike = None
|
22
|
+
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
|
23
|
+
kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)
|
22
24
|
|
23
25
|
def setup(self):
|
24
26
|
if self.norm_groups > 0:
|
25
27
|
norm = partial(nn.GroupNorm, self.norm_groups)
|
28
|
+
self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
|
26
29
|
else:
|
27
30
|
norm = partial(nn.RMSNorm, 1e-5)
|
28
|
-
|
29
|
-
# self.last_up_norm = norm()
|
30
|
-
self.conv_out_norm = norm()
|
31
|
+
self.conv_out_norm = norm()
|
31
32
|
|
32
33
|
@nn.compact
|
33
34
|
def __call__(self, x, temb, textcontext):
|
@@ -49,7 +50,7 @@ class Unet(nn.Module):
|
|
49
50
|
features=self.feature_depths[0],
|
50
51
|
kernel_size=(3, 3),
|
51
52
|
strides=(1, 1),
|
52
|
-
kernel_init=kernel_init(1.0),
|
53
|
+
kernel_init=self.kernel_init(1.0),
|
53
54
|
dtype=self.dtype,
|
54
55
|
precision=self.precision
|
55
56
|
)(x)
|
@@ -64,13 +65,14 @@ class Unet(nn.Module):
|
|
64
65
|
down_conv_type,
|
65
66
|
name=f"down_{i}_residual_{j}",
|
66
67
|
features=dim_in,
|
67
|
-
kernel_init=kernel_init(1.0),
|
68
|
+
kernel_init=self.kernel_init(1.0),
|
68
69
|
kernel_size=(3, 3),
|
69
70
|
strides=(1, 1),
|
70
71
|
activation=self.activation,
|
71
72
|
norm_groups=self.norm_groups,
|
72
73
|
dtype=self.dtype,
|
73
|
-
precision=self.precision
|
74
|
+
precision=self.precision,
|
75
|
+
named_norms=self.named_norms
|
74
76
|
)(x, temb)
|
75
77
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
76
78
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
@@ -80,6 +82,8 @@ class Unet(nn.Module):
|
|
80
82
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
81
83
|
precision=attention_config.get("precision", self.precision),
|
82
84
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
85
|
+
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
|
86
|
+
kernel_init=self.kernel_init(1.0),
|
83
87
|
name=f"down_{i}_attention_{j}")(x, textcontext)
|
84
88
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
85
89
|
downs.append(x)
|
@@ -102,13 +106,14 @@ class Unet(nn.Module):
|
|
102
106
|
middle_conv_type,
|
103
107
|
name=f"middle_res1_{j}",
|
104
108
|
features=middle_dim_out,
|
105
|
-
kernel_init=kernel_init(1.0),
|
109
|
+
kernel_init=self.kernel_init(1.0),
|
106
110
|
kernel_size=(3, 3),
|
107
111
|
strides=(1, 1),
|
108
112
|
activation=self.activation,
|
109
113
|
norm_groups=self.norm_groups,
|
110
114
|
dtype=self.dtype,
|
111
|
-
precision=self.precision
|
115
|
+
precision=self.precision,
|
116
|
+
named_norms=self.named_norms
|
112
117
|
)(x, temb)
|
113
118
|
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
|
114
119
|
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
|
@@ -119,18 +124,21 @@ class Unet(nn.Module):
|
|
119
124
|
use_self_and_cross=False,
|
120
125
|
precision=middle_attention.get("precision", self.precision),
|
121
126
|
only_pure_attention=middle_attention.get("only_pure_attention", True),
|
127
|
+
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
128
|
+
kernel_init=self.kernel_init(1.0),
|
122
129
|
name=f"middle_attention_{j}")(x, textcontext)
|
123
130
|
x = ResidualBlock(
|
124
131
|
middle_conv_type,
|
125
132
|
name=f"middle_res2_{j}",
|
126
133
|
features=middle_dim_out,
|
127
|
-
kernel_init=kernel_init(1.0),
|
134
|
+
kernel_init=self.kernel_init(1.0),
|
128
135
|
kernel_size=(3, 3),
|
129
136
|
strides=(1, 1),
|
130
137
|
activation=self.activation,
|
131
138
|
norm_groups=self.norm_groups,
|
132
139
|
dtype=self.dtype,
|
133
|
-
precision=self.precision
|
140
|
+
precision=self.precision,
|
141
|
+
named_norms=self.named_norms
|
134
142
|
)(x, temb)
|
135
143
|
|
136
144
|
# Upscaling Blocks
|
@@ -145,13 +153,14 @@ class Unet(nn.Module):
|
|
145
153
|
up_conv_type,# if j == 0 else "separable",
|
146
154
|
name=f"up_{i}_residual_{j}",
|
147
155
|
features=dim_out,
|
148
|
-
kernel_init=kernel_init(1.0),
|
156
|
+
kernel_init=self.kernel_init(1.0),
|
149
157
|
kernel_size=kernel_size,
|
150
158
|
strides=(1, 1),
|
151
159
|
activation=self.activation,
|
152
160
|
norm_groups=self.norm_groups,
|
153
161
|
dtype=self.dtype,
|
154
|
-
precision=self.precision
|
162
|
+
precision=self.precision,
|
163
|
+
named_norms=self.named_norms
|
155
164
|
)(x, temb)
|
156
165
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
157
166
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
@@ -161,6 +170,8 @@ class Unet(nn.Module):
|
|
161
170
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
162
171
|
precision=attention_config.get("precision", self.precision),
|
163
172
|
only_pure_attention=attention_config.get("only_pure_attention", True),
|
173
|
+
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
|
174
|
+
kernel_init=self.kernel_init(1.0),
|
164
175
|
name=f"up_{i}_attention_{j}")(x, textcontext)
|
165
176
|
# print("Upscaling ", i, x.shape)
|
166
177
|
if i != len(feature_depths) - 1:
|
@@ -179,7 +190,7 @@ class Unet(nn.Module):
|
|
179
190
|
features=self.feature_depths[0],
|
180
191
|
kernel_size=(3, 3),
|
181
192
|
strides=(1, 1),
|
182
|
-
kernel_init=kernel_init(1.0),
|
193
|
+
kernel_init=self.kernel_init(1.0),
|
183
194
|
dtype=self.dtype,
|
184
195
|
precision=self.precision
|
185
196
|
)(x)
|
@@ -190,13 +201,14 @@ class Unet(nn.Module):
|
|
190
201
|
conv_type,
|
191
202
|
name="final_residual",
|
192
203
|
features=self.feature_depths[0],
|
193
|
-
kernel_init=kernel_init(1.0),
|
204
|
+
kernel_init=self.kernel_init(1.0),
|
194
205
|
kernel_size=(3,3),
|
195
206
|
strides=(1, 1),
|
196
207
|
activation=self.activation,
|
197
208
|
norm_groups=self.norm_groups,
|
198
209
|
dtype=self.dtype,
|
199
|
-
precision=self.precision
|
210
|
+
precision=self.precision,
|
211
|
+
named_norms=self.named_norms
|
200
212
|
)(x, temb)
|
201
213
|
|
202
214
|
x = self.conv_out_norm(x)
|
@@ -208,7 +220,7 @@ class Unet(nn.Module):
|
|
208
220
|
kernel_size=(3, 3),
|
209
221
|
strides=(1, 1),
|
210
222
|
# activation=jax.nn.mish
|
211
|
-
kernel_init=kernel_init(0.0),
|
223
|
+
kernel_init=self.kernel_init(0.0),
|
212
224
|
dtype=self.dtype,
|
213
225
|
precision=self.precision
|
214
226
|
)(x)
|
@@ -16,6 +16,7 @@ from flaxdiff.utils import RandomMarkovState
|
|
16
16
|
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
17
|
|
18
18
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
from flax.training.dynamic_scale import DynamicScale
|
19
20
|
|
20
21
|
class TrainState(SimpleTrainState):
|
21
22
|
rngs: jax.random.PRNGKey
|
@@ -83,7 +84,8 @@ class DiffusionTrainer(SimpleTrainer):
|
|
83
84
|
new_state = existing_state
|
84
85
|
|
85
86
|
if param_transforms is not None:
|
86
|
-
params = param_transforms(params)
|
87
|
+
new_state['params'] = param_transforms(new_state['params'])
|
88
|
+
new_state['ema_params'] = param_transforms(new_state['ema_params'])
|
87
89
|
|
88
90
|
state = TrainState.create(
|
89
91
|
apply_fn=model.apply,
|
@@ -92,7 +94,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
92
94
|
tx=optimizer,
|
93
95
|
rngs=rngs,
|
94
96
|
metrics=Metrics.empty(),
|
95
|
-
dynamic_scale =
|
97
|
+
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
96
98
|
)
|
97
99
|
|
98
100
|
if existing_best_state is not None:
|
@@ -22,7 +22,7 @@ from jax.experimental.shard_map import shard_map
|
|
22
22
|
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
|
23
23
|
from termcolor import colored
|
24
24
|
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
|
-
|
25
|
+
from flax.training.dynamic_scale import DynamicScale
|
26
26
|
from flaxdiff.utils import RandomMarkovState
|
27
27
|
|
28
28
|
PROCESS_COLOR_MAP = {
|
@@ -68,7 +68,7 @@ class Metrics(metrics.Collection):
|
|
68
68
|
# Define the TrainState
|
69
69
|
class SimpleTrainState(train_state.TrainState):
|
70
70
|
metrics: Metrics
|
71
|
-
dynamic_scale:
|
71
|
+
dynamic_scale: DynamicScale
|
72
72
|
|
73
73
|
class SimpleTrainer:
|
74
74
|
state: SimpleTrainState
|
@@ -177,13 +177,16 @@ class SimpleTrainer:
|
|
177
177
|
params = model.init(subkey, **input_vars)
|
178
178
|
else:
|
179
179
|
params = existing_state['params']
|
180
|
+
|
181
|
+
if param_transforms is not None:
|
182
|
+
params = param_transforms(params)
|
180
183
|
|
181
184
|
state = SimpleTrainState.create(
|
182
185
|
apply_fn=model.apply,
|
183
186
|
params=params,
|
184
187
|
tx=optimizer,
|
185
188
|
metrics=Metrics.empty(),
|
186
|
-
dynamic_scale =
|
189
|
+
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
187
190
|
)
|
188
191
|
if existing_best_state is not None:
|
189
192
|
best_state = state.replace(
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.11',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|