diffsynth-engine 0.3.6.dev8__py3-none-any.whl → 0.3.6.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from contextlib import contextmanager
5
+ from diffsynth_engine.utils.platform import DTYPE_FP8
5
6
 
6
7
 
7
8
  def enable_fp8_autocast(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False):
@@ -51,7 +52,7 @@ def enable_fp8_linear(module: nn.Module):
51
52
  def _enable_fp8_linear(module: nn.Module):
52
53
  if isinstance(module, nn.Linear) and torch.is_floating_point(module.weight.data):
53
54
  # avoid conversion for int weights like GGUF
54
- module.weight.data = module.weight.data.to(torch.float8_e4m3fn)
55
+ module.weight.data = module.weight.data.to(DTYPE_FP8)
55
56
  for submodule in module.children():
56
57
  _enable_fp8_linear(submodule)
57
58
 
@@ -71,8 +72,16 @@ def fp8_inference(enabled=True):
71
72
  ) -> torch.Tensor:
72
73
  device = input.device
73
74
  origin_dtype = input.dtype
74
- input = input.to(torch.float8_e4m3fn)
75
- weight = weight.to(torch.float8_e4m3fn)
75
+ scale_a = 1.0
76
+ # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
77
+ # To avoid overflow and ensure numerical compatibility during FP8 computation,
78
+ # we scale down the input by 2.0 in advance.
79
+ # This scaling will be compensated later during the final result scaling.
80
+ if DTYPE_FP8 == torch.float8_e4m3fnuz:
81
+ scale_a = 2.0
82
+ input = input / scale_a
83
+ input = input.to(DTYPE_FP8)
84
+ weight = weight.to(DTYPE_FP8)
76
85
 
77
86
  if len(input.shape) > 2:
78
87
  origin_shape = input.shape
@@ -80,7 +89,7 @@ def fp8_inference(enabled=True):
80
89
  result = torch._scaled_mm(
81
90
  input,
82
91
  weight.T,
83
- scale_a=torch.tensor(1.0).to(device=device),
92
+ scale_a=torch.tensor(scale_a).to(device=device),
84
93
  scale_b=torch.tensor(1.0).to(device=device),
85
94
  bias=bias,
86
95
  out_dtype=origin_dtype,
@@ -91,7 +100,7 @@ def fp8_inference(enabled=True):
91
100
  result = torch._scaled_mm(
92
101
  input,
93
102
  weight.T,
94
- scale_a=torch.tensor(1.0).to(device=device),
103
+ scale_a=torch.tensor(scale_a).to(device=device),
95
104
  scale_b=torch.tensor(1.0).to(device=device),
96
105
  bias=bias,
97
106
  out_dtype=origin_dtype,
@@ -1,7 +1,15 @@
1
+ # cross-platform definitions and utilities
1
2
  import torch
2
3
  import gc
3
4
 
4
- # 存放跨平台的工具类
5
+
6
+ # data type
7
+ # AMD only supports float8_e4m3fnuz
8
+ # https://onnx.ai/onnx/technical/float8.html
9
+ if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
10
+ DTYPE_FP8 = torch.float8_e4m3fnuz
11
+ else:
12
+ DTYPE_FP8 = torch.float8_e4m3fn
5
13
 
6
14
 
7
15
  def empty_cache():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.3.6.dev8
3
+ Version: 0.3.6.dev9
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -129,7 +129,7 @@ diffsynth_engine/utils/constants.py,sha256=L7sIxGNMfCvcZG66ul7GIT6fDctkcwhePAjMj
129
129
  diffsynth_engine/utils/download.py,sha256=NCgfL9tUca-sOhT41k6w4o__Ktbw-1aDwFTR4JDkT28,5639
130
130
  diffsynth_engine/utils/env.py,sha256=43x-kBjt5zI2cwZ9G4BOeTbedi2k6TuBzHGOBeFbFvU,280
131
131
  diffsynth_engine/utils/flag.py,sha256=6zQLnoEaU69pBEyhavCgydQfP0khw5ppCU7sue4yRqg,1370
132
- diffsynth_engine/utils/fp8_linear.py,sha256=qu6Hzi7dqmDFgtoP-Uf0p-GDKW03AK9338YeLuzw2nw,3589
132
+ diffsynth_engine/utils/fp8_linear.py,sha256=NosnWMoAr_IpFcLn-OYbAx-vXySphjxutDZqmXLNjJI,4064
133
133
  diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
134
134
  diffsynth_engine/utils/image.py,sha256=_46CVs1Qe7GdZNulWWJISnR_Y6FotC2tZGLKtr04gIE,562
135
135
  diffsynth_engine/utils/loader.py,sha256=Z5v1WNDWFY0OrVubB70j5VU3zeaAfEK_j8c1KrGI4yM,1240
@@ -138,11 +138,11 @@ diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vm
138
138
  diffsynth_engine/utils/offload.py,sha256=jUR4u7J60o4KZIRxHhMCwaeDkiXJvBa0KJkYKKT6mrg,1587
139
139
  diffsynth_engine/utils/onnx.py,sha256=jeWUudJHnESjuiEAHyUZYUZz7dCj34O9aGjHCe8yjWo,1149
140
140
  diffsynth_engine/utils/parallel.py,sha256=2WISMBTTmW0v2qPvpms421-B59v3bYlS6YrLq9BZ5Zo,16909
141
- diffsynth_engine/utils/platform.py,sha256=q9ifmdzoa66Cj9YKfwps21DsDdwA0JGpwroKQbG6shU,224
141
+ diffsynth_engine/utils/platform.py,sha256=2lXdw6YkqcRONCeT98n4cyg1Ii8Ybbyj2Ns72Se9tlk,496
142
142
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
143
143
  diffsynth_engine/utils/video.py,sha256=Ne0rd2lb59UT1q5EotpjlY7OT8F9oTCFDyo1ST77uoQ,1004
144
- diffsynth_engine-0.3.6.dev8.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
145
- diffsynth_engine-0.3.6.dev8.dist-info/METADATA,sha256=bZRH8-guipkEhR276zMywqxu72Ayr-c2XPttJedLZ2o,1068
146
- diffsynth_engine-0.3.6.dev8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
147
- diffsynth_engine-0.3.6.dev8.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
148
- diffsynth_engine-0.3.6.dev8.dist-info/RECORD,,
144
+ diffsynth_engine-0.3.6.dev9.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
145
+ diffsynth_engine-0.3.6.dev9.dist-info/METADATA,sha256=k33lHBGOXqN3YNxtge0TI6C3ICnOhGAxwXCEtfr3kTY,1068
146
+ diffsynth_engine-0.3.6.dev9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
147
+ diffsynth_engine-0.3.6.dev9.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
148
+ diffsynth_engine-0.3.6.dev9.dist-info/RECORD,,