diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,59 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel
5
+
6
+ from ...utils import logging
7
+
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ class IFSafetyChecker(PreTrainedModel):
13
+ config_class = CLIPConfig
14
+
15
+ _no_split_modules = ["CLIPEncoderLayer"]
16
+
17
+ def __init__(self, config: CLIPConfig):
18
+ super().__init__(config)
19
+
20
+ self.vision_model = CLIPVisionModelWithProjection(config.vision_config)
21
+
22
+ self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
23
+ self.w_head = nn.Linear(config.vision_config.projection_dim, 1)
24
+
25
+ @torch.no_grad()
26
+ def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
27
+ image_embeds = self.vision_model(clip_input)[0]
28
+
29
+ nsfw_detected = self.p_head(image_embeds)
30
+ nsfw_detected = nsfw_detected.flatten()
31
+ nsfw_detected = nsfw_detected > p_threshold
32
+ nsfw_detected = nsfw_detected.tolist()
33
+
34
+ if any(nsfw_detected):
35
+ logger.warning(
36
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
37
+ " Try again with a different prompt and/or seed."
38
+ )
39
+
40
+ for idx, nsfw_detected_ in enumerate(nsfw_detected):
41
+ if nsfw_detected_:
42
+ images[idx] = np.zeros(images[idx].shape)
43
+
44
+ watermark_detected = self.w_head(image_embeds)
45
+ watermark_detected = watermark_detected.flatten()
46
+ watermark_detected = watermark_detected > w_threshold
47
+ watermark_detected = watermark_detected.tolist()
48
+
49
+ if any(watermark_detected):
50
+ logger.warning(
51
+ "Potential watermarked content was detected in one or more images. A black image will be returned instead."
52
+ " Try again with a different prompt and/or seed."
53
+ )
54
+
55
+ for idx, watermark_detected_ in enumerate(watermark_detected):
56
+ if watermark_detected_:
57
+ images[idx] = np.zeros(images[idx].shape)
58
+
59
+ return images, nsfw_detected, watermark_detected
@@ -0,0 +1,579 @@
1
+ fast27_timesteps = [
2
+ 999,
3
+ 800,
4
+ 799,
5
+ 600,
6
+ 599,
7
+ 500,
8
+ 400,
9
+ 399,
10
+ 377,
11
+ 355,
12
+ 333,
13
+ 311,
14
+ 288,
15
+ 266,
16
+ 244,
17
+ 222,
18
+ 200,
19
+ 199,
20
+ 177,
21
+ 155,
22
+ 133,
23
+ 111,
24
+ 88,
25
+ 66,
26
+ 44,
27
+ 22,
28
+ 0,
29
+ ]
30
+
31
+ smart27_timesteps = [
32
+ 999,
33
+ 976,
34
+ 952,
35
+ 928,
36
+ 905,
37
+ 882,
38
+ 858,
39
+ 857,
40
+ 810,
41
+ 762,
42
+ 715,
43
+ 714,
44
+ 572,
45
+ 429,
46
+ 428,
47
+ 286,
48
+ 285,
49
+ 238,
50
+ 190,
51
+ 143,
52
+ 142,
53
+ 118,
54
+ 95,
55
+ 71,
56
+ 47,
57
+ 24,
58
+ 0,
59
+ ]
60
+
61
+ smart50_timesteps = [
62
+ 999,
63
+ 988,
64
+ 977,
65
+ 966,
66
+ 955,
67
+ 944,
68
+ 933,
69
+ 922,
70
+ 911,
71
+ 900,
72
+ 899,
73
+ 879,
74
+ 859,
75
+ 840,
76
+ 820,
77
+ 800,
78
+ 799,
79
+ 766,
80
+ 733,
81
+ 700,
82
+ 699,
83
+ 650,
84
+ 600,
85
+ 599,
86
+ 500,
87
+ 499,
88
+ 400,
89
+ 399,
90
+ 350,
91
+ 300,
92
+ 299,
93
+ 266,
94
+ 233,
95
+ 200,
96
+ 199,
97
+ 179,
98
+ 159,
99
+ 140,
100
+ 120,
101
+ 100,
102
+ 99,
103
+ 88,
104
+ 77,
105
+ 66,
106
+ 55,
107
+ 44,
108
+ 33,
109
+ 22,
110
+ 11,
111
+ 0,
112
+ ]
113
+
114
+ smart100_timesteps = [
115
+ 999,
116
+ 995,
117
+ 992,
118
+ 989,
119
+ 985,
120
+ 981,
121
+ 978,
122
+ 975,
123
+ 971,
124
+ 967,
125
+ 964,
126
+ 961,
127
+ 957,
128
+ 956,
129
+ 951,
130
+ 947,
131
+ 942,
132
+ 937,
133
+ 933,
134
+ 928,
135
+ 923,
136
+ 919,
137
+ 914,
138
+ 913,
139
+ 908,
140
+ 903,
141
+ 897,
142
+ 892,
143
+ 887,
144
+ 881,
145
+ 876,
146
+ 871,
147
+ 870,
148
+ 864,
149
+ 858,
150
+ 852,
151
+ 846,
152
+ 840,
153
+ 834,
154
+ 828,
155
+ 827,
156
+ 820,
157
+ 813,
158
+ 806,
159
+ 799,
160
+ 792,
161
+ 785,
162
+ 784,
163
+ 777,
164
+ 770,
165
+ 763,
166
+ 756,
167
+ 749,
168
+ 742,
169
+ 741,
170
+ 733,
171
+ 724,
172
+ 716,
173
+ 707,
174
+ 699,
175
+ 698,
176
+ 688,
177
+ 677,
178
+ 666,
179
+ 656,
180
+ 655,
181
+ 645,
182
+ 634,
183
+ 623,
184
+ 613,
185
+ 612,
186
+ 598,
187
+ 584,
188
+ 570,
189
+ 569,
190
+ 555,
191
+ 541,
192
+ 527,
193
+ 526,
194
+ 505,
195
+ 484,
196
+ 483,
197
+ 462,
198
+ 440,
199
+ 439,
200
+ 396,
201
+ 395,
202
+ 352,
203
+ 351,
204
+ 308,
205
+ 307,
206
+ 264,
207
+ 263,
208
+ 220,
209
+ 219,
210
+ 176,
211
+ 132,
212
+ 88,
213
+ 44,
214
+ 0,
215
+ ]
216
+
217
+ smart185_timesteps = [
218
+ 999,
219
+ 997,
220
+ 995,
221
+ 992,
222
+ 990,
223
+ 988,
224
+ 986,
225
+ 984,
226
+ 981,
227
+ 979,
228
+ 977,
229
+ 975,
230
+ 972,
231
+ 970,
232
+ 968,
233
+ 966,
234
+ 964,
235
+ 961,
236
+ 959,
237
+ 957,
238
+ 956,
239
+ 954,
240
+ 951,
241
+ 949,
242
+ 946,
243
+ 944,
244
+ 941,
245
+ 939,
246
+ 936,
247
+ 934,
248
+ 931,
249
+ 929,
250
+ 926,
251
+ 924,
252
+ 921,
253
+ 919,
254
+ 916,
255
+ 914,
256
+ 913,
257
+ 910,
258
+ 907,
259
+ 905,
260
+ 902,
261
+ 899,
262
+ 896,
263
+ 893,
264
+ 891,
265
+ 888,
266
+ 885,
267
+ 882,
268
+ 879,
269
+ 877,
270
+ 874,
271
+ 871,
272
+ 870,
273
+ 867,
274
+ 864,
275
+ 861,
276
+ 858,
277
+ 855,
278
+ 852,
279
+ 849,
280
+ 846,
281
+ 843,
282
+ 840,
283
+ 837,
284
+ 834,
285
+ 831,
286
+ 828,
287
+ 827,
288
+ 824,
289
+ 821,
290
+ 817,
291
+ 814,
292
+ 811,
293
+ 808,
294
+ 804,
295
+ 801,
296
+ 798,
297
+ 795,
298
+ 791,
299
+ 788,
300
+ 785,
301
+ 784,
302
+ 780,
303
+ 777,
304
+ 774,
305
+ 770,
306
+ 766,
307
+ 763,
308
+ 760,
309
+ 756,
310
+ 752,
311
+ 749,
312
+ 746,
313
+ 742,
314
+ 741,
315
+ 737,
316
+ 733,
317
+ 730,
318
+ 726,
319
+ 722,
320
+ 718,
321
+ 714,
322
+ 710,
323
+ 707,
324
+ 703,
325
+ 699,
326
+ 698,
327
+ 694,
328
+ 690,
329
+ 685,
330
+ 681,
331
+ 677,
332
+ 673,
333
+ 669,
334
+ 664,
335
+ 660,
336
+ 656,
337
+ 655,
338
+ 650,
339
+ 646,
340
+ 641,
341
+ 636,
342
+ 632,
343
+ 627,
344
+ 622,
345
+ 618,
346
+ 613,
347
+ 612,
348
+ 607,
349
+ 602,
350
+ 596,
351
+ 591,
352
+ 586,
353
+ 580,
354
+ 575,
355
+ 570,
356
+ 569,
357
+ 563,
358
+ 557,
359
+ 551,
360
+ 545,
361
+ 539,
362
+ 533,
363
+ 527,
364
+ 526,
365
+ 519,
366
+ 512,
367
+ 505,
368
+ 498,
369
+ 491,
370
+ 484,
371
+ 483,
372
+ 474,
373
+ 466,
374
+ 457,
375
+ 449,
376
+ 440,
377
+ 439,
378
+ 428,
379
+ 418,
380
+ 407,
381
+ 396,
382
+ 395,
383
+ 381,
384
+ 366,
385
+ 352,
386
+ 351,
387
+ 330,
388
+ 308,
389
+ 307,
390
+ 286,
391
+ 264,
392
+ 263,
393
+ 242,
394
+ 220,
395
+ 219,
396
+ 176,
397
+ 175,
398
+ 132,
399
+ 131,
400
+ 88,
401
+ 44,
402
+ 0,
403
+ ]
404
+
405
+ super27_timesteps = [
406
+ 999,
407
+ 991,
408
+ 982,
409
+ 974,
410
+ 966,
411
+ 958,
412
+ 950,
413
+ 941,
414
+ 933,
415
+ 925,
416
+ 916,
417
+ 908,
418
+ 900,
419
+ 899,
420
+ 874,
421
+ 850,
422
+ 825,
423
+ 800,
424
+ 799,
425
+ 700,
426
+ 600,
427
+ 500,
428
+ 400,
429
+ 300,
430
+ 200,
431
+ 100,
432
+ 0,
433
+ ]
434
+
435
+ super40_timesteps = [
436
+ 999,
437
+ 992,
438
+ 985,
439
+ 978,
440
+ 971,
441
+ 964,
442
+ 957,
443
+ 949,
444
+ 942,
445
+ 935,
446
+ 928,
447
+ 921,
448
+ 914,
449
+ 907,
450
+ 900,
451
+ 899,
452
+ 879,
453
+ 859,
454
+ 840,
455
+ 820,
456
+ 800,
457
+ 799,
458
+ 766,
459
+ 733,
460
+ 700,
461
+ 699,
462
+ 650,
463
+ 600,
464
+ 599,
465
+ 500,
466
+ 499,
467
+ 400,
468
+ 399,
469
+ 300,
470
+ 299,
471
+ 200,
472
+ 199,
473
+ 100,
474
+ 99,
475
+ 0,
476
+ ]
477
+
478
+ super100_timesteps = [
479
+ 999,
480
+ 996,
481
+ 992,
482
+ 989,
483
+ 985,
484
+ 982,
485
+ 979,
486
+ 975,
487
+ 972,
488
+ 968,
489
+ 965,
490
+ 961,
491
+ 958,
492
+ 955,
493
+ 951,
494
+ 948,
495
+ 944,
496
+ 941,
497
+ 938,
498
+ 934,
499
+ 931,
500
+ 927,
501
+ 924,
502
+ 920,
503
+ 917,
504
+ 914,
505
+ 910,
506
+ 907,
507
+ 903,
508
+ 900,
509
+ 899,
510
+ 891,
511
+ 884,
512
+ 876,
513
+ 869,
514
+ 861,
515
+ 853,
516
+ 846,
517
+ 838,
518
+ 830,
519
+ 823,
520
+ 815,
521
+ 808,
522
+ 800,
523
+ 799,
524
+ 788,
525
+ 777,
526
+ 766,
527
+ 755,
528
+ 744,
529
+ 733,
530
+ 722,
531
+ 711,
532
+ 700,
533
+ 699,
534
+ 688,
535
+ 677,
536
+ 666,
537
+ 655,
538
+ 644,
539
+ 633,
540
+ 622,
541
+ 611,
542
+ 600,
543
+ 599,
544
+ 585,
545
+ 571,
546
+ 557,
547
+ 542,
548
+ 528,
549
+ 514,
550
+ 500,
551
+ 499,
552
+ 485,
553
+ 471,
554
+ 457,
555
+ 442,
556
+ 428,
557
+ 414,
558
+ 400,
559
+ 399,
560
+ 379,
561
+ 359,
562
+ 340,
563
+ 320,
564
+ 300,
565
+ 299,
566
+ 279,
567
+ 259,
568
+ 240,
569
+ 220,
570
+ 200,
571
+ 199,
572
+ 166,
573
+ 133,
574
+ 100,
575
+ 99,
576
+ 66,
577
+ 33,
578
+ 0,
579
+ ]
@@ -0,0 +1,46 @@
1
+ from typing import List
2
+
3
+ import PIL
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from ...configuration_utils import ConfigMixin
8
+ from ...models.modeling_utils import ModelMixin
9
+ from ...utils import PIL_INTERPOLATION
10
+
11
+
12
+ class IFWatermarker(ModelMixin, ConfigMixin):
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ self.register_buffer("watermark_image", torch.zeros((62, 62, 4)))
17
+ self.watermark_image_as_pil = None
18
+
19
+ def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None):
20
+ # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287
21
+
22
+ h = images[0].height
23
+ w = images[0].width
24
+
25
+ sample_size = sample_size or h
26
+
27
+ coef = min(h / sample_size, w / sample_size)
28
+ img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)
29
+
30
+ S1, S2 = 1024**2, img_w * img_h
31
+ K = (S2 / S1) ** 0.5
32
+ wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)
33
+
34
+ if self.watermark_image_as_pil is None:
35
+ watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy()
36
+ watermark_image = Image.fromarray(watermark_image, mode="RGBA")
37
+ self.watermark_image_as_pil = watermark_image
38
+
39
+ wm_img = self.watermark_image_as_pil.resize(
40
+ (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None
41
+ )
42
+
43
+ for pil_img in images:
44
+ pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])
45
+
46
+ return images