da4ml 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (65) hide show
  1. da4ml-0.2.0/.clang-format +191 -0
  2. {da4ml-0.1.2 → da4ml-0.2.0}/.gitignore +8 -0
  3. da4ml-0.2.0/PKG-INFO +65 -0
  4. da4ml-0.2.0/README.md +42 -0
  5. {da4ml-0.1.2 → da4ml-0.2.0}/pyproject.toml +2 -2
  6. da4ml-0.2.0/src/da4ml/__init__.py +17 -0
  7. {da4ml-0.1.2 → da4ml-0.2.0}/src/da4ml/_version.py +2 -2
  8. da4ml-0.2.0/src/da4ml/cmvm/__init__.py +4 -0
  9. da4ml-0.2.0/src/da4ml/cmvm/api.py +257 -0
  10. da4ml-0.2.0/src/da4ml/cmvm/core/__init__.py +222 -0
  11. da4ml-0.2.0/src/da4ml/cmvm/core/indexers.py +83 -0
  12. da4ml-0.2.0/src/da4ml/cmvm/core/state_opr.py +284 -0
  13. da4ml-0.2.0/src/da4ml/cmvm/types.py +569 -0
  14. da4ml-0.2.0/src/da4ml/cmvm/util/__init__.py +7 -0
  15. da4ml-0.2.0/src/da4ml/cmvm/util/bit_decompose.py +86 -0
  16. da4ml-0.2.0/src/da4ml/cmvm/util/mat_decompose.py +121 -0
  17. da4ml-0.2.0/src/da4ml/codegen/__init__.py +11 -0
  18. da4ml-0.2.0/src/da4ml/codegen/cpp/__init__.py +3 -0
  19. da4ml-0.2.0/src/da4ml/codegen/cpp/cpp_codegen.py +148 -0
  20. da4ml-0.2.0/src/da4ml/codegen/cpp/source/vitis.h +30 -0
  21. da4ml-0.2.0/src/da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  22. da4ml-0.2.0/src/da4ml/codegen/verilog/__init__.py +13 -0
  23. da4ml-0.2.0/src/da4ml/codegen/verilog/comb.py +146 -0
  24. da4ml-0.2.0/src/da4ml/codegen/verilog/io_wrapper.py +255 -0
  25. da4ml-0.2.0/src/da4ml/codegen/verilog/pipeline.py +49 -0
  26. da4ml-0.2.0/src/da4ml/codegen/verilog/source/build_binder.mk +27 -0
  27. da4ml-0.2.0/src/da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  28. da4ml-0.2.0/src/da4ml/codegen/verilog/source/ioutils.hh +117 -0
  29. da4ml-0.2.0/src/da4ml/codegen/verilog/source/shift_adder.v +56 -0
  30. da4ml-0.2.0/src/da4ml/codegen/verilog/source/template.xdc +29 -0
  31. da4ml-0.2.0/src/da4ml/codegen/verilog/verilog_model.py +265 -0
  32. da4ml-0.2.0/src/da4ml/trace/__init__.py +6 -0
  33. da4ml-0.2.0/src/da4ml/trace/fixed_variable.py +358 -0
  34. da4ml-0.2.0/src/da4ml/trace/fixed_variable_array.py +177 -0
  35. da4ml-0.2.0/src/da4ml/trace/ops/__init__.py +55 -0
  36. da4ml-0.2.0/src/da4ml/trace/ops/conv_utils.py +104 -0
  37. da4ml-0.2.0/src/da4ml/trace/ops/einsum_utils.py +299 -0
  38. da4ml-0.2.0/src/da4ml/trace/pipeline.py +155 -0
  39. da4ml-0.2.0/src/da4ml/trace/tracer.py +120 -0
  40. da4ml-0.2.0/src/da4ml.egg-info/PKG-INFO +65 -0
  41. da4ml-0.2.0/src/da4ml.egg-info/SOURCES.txt +46 -0
  42. da4ml-0.2.0/src/da4ml.egg-info/requires.txt +2 -0
  43. da4ml-0.1.2/PKG-INFO +0 -122
  44. da4ml-0.1.2/README.md +0 -99
  45. da4ml-0.1.2/src/da4ml/__init__.py +0 -17
  46. da4ml-0.1.2/src/da4ml/cmvm/__init__.py +0 -35
  47. da4ml-0.1.2/src/da4ml/cmvm/api.py +0 -91
  48. da4ml-0.1.2/src/da4ml/cmvm/balanced_reduction.py +0 -46
  49. da4ml-0.1.2/src/da4ml/cmvm/cmvm.py +0 -328
  50. da4ml-0.1.2/src/da4ml/cmvm/codegen.py +0 -159
  51. da4ml-0.1.2/src/da4ml/cmvm/csd.py +0 -73
  52. da4ml-0.1.2/src/da4ml/cmvm/fixed_variable.py +0 -205
  53. da4ml-0.1.2/src/da4ml/cmvm/graph_compile.py +0 -85
  54. da4ml-0.1.2/src/da4ml/cmvm/nb_fixed_precision.py +0 -98
  55. da4ml-0.1.2/src/da4ml/cmvm/scoring.py +0 -55
  56. da4ml-0.1.2/src/da4ml/cmvm/utils.py +0 -5
  57. da4ml-0.1.2/src/da4ml.egg-info/PKG-INFO +0 -122
  58. da4ml-0.1.2/src/da4ml.egg-info/SOURCES.txt +0 -24
  59. da4ml-0.1.2/src/da4ml.egg-info/requires.txt +0 -2
  60. {da4ml-0.1.2 → da4ml-0.2.0}/.github/workflows/python-publish.yml +0 -0
  61. {da4ml-0.1.2 → da4ml-0.2.0}/.pre-commit-config.yaml +0 -0
  62. {da4ml-0.1.2 → da4ml-0.2.0}/LICENSE +0 -0
  63. {da4ml-0.1.2 → da4ml-0.2.0}/setup.cfg +0 -0
  64. {da4ml-0.1.2 → da4ml-0.2.0}/src/da4ml.egg-info/dependency_links.txt +0 -0
  65. {da4ml-0.1.2 → da4ml-0.2.0}/src/da4ml.egg-info/top_level.txt +0 -0
@@ -0,0 +1,191 @@
1
+ ---
2
+ Language: Cpp
3
+ # BasedOnStyle: LLVM
4
+ AccessModifierOffset: -2
5
+ AlignAfterOpenBracket: Align
6
+ AlignArrayOfStructures: None
7
+ AlignConsecutiveMacros: None
8
+ AlignConsecutiveAssignments: None
9
+ AlignConsecutiveBitFields: None
10
+ AlignConsecutiveDeclarations: None
11
+ AlignEscapedNewlines: Right
12
+ AlignOperands: Align
13
+ AlignTrailingComments: true
14
+ AllowAllArgumentsOnNextLine: true
15
+ AllowAllParametersOfDeclarationOnNextLine: true
16
+ AllowShortEnumsOnASingleLine: true
17
+ AllowShortBlocksOnASingleLine: Never
18
+ AllowShortCaseLabelsOnASingleLine: false
19
+ AllowShortFunctionsOnASingleLine: All
20
+ AllowShortLambdasOnASingleLine: All
21
+ AllowShortIfStatementsOnASingleLine: Never
22
+ AllowShortLoopsOnASingleLine: false
23
+ AlwaysBreakAfterDefinitionReturnType: None
24
+ AlwaysBreakAfterReturnType: None
25
+ AlwaysBreakBeforeMultilineStrings: false
26
+ AlwaysBreakTemplateDeclarations: MultiLine
27
+ AttributeMacros:
28
+ - __capability
29
+ BinPackArguments: true
30
+ BinPackParameters: true
31
+ BraceWrapping:
32
+ AfterCaseLabel: false
33
+ AfterClass: false
34
+ AfterControlStatement: Never
35
+ AfterEnum: false
36
+ AfterFunction: false
37
+ AfterNamespace: false
38
+ AfterObjCDeclaration: false
39
+ AfterStruct: false
40
+ AfterUnion: false
41
+ AfterExternBlock: false
42
+ BeforeCatch: false
43
+ BeforeElse: false
44
+ BeforeLambdaBody: false
45
+ BeforeWhile: false
46
+ IndentBraces: false
47
+ SplitEmptyFunction: true
48
+ SplitEmptyRecord: true
49
+ SplitEmptyNamespace: true
50
+ BreakBeforeBinaryOperators: None
51
+ BreakBeforeConceptDeclarations: true
52
+ BreakBeforeBraces: Attach
53
+ BreakBeforeInheritanceComma: false
54
+ BreakInheritanceList: BeforeColon
55
+ BreakBeforeTernaryOperators: true
56
+ BreakConstructorInitializersBeforeComma: false
57
+ BreakConstructorInitializers: BeforeColon
58
+ BreakAfterJavaFieldAnnotations: false
59
+ BreakStringLiterals: true
60
+ ColumnLimit: 125
61
+ CommentPragmas: '^ IWYU pragma:'
62
+ QualifierAlignment: Leave
63
+ CompactNamespaces: false
64
+ ConstructorInitializerIndentWidth: 4
65
+ ContinuationIndentWidth: 4
66
+ Cpp11BracedListStyle: true
67
+ DeriveLineEnding: true
68
+ DerivePointerAlignment: false
69
+ DisableFormat: false
70
+ EmptyLineAfterAccessModifier: Never
71
+ EmptyLineBeforeAccessModifier: LogicalBlock
72
+ ExperimentalAutoDetectBinPacking: false
73
+ PackConstructorInitializers: BinPack
74
+ BasedOnStyle: ''
75
+ ConstructorInitializerAllOnOneLineOrOnePerLine: false
76
+ AllowAllConstructorInitializersOnNextLine: true
77
+ FixNamespaceComments: true
78
+ ForEachMacros:
79
+ - foreach
80
+ - Q_FOREACH
81
+ - BOOST_FOREACH
82
+ IfMacros:
83
+ - KJ_IF_MAYBE
84
+ IncludeBlocks: Preserve
85
+ IncludeCategories:
86
+ - Regex: '^"(llvm|llvm-c|clang|clang-c)/'
87
+ Priority: 2
88
+ SortPriority: 0
89
+ CaseSensitive: false
90
+ - Regex: '^(<|"(gtest|gmock|isl|json)/)'
91
+ Priority: 3
92
+ SortPriority: 0
93
+ CaseSensitive: false
94
+ - Regex: '.*'
95
+ Priority: 1
96
+ SortPriority: 0
97
+ CaseSensitive: false
98
+ IncludeIsMainRegex: '(Test)?$'
99
+ IncludeIsMainSourceRegex: ''
100
+ IndentAccessModifiers: false
101
+ IndentCaseLabels: false
102
+ IndentCaseBlocks: false
103
+ IndentGotoLabels: true
104
+ IndentPPDirectives: None
105
+ IndentExternBlock: AfterExternBlock
106
+ IndentRequires: false
107
+ IndentWidth: 4
108
+ IndentWrappedFunctionNames: false
109
+ InsertTrailingCommas: None
110
+ JavaScriptQuotes: Leave
111
+ JavaScriptWrapImports: true
112
+ KeepEmptyLinesAtTheStartOfBlocks: true
113
+ LambdaBodyIndentation: Signature
114
+ MacroBlockBegin: ''
115
+ MacroBlockEnd: ''
116
+ MaxEmptyLinesToKeep: 1
117
+ NamespaceIndentation: None
118
+ ObjCBinPackProtocolList: Auto
119
+ ObjCBlockIndentWidth: 2
120
+ ObjCBreakBeforeNestedBlockParam: true
121
+ ObjCSpaceAfterProperty: false
122
+ ObjCSpaceBeforeProtocolList: true
123
+ PenaltyBreakAssignment: 2
124
+ PenaltyBreakBeforeFirstCallParameter: 19
125
+ PenaltyBreakComment: 300
126
+ PenaltyBreakFirstLessLess: 120
127
+ PenaltyBreakOpenParenthesis: 0
128
+ PenaltyBreakString: 1000
129
+ PenaltyBreakTemplateDeclaration: 10
130
+ PenaltyExcessCharacter: 1000000
131
+ PenaltyReturnTypeOnItsOwnLine: 60
132
+ PenaltyIndentedWhitespace: 0
133
+ PointerAlignment: Right
134
+ PPIndentWidth: -1
135
+ ReferenceAlignment: Pointer
136
+ ReflowComments: true
137
+ RemoveBracesLLVM: false
138
+ SeparateDefinitionBlocks: Leave
139
+ ShortNamespaceLines: 1
140
+ SortIncludes: CaseSensitive
141
+ SortJavaStaticImport: Before
142
+ SortUsingDeclarations: true
143
+ SpaceAfterCStyleCast: false
144
+ SpaceAfterLogicalNot: false
145
+ SpaceAfterTemplateKeyword: true
146
+ SpaceBeforeAssignmentOperators: true
147
+ SpaceBeforeCaseColon: false
148
+ SpaceBeforeCpp11BracedList: false
149
+ SpaceBeforeCtorInitializerColon: true
150
+ SpaceBeforeInheritanceColon: true
151
+ SpaceBeforeParens: ControlStatements
152
+ SpaceBeforeParensOptions:
153
+ AfterControlStatements: true
154
+ AfterForeachMacros: true
155
+ AfterFunctionDefinitionName: false
156
+ AfterFunctionDeclarationName: false
157
+ AfterIfMacros: true
158
+ AfterOverloadedOperator: false
159
+ BeforeNonEmptyParentheses: false
160
+ SpaceAroundPointerQualifiers: Default
161
+ SpaceBeforeRangeBasedForLoopColon: true
162
+ SpaceInEmptyBlock: false
163
+ SpaceInEmptyParentheses: false
164
+ SpacesBeforeTrailingComments: 1
165
+ SpacesInAngles: Never
166
+ SpacesInConditionalStatement: false
167
+ SpacesInContainerLiterals: true
168
+ SpacesInCStyleCastParentheses: false
169
+ SpacesInLineCommentPrefix:
170
+ Minimum: 1
171
+ Maximum: -1
172
+ SpacesInParentheses: false
173
+ SpacesInSquareBrackets: false
174
+ SpaceBeforeSquareBrackets: false
175
+ BitFieldColonSpacing: Both
176
+ Standard: Latest
177
+ StatementAttributeLikeMacros:
178
+ - Q_EMIT
179
+ StatementMacros:
180
+ - Q_UNUSED
181
+ - QT_REQUIRE_VERSION
182
+ TabWidth: 8
183
+ UseCRLF: false
184
+ UseTab: Never
185
+ WhitespaceSensitiveMacros:
186
+ - STRINGIZE
187
+ - PP_STRINGIZE
188
+ - BOOST_PP_STRINGIZE
189
+ - NS_SWIFT_NAME
190
+ - CF_SWIFT_NAME
191
+ ...
@@ -1,3 +1,11 @@
1
+ # test files
2
+ _*_codegen/
3
+
4
+ # Verilator generated files
5
+ obj_dir/
6
+ *.vcd
7
+ *.vpd
8
+
1
9
  # Notebooks
2
10
  *.ipynb
3
11
 
da4ml-0.2.0/PKG-INFO ADDED
@@ -0,0 +1,65 @@
1
+ Metadata-Version: 2.4
2
+ Name: da4ml
3
+ Version: 0.2.0
4
+ Summary: Digital Arithmetic for Machine Learning
5
+ Author-email: Chang Sun <chsun@cern.ch>
6
+ License: GNU Lesser General Public License v3 (LGPLv3)
7
+ Project-URL: repository, https://github.com/calad0i/da4ml
8
+ Keywords: CMVM,distributed arithmetic,hls4ml,MCM,subexpression elimination
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Requires-Python: >=3.10
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: llvmlite>=0.44
21
+ Requires-Dist: numba>=0.61
22
+ Dynamic: license-file
23
+
24
+ # da4ml: Distributed Arithmetic for Machine Learning
25
+
26
+ This project performs Constant Matrix-Vector Multiplication (CMVM) with Distributed Arithmetic (DA) for Machine Learning (ML) on a Field Programmable Gate Arrays (FPGAs).
27
+
28
+ CMVM optimization is done through greedy CSE of two-term subexpressions, with possible Delay Constraints (DC). The optimization is done in jitted Python (Numba), and a list of optimized operations is generated as traced Python code.
29
+
30
+ At the moment, the project only generates Vitis HLS C++ code for the FPGA implementation of the optimized CMVM kernel. HDL code generation is planned for the future. Currently, the major use of this repository is through the `distributed_arithmetic` strategy in the [`hls4ml`](https://github.com/fastmachinelearning/hls4ml/) project.
31
+
32
+
33
+ ## Installation
34
+
35
+ The project is available on PyPI and can be installed with pip:
36
+
37
+ ```bash
38
+ pip install da4ml
39
+ ```
40
+
41
+ Notice that `numba>=6.0.0` is required for the project to work. The project does not work with `python<3.10`. If the project fails to compile, try upgrading `numba` and `llvmlite` to the latest versions.
42
+
43
+ ## `hls4ml`
44
+
45
+ The major use of this project is through the `distributed_arithmetic` strategy in the `hls4ml`:
46
+
47
+ ```python
48
+ model_hls = hls4ml.converters.convert_from_keras_model(
49
+ model,
50
+ hls_config={
51
+ 'Model': {
52
+ ...
53
+ 'Strategy': 'distributed_arithmetic',
54
+ },
55
+ ...
56
+ },
57
+ ...
58
+ )
59
+ ```
60
+
61
+ Currently, `Dense/Conv1D/Conv2D` layers are supported for both `io_parallel` and `io_stream` dataflows. However, notice that distributed arithmetic implies `reuse_factor=1`, as the whole kernel is implemented in combinational logic.
62
+
63
+ ### Notice
64
+
65
+ Currently, only the `da4ml-v3` branch of `hls4ml` supports the `distributed_arithmetic` strategy. The `da4ml-v3` branch is not yet merged into the `main` branch of `hls4ml`, so you need to install it from the GitHub repository.
da4ml-0.2.0/README.md ADDED
@@ -0,0 +1,42 @@
1
+ # da4ml: Distributed Arithmetic for Machine Learning
2
+
3
+ This project performs Constant Matrix-Vector Multiplication (CMVM) with Distributed Arithmetic (DA) for Machine Learning (ML) on a Field Programmable Gate Arrays (FPGAs).
4
+
5
+ CMVM optimization is done through greedy CSE of two-term subexpressions, with possible Delay Constraints (DC). The optimization is done in jitted Python (Numba), and a list of optimized operations is generated as traced Python code.
6
+
7
+ At the moment, the project only generates Vitis HLS C++ code for the FPGA implementation of the optimized CMVM kernel. HDL code generation is planned for the future. Currently, the major use of this repository is through the `distributed_arithmetic` strategy in the [`hls4ml`](https://github.com/fastmachinelearning/hls4ml/) project.
8
+
9
+
10
+ ## Installation
11
+
12
+ The project is available on PyPI and can be installed with pip:
13
+
14
+ ```bash
15
+ pip install da4ml
16
+ ```
17
+
18
+ Notice that `numba>=6.0.0` is required for the project to work. The project does not work with `python<3.10`. If the project fails to compile, try upgrading `numba` and `llvmlite` to the latest versions.
19
+
20
+ ## `hls4ml`
21
+
22
+ The major use of this project is through the `distributed_arithmetic` strategy in the `hls4ml`:
23
+
24
+ ```python
25
+ model_hls = hls4ml.converters.convert_from_keras_model(
26
+ model,
27
+ hls_config={
28
+ 'Model': {
29
+ ...
30
+ 'Strategy': 'distributed_arithmetic',
31
+ },
32
+ ...
33
+ },
34
+ ...
35
+ )
36
+ ```
37
+
38
+ Currently, `Dense/Conv1D/Conv2D` layers are supported for both `io_parallel` and `io_stream` dataflows. However, notice that distributed arithmetic implies `reuse_factor=1`, as the whole kernel is implemented in combinational logic.
39
+
40
+ ### Notice
41
+
42
+ Currently, only the `da4ml-v3` branch of `hls4ml` supports the `distributed_arithmetic` strategy. The `da4ml-v3` branch is not yet merged into the `main` branch of `hls4ml`, so you need to install it from the GitHub repository.
@@ -29,7 +29,7 @@ classifiers = [
29
29
  "Programming Language :: Python :: 3.13",
30
30
  ]
31
31
  dynamic = [ "version" ]
32
- dependencies = [ "llvmlite>=0.43", "numba>=0.60" ]
32
+ dependencies = [ "llvmlite>=0.44", "numba>=0.61" ]
33
33
  urls.repository = "https://github.com/calad0i/da4ml"
34
34
 
35
35
  [tool.setuptools]
@@ -54,7 +54,7 @@ format.skip-magic-trailing-comma = false
54
54
  format.docstring-code-line-length = 130
55
55
  format.docstring-code-format = true
56
56
  lint.select = [ "E", "F", "F401", "I", "W" ]
57
- lint.ignore = [ "E501", "F403", "F405" ]
57
+ lint.ignore = [ "E501", "E741", "F403", "F405" ]
58
58
  lint.explicit-preview-rules = true
59
59
  lint.fixable = [ "ALL" ]
60
60
  lint.unfixable = [ ]
@@ -0,0 +1,17 @@
1
+ # from .cmvm.api import cost, fn_from_kernel
2
+ # from .cmvm.cmvm import compile_kernel
3
+ # from .cmvm.codegen import PyCodegenBackend, VitisCodegenBackend
4
+ # from .cmvm.graph_compile import graph_compile_states
5
+ # from .cmvm.utils import DAState, OpCode, Score
6
+
7
+ # __all__ = [
8
+ # 'DAState',
9
+ # 'OpCode',
10
+ # 'Score',
11
+ # 'cost',
12
+ # 'compile_kernel',
13
+ # 'fn_from_kernel',
14
+ # 'graph_compile_states',
15
+ # 'PyCodegenBackend',
16
+ # 'VitisCodegenBackend',
17
+ # ]
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.2'
21
- __version_tuple__ = version_tuple = (0, 1, 2)
20
+ __version__ = version = '0.2.0'
21
+ __version_tuple__ = version_tuple = (0, 2, 0)
@@ -0,0 +1,4 @@
1
+ from .api import minimal_latency, solve
2
+ from .types import Op, QInterval, Solution
3
+
4
+ __all__ = ['minimal_latency', 'solve', 'QInterval', 'Op', 'Solution']
@@ -0,0 +1,257 @@
1
+ from math import ceil, log2
2
+
3
+ import numpy as np
4
+ from numba import jit, prange
5
+
6
+ from .core import _solve, create_state, to_solution
7
+ from .types import CascadedSolution, QInterval
8
+ from .util import kernel_decompose
9
+
10
+
11
+ @jit(cache=True)
12
+ def minimal_latency(
13
+ kernel: np.ndarray,
14
+ qintervals: list[QInterval],
15
+ latencies: list[float],
16
+ carry_size: int = -1,
17
+ adder_size: int = -1,
18
+ ):
19
+ """Fast latency calculation for a given kernel, QInterval, and input latencies.
20
+ When carry_size=-1, and the input latency is constant `l`:
21
+ this will be the same as `l + max(ceiling(log2(max(#CSD bits for each column, 1))))`.
22
+
23
+ Parameters
24
+ ----------
25
+ kernel : np.ndarray
26
+ The input kernel matrix.
27
+ qintervals : list[QInterval]
28
+ List of QIntervals for each input.
29
+ latencies : list[float]
30
+ List of latencies for each input
31
+ carry_size : int, optional
32
+ The size of the carry unit for latency computation, by default -1 (fixed latency for each addition operation)
33
+ adder_size : int, optional
34
+ The size of the adder unit for latency computation, by default -1 (fixed cost for each addition operation)
35
+
36
+ Returns
37
+ -------
38
+ float
39
+ The minimal latency for the given kernel, QInterval, and input latencies.
40
+ """
41
+
42
+ state = create_state(kernel, qintervals, latencies, no_stat_init=True)
43
+ sol = to_solution(state, adder_size=adder_size, carry_size=carry_size)
44
+ latencies = [sol.ops[i].latency if i >= 0 else 0.0 for i in sol.out_idxs]
45
+ return max(latencies)
46
+
47
+
48
+ @jit(cache=True)
49
+ def jit_solve(
50
+ kernel: np.ndarray,
51
+ method0: str = 'wmc',
52
+ method1: str = 'auto',
53
+ hard_dc: int = -1,
54
+ decompose_dc: int = -2,
55
+ qintervals: list[QInterval] | None = None,
56
+ latencies: list[float] | None = None,
57
+ adder_size: int = -1,
58
+ carry_size: int = -1,
59
+ ) -> CascadedSolution:
60
+ """Optimized implementation of a CMVM computation with cascaded two matrices.
61
+
62
+ Parameters
63
+ ----------
64
+ kernel : np.ndarray
65
+ The input kernel matrix to be implemented.
66
+ method0 : str, optional
67
+ Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
68
+ method1 : str, optional
69
+ Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
70
+ hard_dc : int, optional
71
+ Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
72
+ decompose_dc : int, optional
73
+ Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
74
+ qintervals : list[QInterval] | None, optional
75
+ List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
76
+ inp_latencies : list[float] | None, optional
77
+ List of input latencies, by default None (0. for all inputs)
78
+ adder_size : int, optional
79
+ Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
80
+ carry_size : int, optional
81
+ Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
82
+
83
+ Returns
84
+ -------
85
+ CascadedSolution
86
+ A solution containing the optimized implementation of the CMVM computation with cascaded stages.
87
+ """
88
+
89
+ if hard_dc < 0:
90
+ hard_dc = int(1e9)
91
+
92
+ if method1 == 'auto':
93
+ if hard_dc >= 6 or method0.endswith('dc'):
94
+ method1 = method0
95
+ else:
96
+ method1 = method0 + '-dc'
97
+ if hard_dc == 0 and not method0.endswith('dc'):
98
+ method0 = method0 + '-dc'
99
+
100
+ if qintervals is None:
101
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
102
+ else:
103
+ _qintervals = list(qintervals)
104
+ if latencies is None:
105
+ _inp_latencies = [0.0] * kernel.shape[0]
106
+ else:
107
+ _inp_latencies = [float(lat) for lat in latencies]
108
+ assert len(_qintervals) == kernel.shape[0]
109
+ assert len(_inp_latencies) == kernel.shape[0]
110
+
111
+ min_lat = minimal_latency(kernel, _qintervals, _inp_latencies, carry_size=carry_size, adder_size=adder_size)
112
+ latency_allowed = hard_dc + min_lat
113
+ if decompose_dc == -2:
114
+ decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
115
+ else:
116
+ decompose_dc = min(hard_dc, decompose_dc, ceil(log2(kernel.shape[0])))
117
+
118
+ while True:
119
+ if decompose_dc < 0 and hard_dc >= 0:
120
+ if method0 != 'dummy':
121
+ method0, method1 = 'wmc-dc', 'wmc-dc'
122
+ else:
123
+ method0, method1 = 'dummy', 'dummy'
124
+ mat0, mat1 = kernel_decompose(kernel, dc=decompose_dc)
125
+ sol0 = _solve(
126
+ mat0, method=method0, qintervals=_qintervals, latencies=_inp_latencies, adder_size=adder_size, carry_size=carry_size
127
+ )
128
+ latencies0 = [sol0.ops[i].latency if i >= 0 else 0.0 for i in sol0.out_idxs]
129
+ qintervals0 = [sol0.ops[i].qint if i >= 0 else QInterval(0.0, 0.0, np.inf) for i in sol0.out_idxs]
130
+ if max(latencies0) > latency_allowed:
131
+ if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
132
+ decompose_dc -= 1
133
+ continue
134
+ sol1 = _solve(
135
+ mat1, method=method1, qintervals=qintervals0, latencies=latencies0, adder_size=adder_size, carry_size=carry_size
136
+ )
137
+ latencies1 = [sol1.ops[i].latency if i >= 0 else 0.0 for i in sol1.out_idxs]
138
+ if max(latencies1) > latency_allowed:
139
+ # Prevent infinite loop, shouldn't happen though
140
+ if not method0 == method1 == 'wmc-dc' or decompose_dc >= 0:
141
+ decompose_dc -= 1
142
+ continue
143
+ if sum([op.cost for op in sol1.ops]) * 4 > sum([op.cost for op in sol0.ops]) and decompose_dc > 0:
144
+ # If the second stage is too expensive, the decomposition usually doesn't worth it
145
+ decompose_dc -= 1
146
+ continue
147
+ break
148
+ if max(latencies1) > latency_allowed:
149
+ # When latency depends on the bw, may happen
150
+ print(f'Latency constraint not satisfied: {int(latency_allowed)} < {int(max(latencies1))}')
151
+ return CascadedSolution((sol0, sol1))
152
+
153
+
154
+ @jit(cache=True, parallel=True)
155
+ def solve(
156
+ kernel: np.ndarray,
157
+ method0: str = 'wmc',
158
+ method1: str = 'auto',
159
+ hard_dc: int = -1,
160
+ decompose_dc: int = -2,
161
+ qintervals: tuple[QInterval, ...] | None = None,
162
+ latencies: tuple[float, ...] | None = None,
163
+ adder_size: int = -1,
164
+ carry_size: int = -1,
165
+ search_all_decompose_dc: bool = True,
166
+ ) -> CascadedSolution:
167
+ """Solve the CMVM problem with cascaded two matrices.
168
+
169
+ Parameters
170
+ ----------
171
+ kernel : np.ndarray
172
+ The input kernel matrix to be implemented.
173
+ method0 : str, optional
174
+ Optimization method for the first stage. Must be one of [`wmc`, `wmc-dc`, `wmc-pdc`, `mc`, `mc-dc`, `mc-pdc`].
175
+ method1 : str, optional
176
+ Optimization method for the second stage. When 'auto', it will select based on hard_dc and method0, by default 'auto'
177
+ hard_dc : int, optional
178
+ Hard depth constraint (additional latency allowed beyond minimal latency), by default -1 (no constraint)
179
+ decompose_dc : int, optional
180
+ Decomposition depth constraint, by default -1 (no constraint, follows hard_dc)
181
+ qintervals : list[QInterval] | None, optional
182
+ List of quantization intervals for each input, by default None ([-128, 127, 1] for all inputs)
183
+ inp_latencies : list[float] | None, optional
184
+ List of input latencies, by default None (0. for all inputs)
185
+ adder_size : int, optional
186
+ Size of the adder unit for latency computation, by default -1 (fixed cost for each addition)
187
+ carry_size : int, optional
188
+ Size of the carry unit for latency computation, by default -1 (fixed latency for each addition)
189
+ search_all_decompose_dc : bool, optional
190
+ If True, search for all possible decomposition depth constraints. If False, use the provided decompose_dc value.
191
+ Default is True.
192
+
193
+ Returns
194
+ -------
195
+ CascadedSolution
196
+ A solution containing the optimized implementation of the CMVM computation with cascaded stages.
197
+ """
198
+
199
+ if qintervals is None:
200
+ _qintervals = [QInterval(-128.0, 127.0, 1.0)] * kernel.shape[0]
201
+ else:
202
+ _qintervals = list(qintervals)
203
+ if latencies is None:
204
+ _latencies = [0.0] * kernel.shape[0]
205
+ else:
206
+ _latencies = [float(lat) for lat in latencies]
207
+
208
+ if not search_all_decompose_dc:
209
+ return jit_solve(
210
+ kernel,
211
+ method0=method0,
212
+ method1=method1,
213
+ hard_dc=hard_dc,
214
+ decompose_dc=decompose_dc,
215
+ qintervals=_qintervals,
216
+ latencies=_latencies,
217
+ adder_size=adder_size,
218
+ carry_size=carry_size,
219
+ )
220
+
221
+ if hard_dc < 0:
222
+ hard_dc = int(1e9)
223
+
224
+ max_decompose_dc = min(hard_dc, ceil(log2(kernel.shape[0])))
225
+ try_decompose_dcs = list(range(-1, max_decompose_dc + 1))
226
+
227
+ costs = np.empty(len(try_decompose_dcs), dtype=np.float64)
228
+
229
+ for i in prange(len(try_decompose_dcs)):
230
+ decompose_dc = try_decompose_dcs[i]
231
+ _csol = jit_solve(
232
+ kernel,
233
+ method0=method0,
234
+ method1=method1,
235
+ hard_dc=hard_dc,
236
+ decompose_dc=decompose_dc,
237
+ qintervals=_qintervals,
238
+ latencies=_latencies,
239
+ adder_size=adder_size,
240
+ carry_size=carry_size,
241
+ )
242
+ _cost = sum([sum([op.cost for op in sol.ops]) for sol in _csol.solutions])
243
+ costs[i] = _cost
244
+
245
+ decompose_dc = try_decompose_dcs[np.argmin(costs)]
246
+ csol = jit_solve(
247
+ kernel,
248
+ method0=method0,
249
+ method1=method1,
250
+ hard_dc=hard_dc,
251
+ decompose_dc=decompose_dc,
252
+ qintervals=_qintervals,
253
+ latencies=_latencies,
254
+ adder_size=adder_size,
255
+ carry_size=carry_size,
256
+ )
257
+ return csol