numbox 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
numbox/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.2.1'
1
+ __version__ = '0.2.3'
@@ -0,0 +1,125 @@
1
+ from hashlib import md5
2
+ from inspect import getfile, getmodule
3
+ from io import StringIO
4
+ from numba import njit, typeof
5
+ from numba.core.types import Type
6
+ from typing import Any, Callable, NamedTuple, Optional, Sequence, Union
7
+
8
+ from numbox.core.configurations import default_jit_options
9
+ from numbox.core.work.lowlevel_work_utils import ll_make_work
10
+ from numbox.utils.highlevel import cres
11
+
12
+
13
+ def _file_anchor():
14
+ raise NotImplementedError
15
+
16
+
17
+ class End(NamedTuple):
18
+ name: str
19
+ init_value: Any
20
+ ty: Optional[type | Type] = None
21
+
22
+
23
+ class Derived(NamedTuple):
24
+ name: str
25
+ init_value: Any
26
+ derive: Callable
27
+ sources: Sequence[Union['Derived', End]]
28
+ ty: Optional[type | Type] = None
29
+
30
+
31
+ SpecTy = Derived | End
32
+
33
+
34
+ def _input_line(input_: End, ns):
35
+ name_ = input_.name
36
+ init_ = input_.init_value
37
+ ty_ = input_.ty
38
+ if ty_ is not None:
39
+ type_name = f"{name_}_ty"
40
+ ns[type_name] = ty_
41
+ return f"""{name_} = ll_make_work("{name_}", {init_}, (), None, {type_name})"""
42
+ return f"""{name_} = ll_make_work("{name_}", {init_}, (), None)"""
43
+
44
+
45
+ def get_ty(spec_):
46
+ return spec_.ty or typeof(spec_.init_value)
47
+
48
+
49
+ def _derived_cres(ty, sources: Sequence[End], derive, jit_options=None):
50
+ jit_options = jit_options if jit_options is not None else {}
51
+ sources_ty = []
52
+ for source in sources:
53
+ source_ty = get_ty(source)
54
+ sources_ty.append(source_ty)
55
+ derive_sig = ty(*sources_ty)
56
+ derive_cres = cres(derive_sig, **jit_options)(derive)
57
+ return derive_cres
58
+
59
+
60
+ def _derived_line(derived_: Derived, ns: dict, _make_args: list, jit_options=None):
61
+ name_ = derived_.name
62
+ init_ = derived_.init_value
63
+ sources_ = ", ".join([s.name for s in derived_.sources])
64
+ sources_ = sources_ + ", " if "," not in sources_ else sources_
65
+ ty_ = get_ty(derived_)
66
+ derive_ = _derived_cres(ty_, derived_.sources, derived_.derive, jit_options)
67
+ derive_name = f"{name_}_derive"
68
+ _make_args.append(derive_name)
69
+ ns[derive_name] = derive_
70
+ return f"""{name_} = ll_make_work("{name_}", {init_}, ({sources_}), {derive_name})"""
71
+
72
+
73
+ def _verify_access_nodes(
74
+ all_inputs_: Sequence[End],
75
+ all_derived_: Sequence[Derived],
76
+ access_nodes: Sequence[SpecTy]
77
+ ):
78
+ for access_node in access_nodes:
79
+ assert access_node in all_inputs_ or access_node in all_derived_, f"{access_node} cannot be reached"
80
+
81
+
82
+ def code_block_hash(code_txt: str):
83
+ """ Re-compile and re-save cache when source code has changed. """
84
+ return md5(code_txt.encode("utf-8")).hexdigest()
85
+
86
+
87
+ def make_graph(
88
+ all_inputs_: Sequence[End],
89
+ all_derived_: Sequence[Derived],
90
+ access_nodes: SpecTy | Sequence[SpecTy],
91
+ jit_options: Optional[dict] = None
92
+ ):
93
+ if isinstance(access_nodes, SpecTy):
94
+ access_nodes = (access_nodes,)
95
+ _verify_access_nodes(all_inputs_, all_derived_, access_nodes)
96
+ if jit_options is None:
97
+ jit_options = {}
98
+ jit_options = {**default_jit_options, **jit_options}
99
+ ns = {
100
+ **getmodule(_file_anchor).__dict__,
101
+ **{"jit_options": jit_options, "ll_make_work": ll_make_work, "njit": njit}
102
+ }
103
+ _make_args = []
104
+ code_txt = StringIO()
105
+ for input_ in all_inputs_:
106
+ line_ = _input_line(input_, ns)
107
+ code_txt.write(f"\n\t{line_}")
108
+ for derived_ in all_derived_:
109
+ line_ = _derived_line(derived_, ns, _make_args, jit_options)
110
+ code_txt.write(f"\n\t{line_}")
111
+ hash_ = code_block_hash(code_txt.getvalue())
112
+ code_txt.write(f"""\n\taccess_tuple = ({", ".join([n.name for n in access_nodes])})""")
113
+ code_txt.write(f"\n\treturn access_tuple")
114
+ code_txt = code_txt.getvalue()
115
+ make_params = ", ".join(_make_args)
116
+ make_name = f"_make_{hash_}"
117
+ code_txt = f"""
118
+ @njit(**jit_options)
119
+ def {make_name}({make_params}):""" + code_txt + f"""
120
+ return_node_ = {make_name}({make_params})
121
+ """
122
+ code = compile(code_txt, getfile(_file_anchor), mode="exec")
123
+ exec(code, ns)
124
+ return_node_ = ns["return_node_"]
125
+ return return_node_
numbox/core/work/node.py CHANGED
@@ -40,6 +40,13 @@ class Node(NodeBase):
40
40
  def _all_inputs_names(self):
41
41
  return self.all_inputs_names()
42
42
 
43
+ def all_end_nodes(self):
44
+ return list(self._all_end_nodes())
45
+
46
+ @njit(**default_jit_options)
47
+ def _all_end_nodes(self):
48
+ return self.all_end_nodes()
49
+
43
50
  @njit(**default_jit_options)
44
51
  def depends_on(self, obj_):
45
52
  return self.depends_on(obj_)
@@ -116,6 +123,31 @@ def ol_all_inputs_names(self_ty):
116
123
  return _
117
124
 
118
125
 
126
+ @njit(**default_jit_options)
127
+ def _all_end_nodes(node_, names_):
128
+ node = _cast(node_, NodeType)
129
+ for i in range(len(node.inputs)):
130
+ input_ = node.get_input(i)
131
+ name_ = input_.name
132
+ if name_ not in names_ and len(_cast(input_, NodeType).inputs) == 0:
133
+ names_.append(name_)
134
+ _all_end_nodes(input_, names_)
135
+
136
+
137
+ @overload_method(NodeTypeClass, "all_end_nodes", strict=False, jit_options=default_jit_options)
138
+ def ol_all_end_nodes(self_ty):
139
+ def _(self):
140
+ names = List.empty_list(unicode_type)
141
+ for i in range(len(self.inputs)):
142
+ input_ = self.get_input(i)
143
+ name_ = input_.name
144
+ if name_ not in names and len(_cast(input_, NodeType).inputs) == 0:
145
+ names.append(name_)
146
+ _all_end_nodes(input_, names)
147
+ return names
148
+ return _
149
+
150
+
119
151
  @overload_method(NodeTypeClass, "depends_on", strict=False, jit_options=default_jit_options)
120
152
  def ol_depends_on(self_ty, obj_ty):
121
153
  if isinstance(obj_ty, (Literal, UnicodeType,)):
numbox/core/work/work.py CHANGED
@@ -3,11 +3,12 @@ from io import StringIO
3
3
  from numba import njit
4
4
  from numba.core.errors import NumbaError
5
5
  from numba.core.types import (
6
- DictType, FunctionType, Integer, Literal, NoneType, Tuple, unicode_type, UnicodeType
6
+ boolean, DictType, FunctionType, Literal, NoneType, Tuple, unicode_type, UnicodeType
7
7
  )
8
8
  from numba.core.typing.context import Context
9
9
  from numba.experimental.structref import define_boxing, new
10
10
  from numba.extending import intrinsic, overload, overload_method
11
+ from numba.typed.typeddict import Dict
11
12
  from numba.typed.typedlist import List
12
13
 
13
14
  from numbox.core.any.erased_type import ErasedType
@@ -115,6 +116,13 @@ class Work(NodeBase):
115
116
  def _all_inputs_names(self):
116
117
  return self.all_inputs_names()
117
118
 
119
+ def all_end_nodes(self):
120
+ return list(self._all_end_nodes())
121
+
122
+ @njit(**default_jit_options)
123
+ def _all_end_nodes(self):
124
+ return self.all_end_nodes()
125
+
118
126
  @njit(**default_jit_options)
119
127
  def depends_on(self, obj_):
120
128
  return self.depends_on(obj_)
@@ -284,12 +292,14 @@ def _make_combine_code(num_sources):
284
292
  for source_ind_ in range(num_sources):
285
293
  code_txt.write(_make_source_getter(source_ind_))
286
294
  code_txt.write("""
287
- def _combine_(work_, data_, ct_=0):
288
- if ct_ == len(data_):
289
- return ct_
295
+ def _combine_(work_, data_, harvested_=None):
296
+ if harvested_ is None:
297
+ harvested_ = Dict.empty(key_type=unicode_type, value_type=boolean)
298
+ if len(harvested_) == len(data_):
299
+ return
290
300
  work_name = work_.name
291
301
  if work_name in data_:
292
- ct_ += 1
302
+ harvested_[work_name] = True
293
303
  data_[work_name].reset(work_.data)""")
294
304
  if num_sources > 0:
295
305
  code_txt.write("""
@@ -297,10 +307,10 @@ def _combine_(work_, data_, ct_=0):
297
307
  for source_ind_ in range(num_sources):
298
308
  code_txt.write(f"""
299
309
  source_{source_ind_} = _get_source_{source_ind_}(sources)
300
- ct_ = source_{source_ind_}.combine(data_, ct_)
310
+ source_{source_ind_}.combine(data_, harvested_)
301
311
  """)
302
312
  code_txt.write("""
303
- return ct_
313
+ return
304
314
  """)
305
315
  return code_txt.getvalue()
306
316
 
@@ -309,7 +319,7 @@ _combine_registry = {}
309
319
 
310
320
 
311
321
  @overload_method(WorkTypeClass, "combine", strict=False, jit_options=default_jit_options)
312
- def ol_combine(work_ty, data_ty: DictType, ct_ty=Integer):
322
+ def ol_combine(work_ty, data_ty: DictType, harvested_ty=NoneType):
313
323
  """ Harvest nodes data from the graph with the root node `work`.
314
324
  `data` is provided as dictionary mapping node name to `Any` type
315
325
  containing erased payload `p` to be reset to `data`. """
@@ -318,7 +328,7 @@ def ol_combine(work_ty, data_ty: DictType, ct_ty=Integer):
318
328
  _combine = _combine_registry.get(num_sources, None)
319
329
  if _combine is None:
320
330
  code_txt = _make_combine_code(num_sources)
321
- ns = getmodule(_file_anchor).__dict__
331
+ ns = {**getmodule(_file_anchor).__dict__, **{"boolean": boolean, "Dict": Dict}}
322
332
  code = compile(code_txt, getfile(_file_anchor), mode="exec")
323
333
  exec(code, ns)
324
334
  _combine = ns["_combine_"]
@@ -416,6 +426,14 @@ def ol_all_inputs_names(self_ty):
416
426
  return _
417
427
 
418
428
 
429
+ @overload_method(WorkTypeClass, "all_end_nodes", strict=False, jit_options=default_jit_options)
430
+ def ol_all_end_nodes(self_ty):
431
+ def _(self):
432
+ node = self.as_node()
433
+ return node.all_end_nodes()
434
+ return _
435
+
436
+
419
437
  @overload_method(WorkTypeClass, "depends_on", strict=False, jit_options=default_jit_options)
420
438
  def ol_depends_on(self_ty, obj_ty):
421
439
  if isinstance(obj_ty, (Literal, UnicodeType,)):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: numbox
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Author: Mikhail Goykhman
5
5
  License: MIT License (with Citation Clause)
6
6
 
@@ -49,7 +49,7 @@ Documentation is available at [numbox](https://goykhman.github.io/numbox).
49
49
  - **Any**: A lightweight structure that wraps any value into the same type leveraging a variant of the type erasure technique.
50
50
  - **Bindings**: JIT-wrappers of functions imported from dynamically linked libraries.
51
51
  - **Node** A lightweight JIT-compatible graph node with type-erased dependencies in a uniform dynamically-sized vector (numba List).
52
- - **proxy**: Create a proxy for a decorated function with specified signatures, enabling efficient JIT compilation and caching.
52
+ - **Proxy**: Create a proxy for a decorated function with specified signatures, enabling efficient JIT compilation and caching.
53
53
  - **Work**: JIT-compatible unit of calculation work with dependencies, inner states, and custom calculation.
54
54
 
55
55
  ## Installation
@@ -1,4 +1,4 @@
1
- numbox/__init__.py,sha256=PmcQ2PI2oP8irnLtJLJby2YfW6sBvLAmL-VpABzTqwc,22
1
+ numbox/__init__.py,sha256=XtWUl6HPylv5jZLd2KkgtPptuzuda93kC2REmOrF-Cs,22
2
2
  numbox/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  numbox/core/configurations.py,sha256=0bCmxXL-QMwtvyIDhpXLeT-1KJMf_QpH0wLuEvYLGxQ,68
4
4
  numbox/core/any/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -15,13 +15,14 @@ numbox/core/bindings/utils.py,sha256=aRtN8oUYBk9vgoUGaUJosGx0Za-vvCNwwbZg_g_-LRs
15
15
  numbox/core/proxy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  numbox/core/proxy/proxy.py,sha256=kGYlEdLK40lxxu56e_S32s9YuQ6AuQnFegt5uQrUw5w,3889
17
17
  numbox/core/work/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ numbox/core/work/builder.py,sha256=QaklRb-IciVNPlk1WnQ3Hfc8tHbWbuAQqdrlp51XCcM,3991
18
19
  numbox/core/work/combine_utils.py,sha256=qTVGke_ydzaTQ7o29DFjZWZzKjRNKb0L3yJMaR3TLII,2430
19
20
  numbox/core/work/loader_utils.py,sha256=g83mDWidZJ8oLWP3I3rK8aGISYOO2S-w6HDgtosCyck,1572
20
21
  numbox/core/work/lowlevel_work_utils.py,sha256=TgRRcNfks0oaOXGXXr3ptafd_Xv_lpmH8sjBrJ9bPuI,5748
21
- numbox/core/work/node.py,sha256=4-0p4o94bg54JOvmgAlEVlgR5tSev5SGTqMq9Ht3bd0,4076
22
+ numbox/core/work/node.py,sha256=CMolyoRQjG2A-pTQqZQ0kxKOYTKipWRC0mu8RWHuTUI,5096
22
23
  numbox/core/work/node_base.py,sha256=uI7asM2itQcHuOByXyJtqvrd4ovW6EXDRdHYp3JVHQ0,998
23
24
  numbox/core/work/print_tree.py,sha256=y2u7xmbHvpcA57y8PrGSqOunLNCqhgNXdVtXHqvy1M0,2340
24
- numbox/core/work/work.py,sha256=fdqeqtH8gDXkHFnc0lsFVCn5idP4reSBhYP_ccoiZBQ,13971
25
+ numbox/core/work/work.py,sha256=oC71F2y-NkIF9Ox63trlv061nYBxla6R-F2Nj-7NXI4,14595
25
26
  numbox/core/work/work_utils.py,sha256=3q_nnBdzuxWWcdFpbRL2H0T9ZNkUgx1J1uhiZkX3YG4,1039
26
27
  numbox/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
28
  numbox/utils/highlevel.py,sha256=gXYzLsFPKHQfGqaz-Z4DWcCQddv0dS6SKwIsM-_xjYg,1487
@@ -29,8 +30,8 @@ numbox/utils/lowlevel.py,sha256=ACpf8_HyOIsobPlZ31bapkEyuCsV5dojW3AFrcKykrw,1071
29
30
  numbox/utils/meminfo.py,sha256=ykFi8Vt0WcHI3ztgMwvpn6NqaflDSQGL8tjI01jrzm0,1759
30
31
  numbox/utils/timer.py,sha256=KkAkWOHQ72WtPjyiAzt_tF1q0DcOnCDkITTb85DvkUM,553
31
32
  numbox/utils/void_type.py,sha256=IkZsjNeAIShYJtvWbvERdHnl_mbF1rCRWiM3gp6II8U,404
32
- numbox-0.2.1.dist-info/LICENSE,sha256=YYgNvjH_p6-1NsdrIqGJnr1GUbZzA_8DxsP6vVfM6nY,1446
33
- numbox-0.2.1.dist-info/METADATA,sha256=Hk7ONk6PjPPkOpgYG7rgzUAICj86x4ubbLZ_JWepmP8,2792
34
- numbox-0.2.1.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
35
- numbox-0.2.1.dist-info/top_level.txt,sha256=A67jOkfqidCSYYm6ifjN_WZyIiR1B27fjxv6nNbPvjc,7
36
- numbox-0.2.1.dist-info/RECORD,,
33
+ numbox-0.2.3.dist-info/LICENSE,sha256=YYgNvjH_p6-1NsdrIqGJnr1GUbZzA_8DxsP6vVfM6nY,1446
34
+ numbox-0.2.3.dist-info/METADATA,sha256=e7MF4nI2ntAqX8CGKWS1FvUiEswd2rtRW-rysf4WWJc,2792
35
+ numbox-0.2.3.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
36
+ numbox-0.2.3.dist-info/top_level.txt,sha256=A67jOkfqidCSYYm6ifjN_WZyIiR1B27fjxv6nNbPvjc,7
37
+ numbox-0.2.3.dist-info/RECORD,,
File without changes