diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,709 @@
1
+ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+ from ...configuration_utils import ConfigMixin, register_to_config
25
+ from ...models import ModelMixin
26
+ from ...utils import BaseOutput
27
+ from .camera import create_pan_cameras
28
+
29
+
30
+ def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
31
+ r"""
32
+ Sample from the given discrete probability distribution with replacement.
33
+
34
+ The i-th bin is assumed to have mass pmf[i].
35
+
36
+ Args:
37
+ pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all()
38
+ n_samples: number of samples
39
+
40
+ Return:
41
+ indices sampled with replacement
42
+ """
43
+
44
+ *shape, support_size, last_dim = pmf.shape
45
+ assert last_dim == 1
46
+
47
+ cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)
48
+ inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))
49
+
50
+ return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)
51
+
52
+
53
+ def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
54
+ """
55
+ Concatenate x and its positional encodings, following NeRF.
56
+
57
+ Reference: https://arxiv.org/pdf/2210.04628.pdf
58
+ """
59
+ if min_deg == max_deg:
60
+ return x
61
+
62
+ scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device)
63
+ *shape, dim = x.shape
64
+ xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
65
+ assert xb.shape[-1] == dim * (max_deg - min_deg)
66
+ emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
67
+ return torch.cat([x, emb], dim=-1)
68
+
69
+
70
+ def encode_position(position):
71
+ return posenc_nerf(position, min_deg=0, max_deg=15)
72
+
73
+
74
+ def encode_direction(position, direction=None):
75
+ if direction is None:
76
+ return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
77
+ else:
78
+ return posenc_nerf(direction, min_deg=0, max_deg=8)
79
+
80
+
81
+ def _sanitize_name(x: str) -> str:
82
+ return x.replace(".", "__")
83
+
84
+
85
+ def integrate_samples(volume_range, ts, density, channels):
86
+ r"""
87
+ Function integrating the model output.
88
+
89
+ Args:
90
+ volume_range: Specifies the integral range [t0, t1]
91
+ ts: timesteps
92
+ density: torch.Tensor [batch_size, *shape, n_samples, 1]
93
+ channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
94
+ returns:
95
+ channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density
96
+ *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume
97
+ )
98
+ """
99
+
100
+ # 1. Calculate the weights
101
+ _, _, dt = volume_range.partition(ts)
102
+ ddensity = density * dt
103
+
104
+ mass = torch.cumsum(ddensity, dim=-2)
105
+ transmittance = torch.exp(-mass[..., -1, :])
106
+
107
+ alphas = 1.0 - torch.exp(-ddensity)
108
+ Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
109
+ # This is the probability of light hitting and reflecting off of
110
+ # something at depth [..., i, :].
111
+ weights = alphas * Ts
112
+
113
+ # 2. Integrate channels
114
+ channels = torch.sum(channels * weights, dim=-2)
115
+
116
+ return channels, weights, transmittance
117
+
118
+
119
+ class VoidNeRFModel(nn.Module):
120
+ """
121
+ Implements the default empty space model where all queries are rendered as background.
122
+ """
123
+
124
+ def __init__(self, background, channel_scale=255.0):
125
+ super().__init__()
126
+ background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale)
127
+
128
+ self.register_buffer("background", background)
129
+
130
+ def forward(self, position):
131
+ background = self.background[None].to(position.device)
132
+
133
+ shape = position.shape[:-1]
134
+ ones = [1] * (len(shape) - 1)
135
+ n_channels = background.shape[-1]
136
+ background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels])
137
+
138
+ return background
139
+
140
+
141
+ @dataclass
142
+ class VolumeRange:
143
+ t0: torch.Tensor
144
+ t1: torch.Tensor
145
+ intersected: torch.Tensor
146
+
147
+ def __post_init__(self):
148
+ assert self.t0.shape == self.t1.shape == self.intersected.shape
149
+
150
+ def partition(self, ts):
151
+ """
152
+ Partitions t0 and t1 into n_samples intervals.
153
+
154
+ Args:
155
+ ts: [batch_size, *shape, n_samples, 1]
156
+
157
+ Return:
158
+
159
+ lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size,
160
+ *shape, n_samples, 1]
161
+
162
+ where
163
+ ts \\in [lower, upper] deltas = upper - lower
164
+ """
165
+
166
+ mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
167
+ lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
168
+ upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
169
+ delta = upper - lower
170
+ assert lower.shape == upper.shape == delta.shape == ts.shape
171
+ return lower, upper, delta
172
+
173
+
174
+ class BoundingBoxVolume(nn.Module):
175
+ """
176
+ Axis-aligned bounding box defined by the two opposite corners.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ *,
182
+ bbox_min,
183
+ bbox_max,
184
+ min_dist: float = 0.0,
185
+ min_t_range: float = 1e-3,
186
+ ):
187
+ """
188
+ Args:
189
+ bbox_min: the left/bottommost corner of the bounding box
190
+ bbox_max: the other corner of the bounding box
191
+ min_dist: all rays should start at least this distance away from the origin.
192
+ """
193
+ super().__init__()
194
+
195
+ self.min_dist = min_dist
196
+ self.min_t_range = min_t_range
197
+
198
+ self.bbox_min = torch.tensor(bbox_min)
199
+ self.bbox_max = torch.tensor(bbox_max)
200
+ self.bbox = torch.stack([self.bbox_min, self.bbox_max])
201
+ assert self.bbox.shape == (2, 3)
202
+ assert min_dist >= 0.0
203
+ assert min_t_range > 0.0
204
+
205
+ def intersect(
206
+ self,
207
+ origin: torch.Tensor,
208
+ direction: torch.Tensor,
209
+ t0_lower: Optional[torch.Tensor] = None,
210
+ epsilon=1e-6,
211
+ ):
212
+ """
213
+ Args:
214
+ origin: [batch_size, *shape, 3]
215
+ direction: [batch_size, *shape, 3]
216
+ t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
217
+ params: Optional meta parameters in case Volume is parametric
218
+ epsilon: to stabilize calculations
219
+
220
+ Return:
221
+ A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with
222
+ the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to
223
+ be on the boundary of the volume.
224
+ """
225
+
226
+ batch_size, *shape, _ = origin.shape
227
+ ones = [1] * len(shape)
228
+ bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device)
229
+
230
+ def _safe_divide(a, b, epsilon=1e-6):
231
+ return a / torch.where(b < 0, b - epsilon, b + epsilon)
232
+
233
+ ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
234
+
235
+ # Cases to think about:
236
+ #
237
+ # 1. t1 <= t0: the ray does not pass through the AABB.
238
+ # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
239
+ # 3. t0 <= 0 <= t1: the ray starts from inside the BB
240
+ # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
241
+ #
242
+ # 1 and 4 are clearly handled from t0 < t1 below.
243
+ # Making t0 at least min_dist (>= 0) takes care of 2 and 3.
244
+ t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
245
+ t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
246
+ assert t0.shape == t1.shape == (batch_size, *shape, 1)
247
+ if t0_lower is not None:
248
+ assert t0.shape == t0_lower.shape
249
+ t0 = torch.maximum(t0, t0_lower)
250
+
251
+ intersected = t0 + self.min_t_range < t1
252
+ t0 = torch.where(intersected, t0, torch.zeros_like(t0))
253
+ t1 = torch.where(intersected, t1, torch.ones_like(t1))
254
+
255
+ return VolumeRange(t0=t0, t1=t1, intersected=intersected)
256
+
257
+
258
+ class StratifiedRaySampler(nn.Module):
259
+ """
260
+ Instead of fixed intervals, a sample is drawn uniformly at random from each interval.
261
+ """
262
+
263
+ def __init__(self, depth_mode: str = "linear"):
264
+ """
265
+ :param depth_mode: linear samples ts linearly in depth. harmonic ensures
266
+ closer points are sampled more densely.
267
+ """
268
+ self.depth_mode = depth_mode
269
+ assert self.depth_mode in ("linear", "geometric", "harmonic")
270
+
271
+ def sample(
272
+ self,
273
+ t0: torch.Tensor,
274
+ t1: torch.Tensor,
275
+ n_samples: int,
276
+ epsilon: float = 1e-3,
277
+ ) -> torch.Tensor:
278
+ """
279
+ Args:
280
+ t0: start time has shape [batch_size, *shape, 1]
281
+ t1: finish time has shape [batch_size, *shape, 1]
282
+ n_samples: number of ts to sample
283
+ Return:
284
+ sampled ts of shape [batch_size, *shape, n_samples, 1]
285
+ """
286
+ ones = [1] * (len(t0.shape) - 1)
287
+ ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
288
+
289
+ if self.depth_mode == "linear":
290
+ ts = t0 * (1.0 - ts) + t1 * ts
291
+ elif self.depth_mode == "geometric":
292
+ ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
293
+ elif self.depth_mode == "harmonic":
294
+ # The original NeRF recommends this interpolation scheme for
295
+ # spherical scenes, but there could be some weird edge cases when
296
+ # the observer crosses from the inner to outer volume.
297
+ ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
298
+
299
+ mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
300
+ upper = torch.cat([mids, t1], dim=-1)
301
+ lower = torch.cat([t0, mids], dim=-1)
302
+ # yiyi notes: add a random seed here for testing, don't forget to remove
303
+ torch.manual_seed(0)
304
+ t_rand = torch.rand_like(ts)
305
+
306
+ ts = lower + (upper - lower) * t_rand
307
+ return ts.unsqueeze(-1)
308
+
309
+
310
+ class ImportanceRaySampler(nn.Module):
311
+ """
312
+ Given the initial estimate of densities, this samples more from regions/bins expected to have objects.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ volume_range: VolumeRange,
318
+ ts: torch.Tensor,
319
+ weights: torch.Tensor,
320
+ blur_pool: bool = False,
321
+ alpha: float = 1e-5,
322
+ ):
323
+ """
324
+ Args:
325
+ volume_range: the range in which a ray intersects the given volume.
326
+ ts: earlier samples from the coarse rendering step
327
+ weights: discretized version of density * transmittance
328
+ blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
329
+ alpha: small value to add to weights.
330
+ """
331
+ self.volume_range = volume_range
332
+ self.ts = ts.clone().detach()
333
+ self.weights = weights.clone().detach()
334
+ self.blur_pool = blur_pool
335
+ self.alpha = alpha
336
+
337
+ @torch.no_grad()
338
+ def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
339
+ """
340
+ Args:
341
+ t0: start time has shape [batch_size, *shape, 1]
342
+ t1: finish time has shape [batch_size, *shape, 1]
343
+ n_samples: number of ts to sample
344
+ Return:
345
+ sampled ts of shape [batch_size, *shape, n_samples, 1]
346
+ """
347
+ lower, upper, _ = self.volume_range.partition(self.ts)
348
+
349
+ batch_size, *shape, n_coarse_samples, _ = self.ts.shape
350
+
351
+ weights = self.weights
352
+ if self.blur_pool:
353
+ padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
354
+ maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
355
+ weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
356
+ weights = weights + self.alpha
357
+ pmf = weights / weights.sum(dim=-2, keepdim=True)
358
+ inds = sample_pmf(pmf, n_samples)
359
+ assert inds.shape == (batch_size, *shape, n_samples, 1)
360
+ assert (inds >= 0).all() and (inds < n_coarse_samples).all()
361
+
362
+ t_rand = torch.rand(inds.shape, device=inds.device)
363
+ lower_ = torch.gather(lower, -2, inds)
364
+ upper_ = torch.gather(upper, -2, inds)
365
+
366
+ ts = lower_ + (upper_ - lower_) * t_rand
367
+ ts = torch.sort(ts, dim=-2).values
368
+ return ts
369
+
370
+
371
+ @dataclass
372
+ class MLPNeRFModelOutput(BaseOutput):
373
+ density: torch.Tensor
374
+ signed_distance: torch.Tensor
375
+ channels: torch.Tensor
376
+ ts: torch.Tensor
377
+
378
+
379
+ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
380
+ @register_to_config
381
+ def __init__(
382
+ self,
383
+ d_hidden: int = 256,
384
+ n_output: int = 12,
385
+ n_hidden_layers: int = 6,
386
+ act_fn: str = "swish",
387
+ insert_direction_at: int = 4,
388
+ ):
389
+ super().__init__()
390
+
391
+ # Instantiate the MLP
392
+
393
+ # Find out the dimension of encoded position and direction
394
+ dummy = torch.eye(1, 3)
395
+ d_posenc_pos = encode_position(position=dummy).shape[-1]
396
+ d_posenc_dir = encode_direction(position=dummy).shape[-1]
397
+
398
+ mlp_widths = [d_hidden] * n_hidden_layers
399
+ input_widths = [d_posenc_pos] + mlp_widths
400
+ output_widths = mlp_widths + [n_output]
401
+
402
+ if insert_direction_at is not None:
403
+ input_widths[insert_direction_at] += d_posenc_dir
404
+
405
+ self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)])
406
+
407
+ if act_fn == "swish":
408
+ # self.activation = swish
409
+ # yiyi testing:
410
+ self.activation = lambda x: F.silu(x)
411
+ else:
412
+ raise ValueError(f"Unsupported activation function {act_fn}")
413
+
414
+ self.sdf_activation = torch.tanh
415
+ self.density_activation = torch.nn.functional.relu
416
+ self.channel_activation = torch.sigmoid
417
+
418
+ def map_indices_to_keys(self, output):
419
+ h_map = {
420
+ "sdf": (0, 1),
421
+ "density_coarse": (1, 2),
422
+ "density_fine": (2, 3),
423
+ "stf": (3, 6),
424
+ "nerf_coarse": (6, 9),
425
+ "nerf_fine": (9, 12),
426
+ }
427
+
428
+ mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()}
429
+
430
+ return mapped_output
431
+
432
+ def forward(self, *, position, direction, ts, nerf_level="coarse"):
433
+ h = encode_position(position)
434
+
435
+ h_preact = h
436
+ h_directionless = None
437
+ for i, layer in enumerate(self.mlp):
438
+ if i == self.config.insert_direction_at: # 4 in the config
439
+ h_directionless = h_preact
440
+ h_direction = encode_direction(position, direction=direction)
441
+ h = torch.cat([h, h_direction], dim=-1)
442
+
443
+ h = layer(h)
444
+
445
+ h_preact = h
446
+
447
+ if i < len(self.mlp) - 1:
448
+ h = self.activation(h)
449
+
450
+ h_final = h
451
+ if h_directionless is None:
452
+ h_directionless = h_preact
453
+
454
+ activation = self.map_indices_to_keys(h_final)
455
+
456
+ if nerf_level == "coarse":
457
+ h_density = activation["density_coarse"]
458
+ h_channels = activation["nerf_coarse"]
459
+ else:
460
+ h_density = activation["density_fine"]
461
+ h_channels = activation["nerf_fine"]
462
+
463
+ density = self.density_activation(h_density)
464
+ signed_distance = self.sdf_activation(activation["sdf"])
465
+ channels = self.channel_activation(h_channels)
466
+
467
+ # yiyi notes: I think signed_distance is not used
468
+ return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts)
469
+
470
+
471
+ class ChannelsProj(nn.Module):
472
+ def __init__(
473
+ self,
474
+ *,
475
+ vectors: int,
476
+ channels: int,
477
+ d_latent: int,
478
+ ):
479
+ super().__init__()
480
+ self.proj = nn.Linear(d_latent, vectors * channels)
481
+ self.norm = nn.LayerNorm(channels)
482
+ self.d_latent = d_latent
483
+ self.vectors = vectors
484
+ self.channels = channels
485
+
486
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
487
+ x_bvd = x
488
+ w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
489
+ b_vc = self.proj.bias.view(1, self.vectors, self.channels)
490
+ h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
491
+ h = self.norm(h)
492
+
493
+ h = h + b_vc
494
+ return h
495
+
496
+
497
+ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
498
+ """
499
+ project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP).
500
+
501
+ For more details, see the original paper:
502
+ """
503
+
504
+ @register_to_config
505
+ def __init__(
506
+ self,
507
+ *,
508
+ param_names: Tuple[str] = (
509
+ "nerstf.mlp.0.weight",
510
+ "nerstf.mlp.1.weight",
511
+ "nerstf.mlp.2.weight",
512
+ "nerstf.mlp.3.weight",
513
+ ),
514
+ param_shapes: Tuple[Tuple[int]] = (
515
+ (256, 93),
516
+ (256, 256),
517
+ (256, 256),
518
+ (256, 256),
519
+ ),
520
+ d_latent: int = 1024,
521
+ ):
522
+ super().__init__()
523
+
524
+ # check inputs
525
+ if len(param_names) != len(param_shapes):
526
+ raise ValueError("Must provide same number of `param_names` as `param_shapes`")
527
+ self.projections = nn.ModuleDict({})
528
+ for k, (vectors, channels) in zip(param_names, param_shapes):
529
+ self.projections[_sanitize_name(k)] = ChannelsProj(
530
+ vectors=vectors,
531
+ channels=channels,
532
+ d_latent=d_latent,
533
+ )
534
+
535
+ def forward(self, x: torch.Tensor):
536
+ out = {}
537
+ start = 0
538
+ for k, shape in zip(self.config.param_names, self.config.param_shapes):
539
+ vectors, _ = shape
540
+ end = start + vectors
541
+ x_bvd = x[:, start:end]
542
+ out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
543
+ start = end
544
+ return out
545
+
546
+
547
+ class ShapERenderer(ModelMixin, ConfigMixin):
548
+ @register_to_config
549
+ def __init__(
550
+ self,
551
+ *,
552
+ param_names: Tuple[str] = (
553
+ "nerstf.mlp.0.weight",
554
+ "nerstf.mlp.1.weight",
555
+ "nerstf.mlp.2.weight",
556
+ "nerstf.mlp.3.weight",
557
+ ),
558
+ param_shapes: Tuple[Tuple[int]] = (
559
+ (256, 93),
560
+ (256, 256),
561
+ (256, 256),
562
+ (256, 256),
563
+ ),
564
+ d_latent: int = 1024,
565
+ d_hidden: int = 256,
566
+ n_output: int = 12,
567
+ n_hidden_layers: int = 6,
568
+ act_fn: str = "swish",
569
+ insert_direction_at: int = 4,
570
+ background: Tuple[float] = (
571
+ 255.0,
572
+ 255.0,
573
+ 255.0,
574
+ ),
575
+ ):
576
+ super().__init__()
577
+
578
+ self.params_proj = ShapEParamsProjModel(
579
+ param_names=param_names,
580
+ param_shapes=param_shapes,
581
+ d_latent=d_latent,
582
+ )
583
+ self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
584
+ self.void = VoidNeRFModel(background=background, channel_scale=255.0)
585
+ self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
586
+
587
+ @torch.no_grad()
588
+ def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
589
+ """
590
+ Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below
591
+ with some abuse of notations)
592
+
593
+ C(r) := sum(
594
+ transmittance(t[i]) * integrate(
595
+ lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]],
596
+ ) for i in range(len(parts))
597
+ ) + transmittance(t[-1]) * void_model(t[-1]).channels
598
+
599
+ where
600
+
601
+ 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through
602
+ the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are
603
+ obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t
604
+ where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the
605
+ shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and
606
+ transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1],
607
+ math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
608
+
609
+ args:
610
+ rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples:
611
+ number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including
612
+
613
+ :return: A tuple of
614
+ - `channels`
615
+ - A importance samplers for additional fine-grained rendering
616
+ - raw model output
617
+ """
618
+ origin, direction = rays[..., 0, :], rays[..., 1, :]
619
+
620
+ # Integrate over [t[i], t[i + 1]]
621
+
622
+ # 1 Intersect the rays with the current volume and sample ts to integrate along.
623
+ vrange = self.volume.intersect(origin, direction, t0_lower=None)
624
+ ts = sampler.sample(vrange.t0, vrange.t1, n_samples)
625
+ ts = ts.to(rays.dtype)
626
+
627
+ if prev_model_out is not None:
628
+ # Append the previous ts now before fprop because previous
629
+ # rendering used a different model and we can't reuse the output.
630
+ ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values
631
+
632
+ batch_size, *_shape, _t0_dim = vrange.t0.shape
633
+ _, *ts_shape, _ts_dim = ts.shape
634
+
635
+ # 2. Get the points along the ray and query the model
636
+ directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
637
+ positions = origin.unsqueeze(-2) + ts * directions
638
+
639
+ directions = directions.to(self.mlp.dtype)
640
+ positions = positions.to(self.mlp.dtype)
641
+
642
+ optional_directions = directions if render_with_direction else None
643
+
644
+ model_out = self.mlp(
645
+ position=positions,
646
+ direction=optional_directions,
647
+ ts=ts,
648
+ nerf_level="coarse" if prev_model_out is None else "fine",
649
+ )
650
+
651
+ # 3. Integrate the model results
652
+ channels, weights, transmittance = integrate_samples(
653
+ vrange, model_out.ts, model_out.density, model_out.channels
654
+ )
655
+
656
+ # 4. Clean up results that do not intersect with the volume.
657
+ transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance))
658
+ channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels))
659
+ # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
660
+ channels = channels + transmittance * self.void(origin)
661
+
662
+ weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights)
663
+
664
+ return channels, weighted_sampler, model_out
665
+
666
+ @torch.no_grad()
667
+ def decode(
668
+ self,
669
+ latents,
670
+ device,
671
+ size: int = 64,
672
+ ray_batch_size: int = 4096,
673
+ n_coarse_samples=64,
674
+ n_fine_samples=128,
675
+ ):
676
+ # project the the paramters from the generated latents
677
+ projected_params = self.params_proj(latents)
678
+
679
+ # update the mlp layers of the renderer
680
+ for name, param in self.mlp.state_dict().items():
681
+ if f"nerstf.{name}" in projected_params.keys():
682
+ param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
683
+
684
+ # create cameras object
685
+ camera = create_pan_cameras(size)
686
+ rays = camera.camera_rays
687
+ rays = rays.to(device)
688
+ n_batches = rays.shape[1] // ray_batch_size
689
+
690
+ coarse_sampler = StratifiedRaySampler()
691
+
692
+ images = []
693
+
694
+ for idx in range(n_batches):
695
+ rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
696
+
697
+ # render rays with coarse, stratified samples.
698
+ _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples)
699
+ # Then, render with additional importance-weighted ray samples.
700
+ channels, _, _ = self.render_rays(
701
+ rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out
702
+ )
703
+
704
+ images.append(channels)
705
+
706
+ images = torch.cat(images, dim=1)
707
+ images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
708
+
709
+ return images