da4ml 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of da4ml might be problematic. Click here for more details.

da4ml/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.2'
32
- __version_tuple__ = version_tuple = (0, 3, 2)
31
+ __version__ = version = '0.3.3'
32
+ __version_tuple__ = version_tuple = (0, 3, 3)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -46,7 +46,7 @@ def ssa_gen(sol: Solution, print_latency: bool, typestr_fn: Callable[[bool | int
46
46
  match op.opcode:
47
47
  case -1:
48
48
  # Input marker
49
- val = f'inp[{ops[op.id0].id0}]'
49
+ val = f'inp[{op.id0}]'
50
50
  case 0 | 1:
51
51
  # Common a+/-b<<shift op
52
52
  ref1 = f'bit_shift<{op.data}>(v{op.id1})' if op.data != 0 else f'v{op.id1}'
@@ -9,31 +9,26 @@ constexpr bool _openmp = true;
9
9
  constexpr bool _openmp = false;
10
10
  #endif
11
11
 
12
- template <typename CONFIG_T, typename T> void _inference(T *c_inp, T *c_out, size_t n_samples)
13
- {
12
+ template <typename CONFIG_T, typename T> void _inference(T *c_inp, T *c_out, size_t n_samples) {
14
13
  typename CONFIG_T::inp_t in_fixed_buf[CONFIG_T::N_inp];
15
14
  typename CONFIG_T::out_t out_fixed_buf[CONFIG_T::N_out];
16
15
 
17
- for(size_t i = 0; i < n_samples; ++i)
18
- {
19
- size_t offset_in = i * CONFIG_T::N_inp;
20
- size_t offset_out = i * CONFIG_T::N_out;
21
- for(size_t j = 0; j < CONFIG_T::N_inp; ++j)
22
- {
23
- in_fixed_buf[j] = c_inp[offset_in + j];
24
- }
16
+ for (size_t i = 0; i < n_samples; ++i) {
17
+ size_t offset_in = i * CONFIG_T::N_inp;
18
+ size_t offset_out = i * CONFIG_T::N_out;
19
+ for (size_t j = 0; j < CONFIG_T::N_inp; ++j) {
20
+ in_fixed_buf[j] = c_inp[offset_in + j];
21
+ }
25
22
 
26
- CONFIG_T::f(in_fixed_buf, out_fixed_buf);
23
+ CONFIG_T::f(in_fixed_buf, out_fixed_buf);
27
24
 
28
- for(size_t j = 0; j < CONFIG_T::N_out; ++j)
29
- {
30
- c_out[offset_out + j] = out_fixed_buf[j];
31
- }
25
+ for (size_t j = 0; j < CONFIG_T::N_out; ++j) {
26
+ c_out[offset_out + j] = out_fixed_buf[j];
32
27
  }
28
+ }
33
29
  }
34
30
 
35
- template <typename CONFIG_T, typename T> void batch_inference(T *c_inp, T *c_out, size_t n_samples)
36
- {
31
+ template <typename CONFIG_T, typename T> void batch_inference(T *c_inp, T *c_out, size_t n_samples) {
37
32
  #ifdef _OPENMP
38
33
  size_t n_max_threads = omp_get_max_threads();
39
34
  size_t n_samples_per_thread = std::max<size_t>(n_samples / n_max_threads, 32);
@@ -41,15 +36,14 @@ template <typename CONFIG_T, typename T> void batch_inference(T *c_inp, T *c_out
41
36
  n_thread += (n_samples % n_samples_per_thread) ? 1 : 0;
42
37
 
43
38
  #pragma omp parallel for num_threads(n_thread) schedule(static)
44
- for(size_t i = 0; i < n_thread; ++i)
45
- {
46
- size_t start = i * n_samples_per_thread;
47
- size_t end = std::min<size_t>(start + n_samples_per_thread, n_samples);
48
- size_t n_samples_this_thread = end - start;
49
- size_t offset_in = start * CONFIG_T::N_inp;
50
- size_t offset_out = start * CONFIG_T::N_out;
51
- _inference<CONFIG_T, T>(&c_inp[offset_in], &c_out[offset_out], n_samples_this_thread);
52
- }
39
+ for (size_t i = 0; i < n_thread; ++i) {
40
+ size_t start = i * n_samples_per_thread;
41
+ size_t end = std::min<size_t>(start + n_samples_per_thread, n_samples);
42
+ size_t n_samples_this_thread = end - start;
43
+ size_t offset_in = start * CONFIG_T::N_inp;
44
+ size_t offset_out = start * CONFIG_T::N_out;
45
+ _inference<CONFIG_T, T>(&c_inp[offset_in], &c_out[offset_out], n_samples_this_thread);
46
+ }
53
47
  #else
54
48
  _inference<CONFIG_T, T>(c_inp, c_out, n_samples);
55
49
  #endif
@@ -1,14 +1,16 @@
1
1
  #pragma once
2
- #include "ap_types/ap_fixed.h"
2
+ #include "ap_fixed.h"
3
3
 
4
- template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_fixed<b, i + s> bit_shift(ap_fixed<b, i, Q, O, N> x) {
4
+ template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N>
5
+ ap_fixed<b, i + s> bit_shift(ap_fixed<b, i, Q, O, N> x) {
5
6
  #pragma HLS INLINE
6
7
  ap_fixed<b, i + s> r;
7
8
  r.range() = x.range();
8
9
  return r;
9
10
  };
10
11
 
11
- template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N> ap_ufixed<b, i + s> bit_shift(ap_ufixed<b, i, Q, O, N> x) {
12
+ template <int s, int b, int i, ap_q_mode Q, ap_o_mode O, int N>
13
+ ap_ufixed<b, i + s> bit_shift(ap_ufixed<b, i, Q, O, N> x) {
12
14
  #pragma HLS INLINE
13
15
  ap_ufixed<b, i + s> r;
14
16
  r.range() = x.range();
@@ -6,8 +6,8 @@ import keras
6
6
  import numpy as np
7
7
  from keras import KerasTensor, Operation
8
8
 
9
- from ...trace import FixedVariableArray, HWConfig, comb_trace
10
- from ...trace.fixed_variable_array import FixedVariableArrayInput
9
+ from ...trace import FixedVariableArray, FixedVariableArrayInput, HWConfig, comb_trace
10
+ from ...trace.fixed_variable import FixedVariable
11
11
  from .replica import _registry
12
12
 
13
13
 
@@ -55,6 +55,8 @@ def replace_tensors(tensor_map: dict[KerasTensor, FixedVariableArray], obj: Any)
55
55
  def _flatten_arr(args: Any) -> FixedVariableArray:
56
56
  if isinstance(args, FixedVariableArray):
57
57
  return np.ravel(args) # type: ignore
58
+ if isinstance(args, FixedVariable):
59
+ return FixedVariableArray(np.array([args]))
58
60
  if not isinstance(args, Sequence):
59
61
  return None # type: ignore
60
62
  args = [_flatten_arr(a) for a in args]
@@ -265,6 +265,10 @@ class FixedVariable:
265
265
  def __sub__(self, other: 'FixedVariable|int|float|Decimal'):
266
266
  return self + (-other)
267
267
 
268
+ def __truediv__(self, other: 'int|float|Decimal'):
269
+ assert not isinstance(other, FixedVariable), 'Division by variable is not supported'
270
+ return self * (1 / other)
271
+
268
272
  def __mul__(self, other: 'FixedVariable|int|float|Decimal') -> 'FixedVariable':
269
273
  if other == 0:
270
274
  return FixedVariable(0, 0, 1, hwconf=self.hwconf, opr='const')
@@ -164,6 +164,10 @@ class FixedVariableArray:
164
164
  latency: NDArray[np.floating] | float = 0.0,
165
165
  solver_options: dict[str, Any] | None = None,
166
166
  ):
167
+ mask = k + i + f <= 0
168
+ k = np.where(mask, 0, k)
169
+ i = np.where(mask, 0, i)
170
+ f = np.where(mask, 0, f)
167
171
  step = 2.0**-f
168
172
  _high = 2.0**i
169
173
  high, low = _high - step, -_high * k
@@ -99,7 +99,7 @@ def reduce(operator: Callable[[T, T], T], x: TA, axis: int | Sequence[int] | Non
99
99
  r = _arr.reshape(target_shape) # type: ignore
100
100
 
101
101
  if isinstance(x, FixedVariableArray):
102
- ret = FixedVariableArray(r, solver_config)
103
- if ret.size == 1 and not keepdims:
104
- return ret.ravel()[0] # type: ignore
102
+ r = FixedVariableArray(r, solver_config)
103
+ if r.size == 1 and not keepdims:
104
+ return r.ravel()[0] # type: ignore
105
105
  return r if r.size > 1 or keepdims else r.ravel()[0] # type: ignore
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: da4ml
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Digital Arithmetic for Machine Learning
5
5
  Author-email: Chang Sun <chsun@cern.ch>
6
6
  License: GNU Lesser General Public License v3 (LGPLv3)
@@ -1,5 +1,5 @@
1
1
  da4ml/__init__.py,sha256=IETRRvzsJvPMLu1kzzi8UN5FYaM5MhNaXH2A_ZKr2_w,469
2
- da4ml/_version.py,sha256=e8NqPtZ8fggRgk3GPrqZ_U_BDV8aSULw1u_Gn9NNbnk,704
2
+ da4ml/_version.py,sha256=lemL_4Kl75FgrO6lVuFrrtw6-Dcf9wtXBalKkXuzkO4,704
3
3
  da4ml/cmvm/__init__.py,sha256=4Tbt913k9zP0w8R1p6Oss06v5jrManbUhskyHl6e-U0,154
4
4
  da4ml/cmvm/api.py,sha256=JpecMt6g8zutGh_uWT61_0iX8TuXct7-jq7N7HMIsgA,9626
5
5
  da4ml/cmvm/types.py,sha256=O8BuBZ2SyucxoXt_KbulAuHNgim7Ls3M6Ovw8prLgXM,21340
@@ -11,11 +11,11 @@ da4ml/cmvm/util/bit_decompose.py,sha256=SUco70HRYf4r1JU6BXwcgabDrhm_yAmucae5FC67
11
11
  da4ml/cmvm/util/mat_decompose.py,sha256=eSJNlXwx_jxgqt5vLJrSLQaeq2ZXu8j9mC4d-eq883M,4094
12
12
  da4ml/codegen/__init__.py,sha256=Chdh3oO_vLR4saLbT9VxBPz_0wlEzxJldFSZaVUJo7U,331
13
13
  da4ml/codegen/cpp/__init__.py,sha256=SIePoi_T4iJph50OQUosAnaVuLCckukYjLxp91Y8xQs,134
14
- da4ml/codegen/cpp/cpp_codegen.py,sha256=ot293c8aHBx7wy1R7hnB9IVI22jYMO0476ghYKD8ECA,6162
14
+ da4ml/codegen/cpp/cpp_codegen.py,sha256=I3YcxK524_oJ7jebxOlRGuYbN2uCY5mpKACoQShqZxs,6153
15
15
  da4ml/codegen/cpp/hls_model.py,sha256=J5lnB8sAvMy0Bo5MSJOpgyUm1tzEJqBxgPTlOd38Gbg,8978
16
- da4ml/codegen/cpp/source/binder_util.hh,sha256=pBVmhXIDvdCr8n2wwYehc3Fpp60sWYrrZaDoP3x9JZE,1880
16
+ da4ml/codegen/cpp/source/binder_util.hh,sha256=ClECVxcEynE_9i4jWCV4y1dnadG3wFqLZfjxg4qHFQQ,1752
17
17
  da4ml/codegen/cpp/source/build_binder.mk,sha256=RLu4TP28aJsveyMOHxuDRGEJVoIPMo9T8WyPtqnmtbQ,584
18
- da4ml/codegen/cpp/source/vitis_bitshift.hh,sha256=yFpYCVJ8gof-EzPjkIWWZYmdFh_wk133Pxzs7f61IQo,774
18
+ da4ml/codegen/cpp/source/vitis_bitshift.hh,sha256=u8wjT_cRn7bXcbC5pH3-rS76ekRbwv-VWAAdaP52-dw,765
19
19
  da4ml/codegen/cpp/source/ap_types/ap_binary.h,sha256=yOcafu2IofstDqxn0wDq8vY3JIwZQ9H5z6IY1dEqMr0,2764
20
20
  da4ml/codegen/cpp/source/ap_types/ap_common.h,sha256=1hJY9uvKOdwRSSll5uehUISZR4tsSsQ1z4PNRUc44KU,10180
21
21
  da4ml/codegen/cpp/source/ap_types/ap_decl.h,sha256=z1HsH-2RSvSoofTZR7RHeqIfAnEYVuHcIu_ute9gjEg,6473
@@ -48,19 +48,19 @@ da4ml/codegen/verilog/source/shift_adder.v,sha256=qrpXBX9bhHI-o75v5zshOfq0giEATv
48
48
  da4ml/codegen/verilog/source/template.xdc,sha256=GlSRy8tw_orohSuUwUSNEYJLLkAAHttGTfLTcQqRQDg,1262
49
49
  da4ml/converter/__init__.py,sha256=x7J2PEXYZsVWffRAkucLxbwzzU404eaijMdLwdhBxtY,57
50
50
  da4ml/converter/hgq2/__init__.py,sha256=-gnT_7zXY-KQtPLxsqngwDKZ2TUIynn996pUjjB03B8,59
51
- da4ml/converter/hgq2/parser.py,sha256=O55QTrlkev0lvxiIweXlTGG9RPcfjdrJgpkZc-rwetg,5472
51
+ da4ml/converter/hgq2/parser.py,sha256=Yc5V-B_aEslqIXXJihRi3GMjF9vMkmUQ2_yHMGHMPVo,5573
52
52
  da4ml/converter/hgq2/replica.py,sha256=aKi6BF2x4s3VUF1Q-__GE4-is9eSC3H8TGFDT05vTWc,16292
53
53
  da4ml/trace/__init__.py,sha256=dv-rti3t8iE0RqeThfOb40mAg8FZB2WkkGQq3enJft0,282
54
- da4ml/trace/fixed_variable.py,sha256=samW_xChnERsMaXVQz7aKUQJsIrnSHu2ox4x9dMzhR0,20918
55
- da4ml/trace/fixed_variable_array.py,sha256=1gGSc-ZmRG59sUXvgdN7pulG4XhacAGmgSmzq7nAhJ4,12846
54
+ da4ml/trace/fixed_variable.py,sha256=7vaXFZToCVzPtUZcHv4aoqpqJp46SHUzSWTQijVT0os,21101
55
+ da4ml/trace/fixed_variable_array.py,sha256=mJj9aU-jLCPVkFXrTbcRQndtUKEuhVwiFUGVSGX7PHE,12975
56
56
  da4ml/trace/pipeline.py,sha256=AVeO9BNpQlo_WO6S1nQl7RxiHs5VFRR10tWMg_36C2o,5354
57
57
  da4ml/trace/tracer.py,sha256=xnaVO4oTWwasfiEBqqeY9o60Lek3eX65IIbvB7JtVKQ,6099
58
58
  da4ml/trace/ops/__init__.py,sha256=fz5Cg7ZQqPkZlUj4bIOKY6aaoA1fX_G22TeA8I1n4qY,2166
59
59
  da4ml/trace/ops/conv_utils.py,sha256=Yn73t4F6Tcs1hBwK08L1DPOin2HYVcng4PSkU4vuZFo,8245
60
60
  da4ml/trace/ops/einsum_utils.py,sha256=ODofbvR98FwKBTDZsJ0ObbMjU9_GjPu5AbGuWX6sdCY,11453
61
- da4ml/trace/ops/reduce_utils.py,sha256=9bi-fizhl1BPy9quQzaWMs83eCDSRMFag2PuvqlVFgI,3500
62
- da4ml-0.3.2.dist-info/licenses/LICENSE,sha256=46mU2C5kSwOnkqkw9XQAJlhBL2JAf1_uCD8lVcXyMRg,7652
63
- da4ml-0.3.2.dist-info/METADATA,sha256=zZnCaLH3ndDuURdIXAZD37A06L0ommMlBzfuL93lG-E,4055
64
- da4ml-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- da4ml-0.3.2.dist-info/top_level.txt,sha256=N0tnKVwRqFiffFdeAzCgFq71hUNySh5-ITbNd6-R58Q,6
66
- da4ml-0.3.2.dist-info/RECORD,,
61
+ da4ml/trace/ops/reduce_utils.py,sha256=vQjEUUbvnW8inAYJWHDzgy-PbgwIdHlH-uzPzSEvrSc,3494
62
+ da4ml-0.3.3.dist-info/licenses/LICENSE,sha256=46mU2C5kSwOnkqkw9XQAJlhBL2JAf1_uCD8lVcXyMRg,7652
63
+ da4ml-0.3.3.dist-info/METADATA,sha256=C3NAvObpQ5xNOmQQ-cE77AJMFevKJ0gCCO-BrlQpAeA,4055
64
+ da4ml-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
+ da4ml-0.3.3.dist-info/top_level.txt,sha256=N0tnKVwRqFiffFdeAzCgFq71hUNySh5-ITbNd6-R58Q,6
66
+ da4ml-0.3.3.dist-info/RECORD,,
File without changes