numbox 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numbox might be problematic. Click here for more details.

numbox/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.2.5'
1
+ __version__ = '0.2.6'
@@ -1,10 +1,11 @@
1
+ from collections import namedtuple
1
2
  from hashlib import md5
2
3
  from inspect import getfile, getmodule, getsource
3
4
  from io import StringIO
4
5
  from itertools import chain
5
6
  from numba import njit, typeof
6
7
  from numba.core.types import Type
7
- from typing import Any, Callable, NamedTuple, Optional, Sequence, Union
8
+ from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Union
8
9
 
9
10
  from numbox.core.configurations import default_jit_options
10
11
  from numbox.core.work.lowlevel_work_utils import ll_make_work
@@ -15,7 +16,7 @@ def _file_anchor():
15
16
  raise NotImplementedError
16
17
 
17
18
 
18
- _nodes_names = set()
19
+ _specs_registry = dict()
19
20
 
20
21
 
21
22
  class _End(NamedTuple):
@@ -27,10 +28,11 @@ class _End(NamedTuple):
27
28
  def _new(cls, super_proxy, *args, **kwargs):
28
29
  name = kwargs.get("name")
29
30
  assert name, "`name` key-word argument has not been provided"
30
- if name in _nodes_names:
31
+ if name in _specs_registry:
31
32
  raise ValueError(f"Node '{name}' has already been defined on this graph. Pick a different name.")
32
- _nodes_names.add(name)
33
- return super_proxy.__new__(cls, *args, **kwargs)
33
+ spec_ = super_proxy.__new__(cls, *args, **kwargs)
34
+ _specs_registry[name] = spec_
35
+ return spec_
34
36
 
35
37
 
36
38
  class End(_End):
@@ -111,29 +113,34 @@ def _derived_line(
111
113
  return f"""{name_} = ll_make_work("{name_}", {init_name}, ({sources_}), {derive_name})"""
112
114
 
113
115
 
114
- def _verify_access_nodes(
115
- all_inputs_: Sequence[End],
116
- all_derived_: Sequence[Derived],
117
- access_nodes: Sequence[SpecTy]
118
- ):
119
- for access_node in access_nodes:
120
- assert access_node in all_inputs_ or access_node in all_derived_, f"{access_node} cannot be reached"
121
-
122
-
123
116
  def code_block_hash(code_txt: str):
124
117
  """ Re-compile and re-save cache when source code has changed. """
125
118
  return md5(code_txt.encode("utf-8")).hexdigest()
126
119
 
127
120
 
128
- def make_graph(
129
- all_inputs_: Sequence[End],
130
- all_derived_: Sequence[Derived],
131
- access_nodes: SpecTy | Sequence[SpecTy],
132
- jit_options: Optional[dict] = None
133
- ):
134
- if isinstance(access_nodes, SpecTy):
135
- access_nodes = (access_nodes,)
136
- _verify_access_nodes(all_inputs_, all_derived_, access_nodes)
121
+ def _infer_end_and_derived_nodes(spec: SpecTy, all_inputs_: Dict[str, Type], all_derived_: Dict[str, Type]):
122
+ if spec.name in all_inputs_ or spec.name in all_derived_:
123
+ return
124
+ if isinstance(spec, End):
125
+ all_inputs_[spec.name] = get_ty(spec)
126
+ return
127
+ for source in spec.sources:
128
+ _infer_end_and_derived_nodes(source, all_inputs_, all_derived_)
129
+ all_derived_[spec.name] = get_ty(spec)
130
+
131
+
132
+ def infer_end_and_derived_nodes(access_nodes: SpecTy | Sequence[SpecTy]):
133
+ all_inputs_ = dict()
134
+ all_derived_ = dict()
135
+ for access_node in access_nodes:
136
+ _infer_end_and_derived_nodes(access_node, all_inputs_, all_derived_)
137
+ all_inputs_lst = [_specs_registry[name] for name in all_inputs_.keys()]
138
+ all_derived_lst = [_specs_registry[name] for name in all_derived_.keys()]
139
+ return all_inputs_lst, all_derived_lst
140
+
141
+
142
+ def make_graph(*access_nodes: SpecTy | Sequence[SpecTy], jit_options: Optional[dict] = None):
143
+ all_inputs_, all_derived_ = infer_end_and_derived_nodes(access_nodes)
137
144
  if jit_options is None:
138
145
  jit_options = {}
139
146
  jit_options = {**default_jit_options, **jit_options}
@@ -144,26 +151,30 @@ def make_graph(
144
151
  _make_args = []
145
152
  code_txt = StringIO()
146
153
  initializers = {}
147
- derive_hashes=[]
154
+ derive_hashes = []
148
155
  for input_ in all_inputs_:
149
156
  line_ = _input_line(input_, ns, initializers)
150
157
  code_txt.write(f"\n\t{line_}")
151
158
  for derived_ in all_derived_:
152
159
  line_ = _derived_line(derived_, ns, initializers, derive_hashes, _make_args, jit_options)
153
160
  code_txt.write(f"\n\t{line_}")
154
- hash_str = f"code_block = {code_txt.getvalue()} initializers = {list(initializers.values())} derive_hashes = {derive_hashes}"
161
+ hash_str = f"code_block = {code_txt.getvalue()} initializers = {list(initializers.values())} derive_hashes = {derive_hashes}" # noqa: E501
155
162
  hash_ = code_block_hash(hash_str)
156
- code_txt.write(f"""\n\taccess_tuple = ({", ".join([n.name for n in access_nodes])})""")
157
- code_txt.write(f"\n\treturn access_tuple")
163
+ access_nodes_names = [n.name for n in access_nodes]
164
+ tup_ = ", ".join(access_nodes_names)
165
+ tup_ = tup_ + ", " if ", " not in tup_ else tup_
166
+ code_txt.write(f"""\n\taccess_tuple = ({tup_})""")
167
+ code_txt.write("\n\treturn access_tuple")
158
168
  code_txt = code_txt.getvalue()
159
169
  make_params = ", ".join(chain(_make_args, initializers.keys()))
160
170
  make_name = f"_make_{hash_}"
161
171
  code_txt = f"""
162
172
  @njit(**jit_options)
163
173
  def {make_name}({make_params}):""" + code_txt + f"""
164
- return_node_ = {make_name}({make_params})
174
+ access_tuple_ = {make_name}({make_params})
165
175
  """
166
176
  code = compile(code_txt, getfile(_file_anchor), mode="exec")
167
177
  exec(code, ns)
168
- return_node_ = ns["return_node_"]
169
- return return_node_
178
+ access_tuple_ = ns["access_tuple_"]
179
+ Access = namedtuple("Access", access_nodes_names)
180
+ return Access(*access_tuple_)
@@ -30,6 +30,9 @@ def _explain(work: Work, derivation_: list, derived_: set):
30
30
 
31
31
 
32
32
  def explain(work: Work):
33
+ all_end_nodes = work.all_end_nodes()
33
34
  derivation = []
34
35
  _explain(work, derivation, set())
35
- return "\n".join(derivation)
36
+ explain_txt = f"All required end nodes: {all_end_nodes}\n\n"
37
+ explain_txt += "\n".join(derivation)
38
+ return explain_txt
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: numbox
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Author: Mikhail Goykhman
5
5
  License: MIT License (with Citation Clause)
6
6
 
@@ -1,4 +1,4 @@
1
- numbox/__init__.py,sha256=8f4ummWTPl6PtdE9DzSUgZhPjGioA_NDC-hDpqzB9qI,22
1
+ numbox/__init__.py,sha256=T150U4daRZ7ULOwxGUzzoBDAqm3bW9UydvCf0KJii9I,22
2
2
  numbox/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  numbox/core/configurations.py,sha256=0bCmxXL-QMwtvyIDhpXLeT-1KJMf_QpH0wLuEvYLGxQ,68
4
4
  numbox/core/any/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -15,9 +15,9 @@ numbox/core/bindings/utils.py,sha256=aRtN8oUYBk9vgoUGaUJosGx0Za-vvCNwwbZg_g_-LRs
15
15
  numbox/core/proxy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  numbox/core/proxy/proxy.py,sha256=kGYlEdLK40lxxu56e_S32s9YuQ6AuQnFegt5uQrUw5w,3889
17
17
  numbox/core/work/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- numbox/core/work/builder.py,sha256=tfrizyZuBy21zsoILELzQIj-wtYAuPitDqtVg7V1atw,5362
18
+ numbox/core/work/builder.py,sha256=WFZmH8dHTQNx1TwjpJVtZXN5ud4K2j500HgbF-Msu4c,6063
19
19
  numbox/core/work/combine_utils.py,sha256=qTVGke_ydzaTQ7o29DFjZWZzKjRNKb0L3yJMaR3TLII,2430
20
- numbox/core/work/explain.py,sha256=qEorYHRSe8P2MygodCYsz57UtWeRfaS27xHSw3hBgmo,1063
20
+ numbox/core/work/explain.py,sha256=ESwvsTgfe0w7UnM13yyVpVDtfJyAK2A1sNdF3RNb-jU,1200
21
21
  numbox/core/work/loader_utils.py,sha256=g83mDWidZJ8oLWP3I3rK8aGISYOO2S-w6HDgtosCyck,1572
22
22
  numbox/core/work/lowlevel_work_utils.py,sha256=TgRRcNfks0oaOXGXXr3ptafd_Xv_lpmH8sjBrJ9bPuI,5748
23
23
  numbox/core/work/node.py,sha256=CMolyoRQjG2A-pTQqZQ0kxKOYTKipWRC0mu8RWHuTUI,5096
@@ -31,8 +31,8 @@ numbox/utils/lowlevel.py,sha256=ACpf8_HyOIsobPlZ31bapkEyuCsV5dojW3AFrcKykrw,1071
31
31
  numbox/utils/meminfo.py,sha256=ykFi8Vt0WcHI3ztgMwvpn6NqaflDSQGL8tjI01jrzm0,1759
32
32
  numbox/utils/timer.py,sha256=KkAkWOHQ72WtPjyiAzt_tF1q0DcOnCDkITTb85DvkUM,553
33
33
  numbox/utils/void_type.py,sha256=IkZsjNeAIShYJtvWbvERdHnl_mbF1rCRWiM3gp6II8U,404
34
- numbox-0.2.5.dist-info/LICENSE,sha256=YYgNvjH_p6-1NsdrIqGJnr1GUbZzA_8DxsP6vVfM6nY,1446
35
- numbox-0.2.5.dist-info/METADATA,sha256=kHcNZstr-whJ0XoJ6IcoXACjnAn33yjvJgdn4Ksx6sk,2792
36
- numbox-0.2.5.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
37
- numbox-0.2.5.dist-info/top_level.txt,sha256=A67jOkfqidCSYYm6ifjN_WZyIiR1B27fjxv6nNbPvjc,7
38
- numbox-0.2.5.dist-info/RECORD,,
34
+ numbox-0.2.6.dist-info/LICENSE,sha256=YYgNvjH_p6-1NsdrIqGJnr1GUbZzA_8DxsP6vVfM6nY,1446
35
+ numbox-0.2.6.dist-info/METADATA,sha256=H7eyqJTneiiqllMj4Boqaxpj-gwZpLjkQlGEhiOuJlE,2792
36
+ numbox-0.2.6.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
37
+ numbox-0.2.6.dist-info/top_level.txt,sha256=A67jOkfqidCSYYm6ifjN_WZyIiR1B27fjxv6nNbPvjc,7
38
+ numbox-0.2.6.dist-info/RECORD,,
File without changes