da4ml 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

Files changed (50) hide show
  1. da4ml/__init__.py +16 -16
  2. da4ml/_version.py +2 -2
  3. da4ml/cmvm/__init__.py +3 -34
  4. da4ml/cmvm/api.py +239 -73
  5. da4ml/cmvm/core/__init__.py +222 -0
  6. da4ml/cmvm/core/indexers.py +83 -0
  7. da4ml/cmvm/core/state_opr.py +284 -0
  8. da4ml/cmvm/types.py +569 -0
  9. da4ml/cmvm/util/__init__.py +7 -0
  10. da4ml/cmvm/util/bit_decompose.py +86 -0
  11. da4ml/cmvm/util/mat_decompose.py +121 -0
  12. da4ml/codegen/__init__.py +11 -0
  13. da4ml/codegen/cpp/__init__.py +3 -0
  14. da4ml/codegen/cpp/cpp_codegen.py +148 -0
  15. da4ml/codegen/cpp/source/vitis.h +30 -0
  16. da4ml/codegen/cpp/source/vitis_bridge.h +17 -0
  17. da4ml/codegen/verilog/__init__.py +13 -0
  18. da4ml/codegen/verilog/comb.py +146 -0
  19. da4ml/codegen/verilog/io_wrapper.py +255 -0
  20. da4ml/codegen/verilog/pipeline.py +49 -0
  21. da4ml/codegen/verilog/source/build_binder.mk +27 -0
  22. da4ml/codegen/verilog/source/build_prj.tcl +75 -0
  23. da4ml/codegen/verilog/source/ioutils.hh +117 -0
  24. da4ml/codegen/verilog/source/shift_adder.v +56 -0
  25. da4ml/codegen/verilog/source/template.xdc +29 -0
  26. da4ml/codegen/verilog/verilog_model.py +265 -0
  27. da4ml/trace/__init__.py +6 -0
  28. da4ml/trace/fixed_variable.py +358 -0
  29. da4ml/trace/fixed_variable_array.py +177 -0
  30. da4ml/trace/ops/__init__.py +55 -0
  31. da4ml/trace/ops/conv_utils.py +104 -0
  32. da4ml/trace/ops/einsum_utils.py +299 -0
  33. da4ml/trace/pipeline.py +155 -0
  34. da4ml/trace/tracer.py +120 -0
  35. da4ml-0.2.0.dist-info/METADATA +65 -0
  36. da4ml-0.2.0.dist-info/RECORD +39 -0
  37. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/WHEEL +1 -1
  38. da4ml/cmvm/balanced_reduction.py +0 -46
  39. da4ml/cmvm/cmvm.py +0 -328
  40. da4ml/cmvm/codegen.py +0 -159
  41. da4ml/cmvm/csd.py +0 -73
  42. da4ml/cmvm/fixed_variable.py +0 -205
  43. da4ml/cmvm/graph_compile.py +0 -85
  44. da4ml/cmvm/nb_fixed_precision.py +0 -98
  45. da4ml/cmvm/scoring.py +0 -55
  46. da4ml/cmvm/utils.py +0 -5
  47. da4ml-0.1.2.dist-info/METADATA +0 -122
  48. da4ml-0.1.2.dist-info/RECORD +0 -18
  49. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
  50. {da4ml-0.1.2.dist-info → da4ml-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,155 @@
1
+ from math import ceil, floor
2
+
3
+ from ..cmvm.types import CascadedSolution, Op, Solution
4
+ from .fixed_variable import FixedVariable, HWConfig
5
+ from .tracer import comb_trace
6
+
7
+
8
+ def retime_pipeline(csol: CascadedSolution, verbose=True):
9
+ n_stages = len(csol[0])
10
+ cutoff_high = ceil(max(max(sol.out_latency) / (i + 1) for i, sol in enumerate(csol[0])))
11
+ cutoff_low = 0
12
+ adder_size, carry_size = csol[0][0].adder_size, csol[0][0].carry_size
13
+ best = csol
14
+ while cutoff_high - cutoff_low > 1:
15
+ cutoff = (cutoff_high + cutoff_low) // 2
16
+ _hwconf = HWConfig(adder_size, carry_size, cutoff)
17
+ inp = [FixedVariable(*qint, hwconf=_hwconf) for qint in csol.inp_qint]
18
+ try:
19
+ out = list(csol(inp))
20
+ except AssertionError:
21
+ cutoff_low = cutoff
22
+ continue
23
+ _sol = to_pipeline(comb_trace(inp, out), cutoff, retiming=False)
24
+ if len(_sol[0]) > n_stages:
25
+ cutoff_low = cutoff
26
+ else:
27
+ cutoff_high = cutoff
28
+ best = _sol
29
+ if verbose:
30
+ print(f'actual cutoff: {cutoff_high}')
31
+ return best
32
+
33
+
34
+ def to_pipeline(sol: Solution, latency_cutoff: int, retiming=True, verbose=True) -> CascadedSolution:
35
+ """Split the record into multiple stages based on the latency of the operations.
36
+ Only useful for HDL generation.
37
+
38
+ Parameters
39
+ ----------
40
+ sol : Solution
41
+ The solution to be split into multiple stages.
42
+ latency_cutoff : int
43
+ The latency cutoff for splitting the operations.
44
+ retiming : bool
45
+ Whether to retime the solution after splitting. Default is True.
46
+ If False, new stages are created when the propagation latency exceeds the cutoff.
47
+ If True, after the first round of splitting, the solution is retimed balance the delay within each stage.
48
+ verbose : bool
49
+ Whether to print the actual cutoff used for splitting. Only used if rebalance is True.
50
+ Default is True.
51
+
52
+ Returns
53
+ -------
54
+ CascadedSolution
55
+ The cascaded solution with multiple stages.
56
+ """
57
+ assert len(sol.ops) > 0, 'No operations in the record'
58
+ for i, op in enumerate(sol.ops):
59
+ if op.id1 != -1:
60
+ break
61
+
62
+ def get_stage(op: Op):
63
+ return floor(op.latency / (latency_cutoff + 1e-9)) if latency_cutoff > 0 else 0
64
+
65
+ opd: dict[int, list[Op]] = {}
66
+ out_idxd: dict[int, list[int]] = {}
67
+
68
+ locator: list[dict[int, int]] = []
69
+
70
+ ops = sol.ops.copy()
71
+ lat = max(ops[i].latency for i in sol.out_idxs)
72
+ for i in sol.out_idxs:
73
+ op_out = ops[i]
74
+ ops.append(Op(i, -1001, -1001, 0, op_out.qint, lat, 0.0))
75
+
76
+ for i, op in enumerate(ops):
77
+ stage = get_stage(op)
78
+ if op.opcode == -1:
79
+ # Copy from external buffer
80
+ opd.setdefault(stage, []).append(op)
81
+ locator.append({stage: len(opd[stage]) - 1})
82
+ continue
83
+ p0_stages = locator[op.id0].keys()
84
+ if stage not in p0_stages:
85
+ # Need to copy parent to later states
86
+ p0_stage = max(p0_stages)
87
+ p0_idx = locator[op.id0][p0_stage]
88
+ for j in range(p0_stage, stage):
89
+ op0 = ops[op.id0]
90
+ latency = float(latency_cutoff * (j + 1))
91
+ out_idxd.setdefault(j, []).append(locator[op.id0][j])
92
+ _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op0.qint, latency, 0.0)
93
+ opd.setdefault(j + 1, []).append(_copy_op)
94
+ p0_idx = len(opd[j + 1]) - 1
95
+ locator[op.id0][j + 1] = p0_idx
96
+ else:
97
+ p0_idx = locator[op.id0][stage]
98
+
99
+ if op.opcode in (0, 1):
100
+ p1_stages = locator[op.id1].keys()
101
+ if stage not in p1_stages:
102
+ # Need to copy parent to later states
103
+ p1_stage = max(p1_stages)
104
+ p1_idx = locator[op.id1][p1_stage]
105
+ for j in range(p1_stage, stage):
106
+ op1 = ops[op.id1]
107
+ latency = float(latency_cutoff * (j + 1))
108
+ out_idxd.setdefault(j, []).append(locator[op.id1][j])
109
+ _copy_op = Op(len(out_idxd[j]) - 1, -1, -1, 0, op1.qint, latency, 0.0)
110
+ opd.setdefault(j + 1, []).append(_copy_op)
111
+ p1_idx = len(opd[j + 1]) - 1
112
+ locator[op.id1][j + 1] = p1_idx
113
+ else:
114
+ p1_idx = locator[op.id1][stage]
115
+ else:
116
+ p1_idx = op.id1
117
+
118
+ if p1_idx == -1001:
119
+ # Output to external buffer
120
+ out_idxd.setdefault(stage, []).append(p0_idx)
121
+ else:
122
+ _Op = Op(p0_idx, p1_idx, op.opcode, op.data, op.qint, op.latency, op.cost)
123
+ opd.setdefault(stage, []).append(_Op)
124
+ locator.append({stage: len(opd[stage]) - 1})
125
+ sols = []
126
+ max_stage = max(opd.keys())
127
+ for i, stage in enumerate(opd.keys()):
128
+ _ops = opd[stage]
129
+ _out_idx = out_idxd[stage]
130
+ n_in = sum(op.opcode == -1 for op in _ops)
131
+ n_out = len(_out_idx)
132
+
133
+ if i == max_stage:
134
+ out_shifts = sol.out_shifts
135
+ out_negs = sol.out_negs
136
+ else:
137
+ out_shifts = [0] * len(_out_idx)
138
+ out_negs = [False] * len(_out_idx)
139
+
140
+ _sol = Solution(
141
+ shape=(n_in, n_out),
142
+ inp_shift=[0] * n_in,
143
+ out_idxs=_out_idx,
144
+ out_shifts=out_shifts,
145
+ out_negs=out_negs,
146
+ ops=_ops,
147
+ carry_size=sol.carry_size,
148
+ adder_size=sol.adder_size,
149
+ )
150
+ sols.append(_sol)
151
+ csol = CascadedSolution(tuple(sols))
152
+
153
+ if retiming:
154
+ csol = retime_pipeline(csol, verbose=verbose)
155
+ return csol
da4ml/trace/tracer.py ADDED
@@ -0,0 +1,120 @@
1
+ from collections.abc import Sequence
2
+ from decimal import Decimal
3
+ from math import log2
4
+ from typing import overload
5
+ from uuid import UUID
6
+
7
+ import numpy as np
8
+
9
+ from ..cmvm.types import Op, QInterval, Solution
10
+ from .fixed_variable import FixedVariable, _const_f
11
+ from .fixed_variable_array import FixedVariableArray
12
+
13
+
14
+ def _recursive_trace(v: FixedVariable, gathered: dict[UUID, FixedVariable]):
15
+ if v in gathered:
16
+ return
17
+ assert v._from is not None
18
+ for _v in v._from:
19
+ if _v.id not in gathered:
20
+ _recursive_trace(_v, gathered)
21
+ gathered[v.id] = v
22
+
23
+
24
+ def gather_variables(inputs: Sequence[FixedVariable], outputs: Sequence[FixedVariable]):
25
+ gathered = {v.id: v for v in inputs}
26
+ for o in outputs:
27
+ _recursive_trace(o, gathered)
28
+
29
+ variables = list(gathered.values())
30
+
31
+ N = len(variables)
32
+ _index = sorted(list(range(N)), key=lambda i: variables[i].latency * N + i)
33
+ variables = [variables[i] for i in _index]
34
+ index = {variables[i].id: i for i in range(N)}
35
+
36
+ return variables, index
37
+
38
+
39
+ def _comb_trace(inputs: Sequence[FixedVariable], outputs: Sequence[FixedVariable]):
40
+ variables, index = gather_variables(inputs, outputs)
41
+ ops: list[Op] = []
42
+ inp_uuids = {v.id: i for i, v in enumerate(inputs)}
43
+ for i, v in enumerate(variables):
44
+ if v.id in inp_uuids and v.opr != 'const':
45
+ id0 = inp_uuids[v.id]
46
+ ops.append(Op(id0, -1, -1, 0, v.unscaled.qint, v.latency, v.cost))
47
+ continue
48
+ if v.opr == 'new':
49
+ raise NotImplementedError('Operation "new" is only expected in the input list')
50
+ match v.opr:
51
+ case 'vadd':
52
+ v0, v1 = v._from
53
+ f0, f1 = v0._factor, v1._factor
54
+ id0, id1 = index[v0.id], index[v1.id]
55
+ sub = int(f1 < 0)
56
+ data = int(log2(abs(f1 / f0)))
57
+ assert id0 < i and id1 < i, f'{id0} {id1} {i} {v.id}'
58
+ ops.append(Op(id0, id1, sub, data, v.unscaled.qint, v.latency, v.cost))
59
+ case 'cadd':
60
+ v0 = v._from[0]
61
+ f0 = v0._factor
62
+ id0 = index[v0.id]
63
+ assert v._data is not None, 'cadd must have data'
64
+ qint = v.unscaled.qint
65
+ data = int(v._data / Decimal(qint.step))
66
+ assert id0 < i, f'{id0} {i} {v.id}'
67
+ ops.append(Op(id0, -1, 4, data, qint, v.latency, v.cost))
68
+ case 'wrap':
69
+ v0 = v._from[0]
70
+ id0 = index[v0.id]
71
+ assert id0 < i, f'{id0} {i} {v.id}'
72
+ opcode = -3 if v._from[0]._factor < 0 else 3
73
+ ops.append(Op(id0, -1, opcode, 0, v.unscaled.qint, v.latency, v.cost))
74
+ case 'relu':
75
+ v0 = v._from[0]
76
+ id0 = index[v0.id]
77
+ assert id0 < i, f'{id0} {i} {v.id}'
78
+ opcode = -2 if v._from[0]._factor < 0 else 2
79
+ ops.append(Op(id0, -1, opcode, 0, v.unscaled.qint, v.latency, v.cost))
80
+ case 'const':
81
+ qint = v.unscaled.qint
82
+ assert qint.min == qint.max, f'const {v.id} {qint.min} {qint.max}'
83
+ f = _const_f(qint.min)
84
+ step = 2.0**-f
85
+ qint = QInterval(qint.min, qint.min, step)
86
+ data = qint.min / step
87
+ ops.append(Op(-1, -1, 5, int(data), qint, v.latency, v.cost))
88
+ case _:
89
+ raise NotImplementedError(f'Operation "{v.opr}" is not supported in tracing')
90
+ out_index = [index[v.id] for v in outputs]
91
+ return ops, out_index
92
+
93
+
94
+ @overload
95
+ def comb_trace(inputs: Sequence[FixedVariable], outputs: Sequence[FixedVariable]) -> Solution: ...
96
+
97
+
98
+ @overload
99
+ def comb_trace(inputs: FixedVariableArray, outputs: FixedVariableArray) -> Solution: ...
100
+
101
+
102
+ def comb_trace(inputs, outputs):
103
+ inputs, outputs = list(np.ravel(inputs)), list(np.ravel(outputs))
104
+ ops, out_index = _comb_trace(inputs, outputs)
105
+ shape = len(inputs), len(outputs)
106
+ inp_shift = [0] * shape[0]
107
+ out_sf = [v._factor for v in outputs]
108
+ out_shift = [int(log2(abs(sf))) for sf in out_sf]
109
+ out_neg = [sf < 0 for sf in out_sf]
110
+
111
+ return Solution(
112
+ shape,
113
+ inp_shift,
114
+ out_index,
115
+ out_shift,
116
+ out_neg,
117
+ ops,
118
+ outputs[0].hwconf.carry_size,
119
+ outputs[0].hwconf.adder_size,
120
+ )
@@ -0,0 +1,65 @@
1
+ Metadata-Version: 2.4
2
+ Name: da4ml
3
+ Version: 0.2.0
4
+ Summary: Digital Arithmetic for Machine Learning
5
+ Author-email: Chang Sun <chsun@cern.ch>
6
+ License: GNU Lesser General Public License v3 (LGPLv3)
7
+ Project-URL: repository, https://github.com/calad0i/da4ml
8
+ Keywords: CMVM,distributed arithmetic,hls4ml,MCM,subexpression elimination
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Requires-Python: >=3.10
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: llvmlite>=0.44
21
+ Requires-Dist: numba>=0.61
22
+ Dynamic: license-file
23
+
24
+ # da4ml: Distributed Arithmetic for Machine Learning
25
+
26
+ This project performs Constant Matrix-Vector Multiplication (CMVM) with Distributed Arithmetic (DA) for Machine Learning (ML) on a Field Programmable Gate Arrays (FPGAs).
27
+
28
+ CMVM optimization is done through greedy CSE of two-term subexpressions, with possible Delay Constraints (DC). The optimization is done in jitted Python (Numba), and a list of optimized operations is generated as traced Python code.
29
+
30
+ At the moment, the project only generates Vitis HLS C++ code for the FPGA implementation of the optimized CMVM kernel. HDL code generation is planned for the future. Currently, the major use of this repository is through the `distributed_arithmetic` strategy in the [`hls4ml`](https://github.com/fastmachinelearning/hls4ml/) project.
31
+
32
+
33
+ ## Installation
34
+
35
+ The project is available on PyPI and can be installed with pip:
36
+
37
+ ```bash
38
+ pip install da4ml
39
+ ```
40
+
41
+ Notice that `numba>=6.0.0` is required for the project to work. The project does not work with `python<3.10`. If the project fails to compile, try upgrading `numba` and `llvmlite` to the latest versions.
42
+
43
+ ## `hls4ml`
44
+
45
+ The major use of this project is through the `distributed_arithmetic` strategy in the `hls4ml`:
46
+
47
+ ```python
48
+ model_hls = hls4ml.converters.convert_from_keras_model(
49
+ model,
50
+ hls_config={
51
+ 'Model': {
52
+ ...
53
+ 'Strategy': 'distributed_arithmetic',
54
+ },
55
+ ...
56
+ },
57
+ ...
58
+ )
59
+ ```
60
+
61
+ Currently, `Dense/Conv1D/Conv2D` layers are supported for both `io_parallel` and `io_stream` dataflows. However, notice that distributed arithmetic implies `reuse_factor=1`, as the whole kernel is implemented in combinational logic.
62
+
63
+ ### Notice
64
+
65
+ Currently, only the `da4ml-v3` branch of `hls4ml` supports the `distributed_arithmetic` strategy. The `da4ml-v3` branch is not yet merged into the `main` branch of `hls4ml`, so you need to install it from the GitHub repository.
@@ -0,0 +1,39 @@
1
+ da4ml/__init__.py,sha256=IETRRvzsJvPMLu1kzzi8UN5FYaM5MhNaXH2A_ZKr2_w,469
2
+ da4ml/_version.py,sha256=iB5DfB5V6YB5Wo4JmvS-txT42QtmGaWcWp3udRT7zCI,511
3
+ da4ml/cmvm/__init__.py,sha256=4Tbt913k9zP0w8R1p6Oss06v5jrManbUhskyHl6e-U0,154
4
+ da4ml/cmvm/api.py,sha256=wwJoEnKtv86q2eyIV2yh_DXIpJPgIW826WYweU8YxBY,9891
5
+ da4ml/cmvm/types.py,sha256=7Z6fSMZU3TcrivWCCWzRuygrSF2SIHydF1jR04-WWZc,17855
6
+ da4ml/cmvm/core/__init__.py,sha256=FRqePEE3SXSSTSjSvByw6QDJkzHCeyvyeihxgxkXVJI,7721
7
+ da4ml/cmvm/core/indexers.py,sha256=QjXgvExS-B2abHTJPDG4NufMdMEflo1i6cUhFOgJpH4,2945
8
+ da4ml/cmvm/core/state_opr.py,sha256=wLqO8qVuM2-qCE5LDeYJDNkUruIPHy63obsv4-x-aR8,8661
9
+ da4ml/cmvm/util/__init__.py,sha256=DkBlUEKA_Gu7n576ja_xZlAQfToWmNL9VXU-jmj6a-g,145
10
+ da4ml/cmvm/util/bit_decompose.py,sha256=SUco70HRYf4r1JU6BXwcgabDrhm_yAmucae5FC67i4I,2216
11
+ da4ml/cmvm/util/mat_decompose.py,sha256=eSJNlXwx_jxgqt5vLJrSLQaeq2ZXu8j9mC4d-eq883M,4094
12
+ da4ml/codegen/__init__.py,sha256=g58EgubgPPoiwRTBduSzm6hAc-poPcK6egdoECfPx9o,329
13
+ da4ml/codegen/cpp/__init__.py,sha256=Tw4XeU_oJsyUkTrsfEPuZ-r0rGAo8E2NX5wn_VTA7NM,90
14
+ da4ml/codegen/cpp/cpp_codegen.py,sha256=gfnDKdmdrubLp5ANeRXvYGTezCMGMet-IQaDlPymaHM,4787
15
+ da4ml/codegen/cpp/source/vitis.h,sha256=ovEefBOfW5-PXuDdRObPGNokGGFHiixDCpPWeTN6aTo,765
16
+ da4ml/codegen/cpp/source/vitis_bridge.h,sha256=XvvGw3A4eAaXKi5jp50bMKUsNfd5iQ-HhUKtsty1uns,567
17
+ da4ml/codegen/verilog/__init__.py,sha256=obRTdtMWhPHsxFHg2ADoPd3iDBEX8nk_6HuCet5EDz0,356
18
+ da4ml/codegen/verilog/comb.py,sha256=EZONCceEvIKHHF8yLY-i2V_U_8THw_dJEQWujjCJ5iI,5592
19
+ da4ml/codegen/verilog/io_wrapper.py,sha256=W7c9_jKRXqrcCm3M13NP7OJdbWtc3FSJgwjZlMEBDL4,7953
20
+ da4ml/codegen/verilog/pipeline.py,sha256=oi8qaobHnTk7gpzpvctuLxKzRIaRJPRFb3BYnnVSSsw,1557
21
+ da4ml/codegen/verilog/verilog_model.py,sha256=ewZuu653ZkhTmf0vRfn6QHYewLoEPdAl4sNwptlV0-M,10585
22
+ da4ml/codegen/verilog/source/build_binder.mk,sha256=rQbI98itE_b1wIQ_0uCXfBzNmGK2XT4vWmRyCJNnPKk,960
23
+ da4ml/codegen/verilog/source/build_prj.tcl,sha256=DMdIDBOF5stkm9Bknk6MiDpIK9-n85h6qwn4B6sV22A,3175
24
+ da4ml/codegen/verilog/source/ioutils.hh,sha256=1o1-oIyQyYc9CU91bBxuitVzzcrNT8p4MTarFKiJoG4,3967
25
+ da4ml/codegen/verilog/source/shift_adder.v,sha256=l2ofym56Y-_PeeY9fwkcZeW9MzrTL_WxvSTvoWERJrU,1885
26
+ da4ml/codegen/verilog/source/template.xdc,sha256=ON8i-TK96Yo6FoZ66WzcVKELajTF5TBmbWFbEilna2U,1142
27
+ da4ml/trace/__init__.py,sha256=1br9bWeFb33t69k6h1XQ50iJhLCqrRuEHtqEawELp-c,230
28
+ da4ml/trace/fixed_variable.py,sha256=KPkYnJgk8bK3W_D6pwCHEndoIgyRn4wyPzTye5Pnh1g,11647
29
+ da4ml/trace/fixed_variable_array.py,sha256=zYuYaXYK-LVcAEUDQ-TxKpR2a_B30RwBIpODBlp2Aq8,6400
30
+ da4ml/trace/pipeline.py,sha256=dYduPBNUeyW2Ws392hZNGJEo0qI5ynpn-iC2n7UVahk,5687
31
+ da4ml/trace/tracer.py,sha256=y0o_KeXCmlUwTLOxwmqZZsZm0oA8-kfu41J-I4-6LXU,4385
32
+ da4ml/trace/ops/__init__.py,sha256=qz0DLPUyxBAu08RCN22kCkJj1EPKanC8ey8NB3_K8co,1640
33
+ da4ml/trace/ops/conv_utils.py,sha256=LtgP3iSZ3fNV6QkEVBzT7ixt-7WTdmBDrFTtQ_9D5aE,3638
34
+ da4ml/trace/ops/einsum_utils.py,sha256=miyMyzJwBLpLTEzXU4vErPE1Xk-ckZG0cjhd13MLAuA,11325
35
+ da4ml-0.2.0.dist-info/licenses/LICENSE,sha256=46mU2C5kSwOnkqkw9XQAJlhBL2JAf1_uCD8lVcXyMRg,7652
36
+ da4ml-0.2.0.dist-info/METADATA,sha256=LdaPbfaPUbMKwLjydibFiFT7qSTs6XuILpnEkN7qEgs,2849
37
+ da4ml-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
38
+ da4ml-0.2.0.dist-info/top_level.txt,sha256=N0tnKVwRqFiffFdeAzCgFq71hUNySh5-ITbNd6-R58Q,6
39
+ da4ml-0.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,46 +0,0 @@
1
- import numpy as np
2
-
3
- from .fixed_variable import FixedVariable
4
- from .scoring import py_scorer
5
-
6
-
7
- def _balanced_reduction(vars: list[FixedVariable]):
8
- vars = vars.copy()
9
-
10
- n = len(vars)
11
- if n == 0:
12
- return FixedVariable.from_const(0.0)
13
- score_mat = np.full((len(vars), len(vars)), -np.inf, dtype=np.float32)
14
- for i in range(len(vars)):
15
- for j in range(i + 1, len(vars)):
16
- score_mat[i, j] = py_scorer(vars[i], vars[j]) - 1000 * (vars[i]._depth + vars[j]._depth)
17
-
18
- while n > 1:
19
- idx = np.argmax(score_mat)
20
- i, j = np.unravel_index(idx, score_mat.shape)
21
- vars[i] = vars[i] + vars[j]
22
- vars.pop(j)
23
- score_mat[j : n - 1] = score_mat[j + 1 : n]
24
- score_mat[:, j : n - 1] = score_mat[:, j + 1 : n]
25
- score_mat = score_mat[: n - 1, : n - 1]
26
- n -= 1
27
- for k in range(n):
28
- if k == i:
29
- continue
30
- if k < i:
31
- score_mat[k, i] = py_scorer(vars[k], vars[i]) - 1000 * (vars[k]._depth + vars[i]._depth)
32
- else:
33
- score_mat[i, k] = py_scorer(vars[i], vars[k]) - 1000 * (vars[i]._depth + vars[k]._depth)
34
-
35
- return vars[0]
36
-
37
-
38
- def balanced_reduction(vars: list[FixedVariable], signed=True):
39
- if not signed:
40
- return _balanced_reduction(vars)
41
- pos_vars = [v for v in vars if v._factor > 0]
42
- neg_vars = [v for v in vars if v._factor < 0]
43
- for v in neg_vars:
44
- v._factor = -v._factor
45
-
46
- return _balanced_reduction(pos_vars) - _balanced_reduction(neg_vars)