flaxdiff 0.1.21__py3-none-any.whl → 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +33 -20
- flaxdiff/models/common.py +20 -15
- flaxdiff/models/simple_vit.py +83 -62
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.23.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.23.dist-info}/RECORD +7 -7
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.23.dist-info}/WHEEL +1 -1
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.23.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -21,6 +21,7 @@ import urllib
|
|
21
21
|
|
22
22
|
import PIL.Image
|
23
23
|
import cv2
|
24
|
+
import traceback
|
24
25
|
|
25
26
|
USER_AGENT = get_datasets_user_agent()
|
26
27
|
|
@@ -43,7 +44,27 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
43
44
|
return image
|
44
45
|
|
45
46
|
|
46
|
-
def default_image_processor(
|
47
|
+
def default_image_processor(
|
48
|
+
image, image_shape,
|
49
|
+
min_image_shape=(128, 128),
|
50
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
51
|
+
downscale_interpolation=cv2.INTER_AREA,
|
52
|
+
):
|
53
|
+
image = np.array(image)
|
54
|
+
original_height, original_width = image.shape[:2]
|
55
|
+
# check if the image is too small
|
56
|
+
if min(original_height, original_width) < min(min_image_shape):
|
57
|
+
return None, original_height, original_width
|
58
|
+
# check if wrong aspect ratio
|
59
|
+
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
60
|
+
return None, original_height, original_width
|
61
|
+
# check if the variance is too low
|
62
|
+
if np.std(image) < 1e-5:
|
63
|
+
return None, original_height, original_width
|
64
|
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
65
|
+
downscale = max(original_width, original_height) > max(image_shape)
|
66
|
+
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
67
|
+
|
47
68
|
image = A.longest_max_size(image, max(
|
48
69
|
image_shape), interpolation=interpolation)
|
49
70
|
image = A.pad(
|
@@ -53,7 +74,7 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
|
|
53
74
|
border_mode=cv2.BORDER_CONSTANT,
|
54
75
|
value=[255, 255, 255],
|
55
76
|
)
|
56
|
-
return image
|
77
|
+
return image, original_height, original_width
|
57
78
|
|
58
79
|
|
59
80
|
def map_sample(
|
@@ -72,23 +93,13 @@ def map_sample(
|
|
72
93
|
if image is None:
|
73
94
|
return
|
74
95
|
|
75
|
-
image =
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
82
|
-
return
|
83
|
-
# check if the variance is too low
|
84
|
-
if np.std(image) < 1e-4:
|
96
|
+
image, original_height, original_width = image_processor(
|
97
|
+
image, image_shape, min_image_shape=min_image_shape,
|
98
|
+
upscale_interpolation=upscale_interpolation,
|
99
|
+
downscale_interpolation=downscale_interpolation,)
|
100
|
+
|
101
|
+
if image is None:
|
85
102
|
return
|
86
|
-
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
87
|
-
downscale = max(original_width, original_height) > max(image_shape)
|
88
|
-
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
89
|
-
|
90
|
-
image = image_processor(
|
91
|
-
image, image_shape, interpolation=interpolation)
|
92
103
|
|
93
104
|
data_queue.put({
|
94
105
|
"url": url,
|
@@ -98,7 +109,8 @@ def map_sample(
|
|
98
109
|
"original_width": original_width,
|
99
110
|
})
|
100
111
|
except Exception as e:
|
101
|
-
print(f"Error
|
112
|
+
print(f"Error maping sample {url}", e)
|
113
|
+
traceback.print_exc()
|
102
114
|
# error_queue.put_nowait({
|
103
115
|
# "url": url,
|
104
116
|
# "caption": caption,
|
@@ -122,7 +134,8 @@ def map_batch(
|
|
122
134
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
123
135
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
124
136
|
except Exception as e:
|
125
|
-
print(f"Error
|
137
|
+
print(f"Error maping batch", e)
|
138
|
+
traceback.print_exc()
|
126
139
|
# error_queue.put_nowait({
|
127
140
|
# "batch": batch,
|
128
141
|
# "error": str(e)
|
flaxdiff/models/common.py
CHANGED
@@ -108,12 +108,13 @@ class FourierEmbedding(nn.Module):
|
|
108
108
|
class TimeProjection(nn.Module):
|
109
109
|
features:int
|
110
110
|
activation:Callable=jax.nn.gelu
|
111
|
+
kernel_init:Callable=partial(kernel_init, 1.0)
|
111
112
|
|
112
113
|
@nn.compact
|
113
114
|
def __call__(self, x):
|
114
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(
|
115
|
+
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
|
115
116
|
x = self.activation(x)
|
116
|
-
x = nn.DenseGeneral(self.features, kernel_init=kernel_init(
|
117
|
+
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
|
117
118
|
x = self.activation(x)
|
118
119
|
return x
|
119
120
|
|
@@ -122,7 +123,7 @@ class SeparableConv(nn.Module):
|
|
122
123
|
kernel_size:tuple=(3, 3)
|
123
124
|
strides:tuple=(1, 1)
|
124
125
|
use_bias:bool=False
|
125
|
-
kernel_init:Callable=kernel_init
|
126
|
+
kernel_init:Callable=partial(kernel_init, 1.0)
|
126
127
|
padding:str="SAME"
|
127
128
|
dtype: Optional[Dtype] = None
|
128
129
|
precision: PrecisionLike = None
|
@@ -132,7 +133,7 @@ class SeparableConv(nn.Module):
|
|
132
133
|
in_features = x.shape[-1]
|
133
134
|
depthwise = nn.Conv(
|
134
135
|
features=in_features, kernel_size=self.kernel_size,
|
135
|
-
strides=self.strides, kernel_init=self.kernel_init,
|
136
|
+
strides=self.strides, kernel_init=self.kernel_init(),
|
136
137
|
feature_group_count=in_features, use_bias=self.use_bias,
|
137
138
|
padding=self.padding,
|
138
139
|
dtype=self.dtype,
|
@@ -140,7 +141,7 @@ class SeparableConv(nn.Module):
|
|
140
141
|
)(x)
|
141
142
|
pointwise = nn.Conv(
|
142
143
|
features=self.features, kernel_size=(1, 1),
|
143
|
-
strides=(1, 1), kernel_init=self.kernel_init,
|
144
|
+
strides=(1, 1), kernel_init=self.kernel_init(),
|
144
145
|
use_bias=self.use_bias,
|
145
146
|
dtype=self.dtype,
|
146
147
|
precision=self.precision
|
@@ -152,7 +153,7 @@ class ConvLayer(nn.Module):
|
|
152
153
|
features:int
|
153
154
|
kernel_size:tuple=(3, 3)
|
154
155
|
strides:tuple=(1, 1)
|
155
|
-
kernel_init:Callable=kernel_init
|
156
|
+
kernel_init:Callable=partial(kernel_init, 1.0)
|
156
157
|
dtype: Optional[Dtype] = None
|
157
158
|
precision: PrecisionLike = None
|
158
159
|
|
@@ -163,7 +164,7 @@ class ConvLayer(nn.Module):
|
|
163
164
|
features=self.features,
|
164
165
|
kernel_size=self.kernel_size,
|
165
166
|
strides=self.strides,
|
166
|
-
kernel_init=self.kernel_init,
|
167
|
+
kernel_init=self.kernel_init(),
|
167
168
|
dtype=self.dtype,
|
168
169
|
precision=self.precision
|
169
170
|
)
|
@@ -182,7 +183,7 @@ class ConvLayer(nn.Module):
|
|
182
183
|
features=self.features,
|
183
184
|
kernel_size=self.kernel_size,
|
184
185
|
strides=self.strides,
|
185
|
-
kernel_init=self.kernel_init,
|
186
|
+
kernel_init=self.kernel_init(),
|
186
187
|
dtype=self.dtype,
|
187
188
|
precision=self.precision
|
188
189
|
)
|
@@ -191,7 +192,7 @@ class ConvLayer(nn.Module):
|
|
191
192
|
features=self.features,
|
192
193
|
kernel_size=self.kernel_size,
|
193
194
|
strides=self.strides,
|
194
|
-
kernel_init=self.kernel_init,
|
195
|
+
kernel_init=self.kernel_init(),
|
195
196
|
dtype=self.dtype,
|
196
197
|
precision=self.precision
|
197
198
|
)
|
@@ -205,6 +206,7 @@ class Upsample(nn.Module):
|
|
205
206
|
activation:Callable=jax.nn.swish
|
206
207
|
dtype: Optional[Dtype] = None
|
207
208
|
precision: PrecisionLike = None
|
209
|
+
kernel_init:Callable=partial(kernel_init, 1.0)
|
208
210
|
|
209
211
|
@nn.compact
|
210
212
|
def __call__(self, x, residual=None):
|
@@ -218,7 +220,8 @@ class Upsample(nn.Module):
|
|
218
220
|
kernel_size=(3, 3),
|
219
221
|
strides=(1, 1),
|
220
222
|
dtype=self.dtype,
|
221
|
-
precision=self.precision
|
223
|
+
precision=self.precision,
|
224
|
+
kernel_init=self.kernel_init()
|
222
225
|
)(out)
|
223
226
|
if residual is not None:
|
224
227
|
out = jnp.concatenate([out, residual], axis=-1)
|
@@ -230,6 +233,7 @@ class Downsample(nn.Module):
|
|
230
233
|
activation:Callable=jax.nn.swish
|
231
234
|
dtype: Optional[Dtype] = None
|
232
235
|
precision: PrecisionLike = None
|
236
|
+
kernel_init:Callable=partial(kernel_init, 1.0)
|
233
237
|
|
234
238
|
@nn.compact
|
235
239
|
def __call__(self, x, residual=None):
|
@@ -239,7 +243,8 @@ class Downsample(nn.Module):
|
|
239
243
|
kernel_size=(3, 3),
|
240
244
|
strides=(2, 2),
|
241
245
|
dtype=self.dtype,
|
242
|
-
precision=self.precision
|
246
|
+
precision=self.precision,
|
247
|
+
kernel_init=self.kernel_init()
|
243
248
|
)(x)
|
244
249
|
if residual is not None:
|
245
250
|
if residual.shape[1] > out.shape[1]:
|
@@ -264,7 +269,7 @@ class ResidualBlock(nn.Module):
|
|
264
269
|
direction:str=None
|
265
270
|
res:int=2
|
266
271
|
norm_groups:int=8
|
267
|
-
kernel_init:Callable=kernel_init
|
272
|
+
kernel_init:Callable=partial(kernel_init, 1.0)
|
268
273
|
dtype: Optional[Dtype] = None
|
269
274
|
precision: PrecisionLike = None
|
270
275
|
named_norms:bool=False
|
@@ -291,7 +296,7 @@ class ResidualBlock(nn.Module):
|
|
291
296
|
features=self.features,
|
292
297
|
kernel_size=self.kernel_size,
|
293
298
|
strides=self.strides,
|
294
|
-
kernel_init=self.kernel_init,
|
299
|
+
kernel_init=self.kernel_init(),
|
295
300
|
name="conv1",
|
296
301
|
dtype=self.dtype,
|
297
302
|
precision=self.precision
|
@@ -316,7 +321,7 @@ class ResidualBlock(nn.Module):
|
|
316
321
|
features=self.features,
|
317
322
|
kernel_size=self.kernel_size,
|
318
323
|
strides=self.strides,
|
319
|
-
kernel_init=self.kernel_init,
|
324
|
+
kernel_init=self.kernel_init(),
|
320
325
|
name="conv2",
|
321
326
|
dtype=self.dtype,
|
322
327
|
precision=self.precision
|
@@ -328,7 +333,7 @@ class ResidualBlock(nn.Module):
|
|
328
333
|
features=self.features,
|
329
334
|
kernel_size=(1, 1),
|
330
335
|
strides=1,
|
331
|
-
kernel_init=self.kernel_init,
|
336
|
+
kernel_init=self.kernel_init(),
|
332
337
|
name="residual_conv",
|
333
338
|
dtype=self.dtype,
|
334
339
|
precision=self.precision
|
flaxdiff/models/simple_vit.py
CHANGED
@@ -3,9 +3,20 @@
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
5
|
from flax import linen as nn
|
6
|
-
from typing import Callable, Any
|
6
|
+
from typing import Callable, Any, Optional, Tuple
|
7
7
|
from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
|
8
8
|
from .attention import TransformerBlock
|
9
|
+
from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
|
10
|
+
import einops
|
11
|
+
from flax.typing import Dtype, PrecisionLike
|
12
|
+
from functools import partial
|
13
|
+
|
14
|
+
def unpatchify(x, channels=3):
|
15
|
+
patch_size = int((x.shape[2] // channels) ** 0.5)
|
16
|
+
h = w = int(x.shape[1] ** .5)
|
17
|
+
assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
|
18
|
+
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
|
19
|
+
return x
|
9
20
|
|
10
21
|
class PatchEmbedding(nn.Module):
|
11
22
|
patch_size: int
|
@@ -37,39 +48,28 @@ class PositionalEncoding(nn.Module):
|
|
37
48
|
(1, self.max_len, self.embedding_dim))
|
38
49
|
return x + pe[:, :x.shape[1], :]
|
39
50
|
|
40
|
-
class
|
41
|
-
|
42
|
-
num_heads: int
|
43
|
-
dropout_rate: float = 0.1
|
44
|
-
dtype: Any = jnp.float32
|
45
|
-
precision: Any = jax.lax.Precision.HIGH
|
46
|
-
use_projection: bool = False
|
47
|
-
|
48
|
-
@nn.compact
|
49
|
-
def __call__(self, x, context=None):
|
50
|
-
for _ in range(self.num_layers):
|
51
|
-
x = TransformerBlock(
|
52
|
-
heads=self.num_heads,
|
53
|
-
dim_head=x.shape[-1] // self.num_heads,
|
54
|
-
dropout_rate=self.dropout_rate,
|
55
|
-
dtype=self.dtype,
|
56
|
-
precision=self.precision,
|
57
|
-
use_self_and_cross=True,
|
58
|
-
use_projection=self.use_projection,
|
59
|
-
)(x, context)
|
60
|
-
return x
|
61
|
-
|
62
|
-
class VisionTransformer(nn.Module):
|
51
|
+
class UViT(nn.Module):
|
52
|
+
output_channels:int=3
|
63
53
|
patch_size: int = 16
|
64
|
-
|
54
|
+
emb_features:int=768,
|
65
55
|
num_layers: int = 12
|
66
56
|
num_heads: int = 12
|
67
|
-
emb_features: int = 256
|
68
57
|
dropout_rate: float = 0.1
|
69
58
|
dtype: Any = jnp.float32
|
70
59
|
precision: Any = jax.lax.Precision.HIGH
|
71
60
|
use_projection: bool = False
|
72
|
-
|
61
|
+
activation:Callable = jax.nn.swish
|
62
|
+
norm_groups:int=8
|
63
|
+
dtype: Optional[Dtype] = None
|
64
|
+
precision: PrecisionLike = None
|
65
|
+
kernel_init: Callable = partial(kernel_init, 1.0)
|
66
|
+
|
67
|
+
def setup(self):
|
68
|
+
if self.norm_groups > 0:
|
69
|
+
self.norm = partial(nn.GroupNorm, self.norm_groups)
|
70
|
+
else:
|
71
|
+
self.norm = partial(nn.RMSNorm, 1e-5)
|
72
|
+
|
73
73
|
@nn.compact
|
74
74
|
def __call__(self, x, temb, textcontext=None):
|
75
75
|
# Time embedding
|
@@ -77,44 +77,65 @@ class VisionTransformer(nn.Module):
|
|
77
77
|
temb = TimeProjection(features=self.emb_features)(temb)
|
78
78
|
|
79
79
|
# Patch embedding
|
80
|
-
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.
|
80
|
+
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
81
81
|
dtype=self.dtype, precision=self.precision)(x)
|
82
|
+
num_patches = x.shape[1]
|
82
83
|
|
83
|
-
|
84
|
-
|
84
|
+
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
85
|
+
dtype=self.dtype, precision=self.precision)(textcontext)
|
86
|
+
num_text_tokens = textcontext.shape[1]
|
87
|
+
|
88
|
+
# print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')
|
85
89
|
|
86
|
-
num_patches = x.shape[1]
|
87
|
-
|
88
90
|
# Add time embedding
|
89
91
|
temb = jnp.expand_dims(temb, axis=1)
|
90
|
-
x = jnp.concatenate([x, temb], axis=1)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
92
|
+
x = jnp.concatenate([x, temb, context_emb], axis=1)
|
93
|
+
# print(f'Shape of x after time embedding: {x.shape}')
|
94
|
+
|
95
|
+
# Add positional encoding
|
96
|
+
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
|
97
|
+
|
98
|
+
# print(f'Shape of x after positional encoding: {x.shape}')
|
99
|
+
|
100
|
+
skips = []
|
101
|
+
# In blocks
|
102
|
+
for i in range(self.num_layers // 2):
|
103
|
+
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
104
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
105
|
+
use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
|
106
|
+
only_pure_attention=False,
|
107
|
+
kernel_init=self.kernel_init())(x)
|
108
|
+
skips.append(x)
|
109
|
+
|
110
|
+
# Middle block
|
111
|
+
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
112
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
113
|
+
use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True,
|
114
|
+
only_pure_attention=False,
|
115
|
+
kernel_init=self.kernel_init())(x)
|
116
|
+
|
117
|
+
# # Out blocks
|
118
|
+
for i in range(self.num_layers // 2):
|
119
|
+
skip = jnp.concatenate([x, skips.pop()], axis=-1)
|
120
|
+
skip = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
|
121
|
+
dtype=self.dtype, precision=self.precision)(skip)
|
122
|
+
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
|
123
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
124
|
+
use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
|
125
|
+
only_pure_attention=False,
|
126
|
+
kernel_init=self.kernel_init())(skip)
|
127
|
+
|
128
|
+
# print(f'Shape of x after transformer blocks: {x.shape}')
|
129
|
+
x = self.norm()(x)
|
130
|
+
|
131
|
+
# print(f'Shape of x after norm: {x.shape}')
|
132
|
+
|
133
|
+
patch_dim = self.patch_size ** 2 * self.output_channels
|
134
|
+
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
|
135
|
+
# print(f'Shape of x after patch dense layer: {x.shape}, patch_dim: {patch_dim}')
|
136
|
+
x = x[:, 1 + num_text_tokens:, :]
|
137
|
+
x = unpatchify(x, channels=self.output_channels)
|
138
|
+
# print(f'Shape of x after final dense layer: {x.shape}')
|
139
|
+
x = nn.Dense(features=self.output_channels, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
|
140
|
+
|
120
141
|
return x
|
@@ -1,13 +1,13 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
|
-
flaxdiff/models/common.py,sha256=
|
7
|
+
flaxdiff/models/common.py,sha256=lzfOHB-7Bjx83ZZzywXy2mjhwP2UOMKR11vdVSKvsCo,11068
|
8
8
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
9
|
flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=
|
10
|
+
flaxdiff/models/simple_vit.py,sha256=tXO0WIozq2C4j2GqphSM9mMAYEwj9fKr7rfm4G6vf4A,6403
|
11
11
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
12
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
13
|
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.23.dist-info/METADATA,sha256=iTgk4DY-kuALF-Y1U-a433c7pYRhSy-bXPQrtXh6d54,22083
|
38
|
+
flaxdiff-0.1.23.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
39
|
+
flaxdiff-0.1.23.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.23.dist-info/RECORD,,
|
File without changes
|