diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+
22
+ @dataclass
23
+ class DifferentiableProjectiveCamera:
24
+ """
25
+ Implements a batch, differentiable, standard pinhole camera
26
+ """
27
+
28
+ origin: torch.Tensor # [batch_size x 3]
29
+ x: torch.Tensor # [batch_size x 3]
30
+ y: torch.Tensor # [batch_size x 3]
31
+ z: torch.Tensor # [batch_size x 3]
32
+ width: int
33
+ height: int
34
+ x_fov: float
35
+ y_fov: float
36
+ shape: Tuple[int]
37
+
38
+ def __post_init__(self):
39
+ assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
40
+ assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
41
+ assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2
42
+
43
+ def resolution(self):
44
+ return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
45
+
46
+ def fov(self):
47
+ return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
48
+
49
+ def get_image_coords(self) -> torch.Tensor:
50
+ """
51
+ :return: coords of shape (width * height, 2)
52
+ """
53
+ pixel_indices = torch.arange(self.height * self.width)
54
+ coords = torch.stack(
55
+ [
56
+ pixel_indices % self.width,
57
+ torch.div(pixel_indices, self.width, rounding_mode="trunc"),
58
+ ],
59
+ axis=1,
60
+ )
61
+ return coords
62
+
63
+ @property
64
+ def camera_rays(self):
65
+ batch_size, *inner_shape = self.shape
66
+ inner_batch_size = int(np.prod(inner_shape))
67
+
68
+ coords = self.get_image_coords()
69
+ coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
70
+ rays = self.get_camera_rays(coords)
71
+
72
+ rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3)
73
+
74
+ return rays
75
+
76
+ def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
77
+ batch_size, *shape, n_coords = coords.shape
78
+ assert n_coords == 2
79
+ assert batch_size == self.origin.shape[0]
80
+
81
+ flat = coords.view(batch_size, -1, 2)
82
+
83
+ res = self.resolution()
84
+ fov = self.fov()
85
+
86
+ fracs = (flat.float() / (res - 1)) * 2 - 1
87
+ fracs = fracs * torch.tan(fov / 2)
88
+
89
+ fracs = fracs.view(batch_size, -1, 2)
90
+ directions = (
91
+ self.z.view(batch_size, 1, 3)
92
+ + self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
93
+ + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
94
+ )
95
+ directions = directions / directions.norm(dim=-1, keepdim=True)
96
+ rays = torch.stack(
97
+ [
98
+ torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]),
99
+ directions,
100
+ ],
101
+ dim=2,
102
+ )
103
+ return rays.view(batch_size, *shape, 2, 3)
104
+
105
+ def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
106
+ """
107
+ Creates a new camera for the resized view assuming the aspect ratio does not change.
108
+ """
109
+ assert width * self.height == height * self.width, "The aspect ratio should not change."
110
+ return DifferentiableProjectiveCamera(
111
+ origin=self.origin,
112
+ x=self.x,
113
+ y=self.y,
114
+ z=self.z,
115
+ width=width,
116
+ height=height,
117
+ x_fov=self.x_fov,
118
+ y_fov=self.y_fov,
119
+ )
120
+
121
+
122
+ def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera:
123
+ origins = []
124
+ xs = []
125
+ ys = []
126
+ zs = []
127
+ for theta in np.linspace(0, 2 * np.pi, num=20):
128
+ z = np.array([np.sin(theta), np.cos(theta), -0.5])
129
+ z /= np.sqrt(np.sum(z**2))
130
+ origin = -z * 4
131
+ x = np.array([np.cos(theta), -np.sin(theta), 0.0])
132
+ y = np.cross(z, x)
133
+ origins.append(origin)
134
+ xs.append(x)
135
+ ys.append(y)
136
+ zs.append(z)
137
+ return DifferentiableProjectiveCamera(
138
+ origin=torch.from_numpy(np.stack(origins, axis=0)).float(),
139
+ x=torch.from_numpy(np.stack(xs, axis=0)).float(),
140
+ y=torch.from_numpy(np.stack(ys, axis=0)).float(),
141
+ z=torch.from_numpy(np.stack(zs, axis=0)).float(),
142
+ width=size,
143
+ height=size,
144
+ x_fov=0.7,
145
+ y_fov=0.7,
146
+ shape=(1, len(xs)),
147
+ )
@@ -0,0 +1,390 @@
1
+ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Union
18
+
19
+ import numpy as np
20
+ import PIL
21
+ import torch
22
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
23
+
24
+ from ...models import PriorTransformer
25
+ from ...pipelines import DiffusionPipeline
26
+ from ...schedulers import HeunDiscreteScheduler
27
+ from ...utils import (
28
+ BaseOutput,
29
+ is_accelerate_available,
30
+ is_accelerate_version,
31
+ logging,
32
+ randn_tensor,
33
+ replace_example_docstring,
34
+ )
35
+ from .renderer import ShapERenderer
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> import torch
44
+ >>> from diffusers import DiffusionPipeline
45
+ >>> from diffusers.utils import export_to_gif
46
+
47
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+
49
+ >>> repo = "openai/shap-e"
50
+ >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
51
+ >>> pipe = pipe.to(device)
52
+
53
+ >>> guidance_scale = 15.0
54
+ >>> prompt = "a shark"
55
+
56
+ >>> images = pipe(
57
+ ... prompt,
58
+ ... guidance_scale=guidance_scale,
59
+ ... num_inference_steps=64,
60
+ ... frame_size=256,
61
+ ... ).images
62
+
63
+ >>> gif_path = export_to_gif(images[0], "shark_3d.gif")
64
+ ```
65
+ """
66
+
67
+
68
+ @dataclass
69
+ class ShapEPipelineOutput(BaseOutput):
70
+ """
71
+ Output class for ShapEPipeline.
72
+
73
+ Args:
74
+ images (`torch.FloatTensor`)
75
+ a list of images for 3D rendering
76
+ """
77
+
78
+ images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]]
79
+
80
+
81
+ class ShapEPipeline(DiffusionPipeline):
82
+ """
83
+ Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
84
+
85
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
86
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
87
+
88
+ Args:
89
+ prior ([`PriorTransformer`]):
90
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
91
+ text_encoder ([`CLIPTextModelWithProjection`]):
92
+ Frozen text-encoder.
93
+ tokenizer (`CLIPTokenizer`):
94
+ Tokenizer of class
95
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
96
+ scheduler ([`HeunDiscreteScheduler`]):
97
+ A scheduler to be used in combination with `prior` to generate image embedding.
98
+ renderer ([`ShapERenderer`]):
99
+ Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
100
+ with the NeRF rendering method
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ prior: PriorTransformer,
106
+ text_encoder: CLIPTextModelWithProjection,
107
+ tokenizer: CLIPTokenizer,
108
+ scheduler: HeunDiscreteScheduler,
109
+ renderer: ShapERenderer,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.register_modules(
114
+ prior=prior,
115
+ text_encoder=text_encoder,
116
+ tokenizer=tokenizer,
117
+ scheduler=scheduler,
118
+ renderer=renderer,
119
+ )
120
+
121
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
122
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
123
+ if latents is None:
124
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
125
+ else:
126
+ if latents.shape != shape:
127
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
128
+ latents = latents.to(device)
129
+
130
+ latents = latents * scheduler.init_noise_sigma
131
+ return latents
132
+
133
+ def enable_sequential_cpu_offload(self, gpu_id=0):
134
+ r"""
135
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
136
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
137
+ when their specific submodule has its `forward` method called.
138
+ """
139
+ if is_accelerate_available():
140
+ from accelerate import cpu_offload
141
+ else:
142
+ raise ImportError("Please install accelerate via `pip install accelerate`")
143
+
144
+ device = torch.device(f"cuda:{gpu_id}")
145
+
146
+ models = [self.text_encoder, self.prior]
147
+ for cpu_offloaded_model in models:
148
+ if cpu_offloaded_model is not None:
149
+ cpu_offload(cpu_offloaded_model, device)
150
+
151
+ def enable_model_cpu_offload(self, gpu_id=0):
152
+ r"""
153
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
154
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
155
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
156
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
157
+ """
158
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
159
+ from accelerate import cpu_offload_with_hook
160
+ else:
161
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
162
+
163
+ device = torch.device(f"cuda:{gpu_id}")
164
+
165
+ if self.device.type != "cpu":
166
+ self.to("cpu", silence_dtype_warnings=True)
167
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
168
+
169
+ hook = None
170
+ for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
171
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
172
+
173
+ if self.safety_checker is not None:
174
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
175
+
176
+ # We'll offload the last model manually.
177
+ self.final_offload_hook = hook
178
+
179
+ @property
180
+ def _execution_device(self):
181
+ r"""
182
+ Returns the device on which the pipeline's models will be executed. After calling
183
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
184
+ hooks.
185
+ """
186
+ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
187
+ return self.device
188
+ for module in self.text_encoder.modules():
189
+ if (
190
+ hasattr(module, "_hf_hook")
191
+ and hasattr(module._hf_hook, "execution_device")
192
+ and module._hf_hook.execution_device is not None
193
+ ):
194
+ return torch.device(module._hf_hook.execution_device)
195
+ return self.device
196
+
197
+ def _encode_prompt(
198
+ self,
199
+ prompt,
200
+ device,
201
+ num_images_per_prompt,
202
+ do_classifier_free_guidance,
203
+ ):
204
+ len(prompt) if isinstance(prompt, list) else 1
205
+
206
+ # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file
207
+ self.tokenizer.pad_token_id = 0
208
+ # get prompt text embeddings
209
+ text_inputs = self.tokenizer(
210
+ prompt,
211
+ padding="max_length",
212
+ max_length=self.tokenizer.model_max_length,
213
+ truncation=True,
214
+ return_tensors="pt",
215
+ )
216
+ text_input_ids = text_inputs.input_ids
217
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
218
+
219
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
220
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
221
+ logger.warning(
222
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
223
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
224
+ )
225
+
226
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
227
+ prompt_embeds = text_encoder_output.text_embeds
228
+
229
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
230
+ # in Shap-E it normalize the prompt_embeds and then later rescale it
231
+ prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
232
+
233
+ if do_classifier_free_guidance:
234
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
235
+
236
+ # For classifier free guidance, we need to do two forward passes.
237
+ # Here we concatenate the unconditional and text embeddings into a single batch
238
+ # to avoid doing two forward passes
239
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
240
+
241
+ # Rescale the features to have unit variance
242
+ prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
243
+
244
+ return prompt_embeds
245
+
246
+ @torch.no_grad()
247
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
248
+ def __call__(
249
+ self,
250
+ prompt: str,
251
+ num_images_per_prompt: int = 1,
252
+ num_inference_steps: int = 25,
253
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
254
+ latents: Optional[torch.FloatTensor] = None,
255
+ guidance_scale: float = 4.0,
256
+ frame_size: int = 64,
257
+ output_type: Optional[str] = "pil", # pil, np, latent
258
+ return_dict: bool = True,
259
+ ):
260
+ """
261
+ Function invoked when calling the pipeline for generation.
262
+
263
+ Args:
264
+ prompt (`str` or `List[str]`):
265
+ The prompt or prompts to guide the image generation.
266
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
267
+ The number of images to generate per prompt.
268
+ num_inference_steps (`int`, *optional*, defaults to 25):
269
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
270
+ expense of slower inference.
271
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
272
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
273
+ to make generation deterministic.
274
+ latents (`torch.FloatTensor`, *optional*):
275
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
276
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
277
+ tensor will ge generated by sampling using the supplied random `generator`.
278
+ guidance_scale (`float`, *optional*, defaults to 4.0):
279
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
280
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
281
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
282
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
283
+ usually at the expense of lower image quality.
284
+ frame_size (`int`, *optional*, default to 64):
285
+ the width and height of each image frame of the generated 3d output
286
+ output_type (`str`, *optional*, defaults to `"pt"`):
287
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
288
+ (`torch.Tensor`).
289
+ return_dict (`bool`, *optional*, defaults to `True`):
290
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
291
+
292
+ Examples:
293
+
294
+ Returns:
295
+ [`ShapEPipelineOutput`] or `tuple`
296
+ """
297
+
298
+ if isinstance(prompt, str):
299
+ batch_size = 1
300
+ elif isinstance(prompt, list):
301
+ batch_size = len(prompt)
302
+ else:
303
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
304
+
305
+ device = self._execution_device
306
+
307
+ batch_size = batch_size * num_images_per_prompt
308
+
309
+ do_classifier_free_guidance = guidance_scale > 1.0
310
+ prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
311
+
312
+ # prior
313
+
314
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
315
+ timesteps = self.scheduler.timesteps
316
+
317
+ num_embeddings = self.prior.config.num_embeddings
318
+ embedding_dim = self.prior.config.embedding_dim
319
+
320
+ latents = self.prepare_latents(
321
+ (batch_size, num_embeddings * embedding_dim),
322
+ prompt_embeds.dtype,
323
+ device,
324
+ generator,
325
+ latents,
326
+ self.scheduler,
327
+ )
328
+
329
+ # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
330
+ latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
331
+
332
+ for i, t in enumerate(self.progress_bar(timesteps)):
333
+ # expand the latents if we are doing classifier free guidance
334
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
335
+ scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
336
+
337
+ noise_pred = self.prior(
338
+ scaled_model_input,
339
+ timestep=t,
340
+ proj_embedding=prompt_embeds,
341
+ ).predicted_image_embedding
342
+
343
+ # remove the variance
344
+ noise_pred, _ = noise_pred.split(
345
+ scaled_model_input.shape[2], dim=2
346
+ ) # batch_size, num_embeddings, embedding_dim
347
+
348
+ if do_classifier_free_guidance is not None:
349
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
350
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
351
+
352
+ latents = self.scheduler.step(
353
+ noise_pred,
354
+ timestep=t,
355
+ sample=latents,
356
+ ).prev_sample
357
+
358
+ if output_type == "latent":
359
+ return ShapEPipelineOutput(images=latents)
360
+
361
+ images = []
362
+ for i, latent in enumerate(latents):
363
+ image = self.renderer.decode(
364
+ latent[None, :],
365
+ device,
366
+ size=frame_size,
367
+ ray_batch_size=4096,
368
+ n_coarse_samples=64,
369
+ n_fine_samples=128,
370
+ )
371
+ images.append(image)
372
+
373
+ images = torch.stack(images)
374
+
375
+ if output_type not in ["np", "pil"]:
376
+ raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
377
+
378
+ images = images.cpu().numpy()
379
+
380
+ if output_type == "pil":
381
+ images = [self.numpy_to_pil(image) for image in images]
382
+
383
+ # Offload last model to CPU
384
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
385
+ self.final_offload_hook.offload()
386
+
387
+ if not return_dict:
388
+ return (images,)
389
+
390
+ return ShapEPipelineOutput(images=images)