mapFolding 0.3.6__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  from mapFolding import getPathFilenameFoldsTotal, indexMy, indexTrack
2
- from mapFolding import make_dtype, datatypeLargeDEFAULT, datatypeMediumDEFAULT, datatypeSmallDEFAULT, datatypeModuleDEFAULT
2
+ from mapFolding import setDatatypeElephino, setDatatypeFoldsTotal, setDatatypeLeavesTotal, setDatatypeModule, hackSSOTdatatype
3
3
  from someAssemblyRequired import makeStateJob
4
4
  from typing import Optional
5
5
  import importlib
@@ -13,9 +13,36 @@ import python_minifier
13
13
  identifierCallableLaunch = "goGoGadgetAbsurdity"
14
14
 
15
15
  def makeStrRLEcompacted(arrayTarget: numpy.ndarray, identifierName: str) -> str:
16
- def process_nested_array(arraySlice):
16
+ """Converts a NumPy array into a compressed string representation using run-length encoding (RLE).
17
+
18
+ This function takes a NumPy array and converts it into an optimized string representation by:
19
+ 1. Compressing consecutive sequences of numbers into range objects
20
+ 2. Minimizing repeated zeros using array multiplication syntax
21
+ 3. Converting the result into a valid Python array initialization statement
22
+
23
+ Parameters:
24
+ arrayTarget (numpy.ndarray): The input NumPy array to be converted
25
+ identifierName (str): The variable name to use in the output string
26
+
27
+ Returns:
28
+ str: A string containing Python code that recreates the input array in compressed form.
29
+ Format: "{identifierName} = numpy.array({compressed_data}, dtype=numpy.{dtype})"
30
+
31
+ Example:
32
+ >>> arr = numpy.array([[0,0,0,1,2,3,4,0,0]])
33
+ >>> print(makeStrRLEcompacted(arr, "myArray"))
34
+ "myArray = numpy.array([[0]*3,*range(1,5),[0]*2], dtype=numpy.int64)"
35
+
36
+ Notes:
37
+ - Sequences of 4 or fewer numbers are kept as individual values
38
+ - Sequences longer than 4 numbers are converted to range objects
39
+ - Consecutive zeros are compressed using multiplication syntax
40
+ - The function preserves the original array's dtype
41
+ """
42
+
43
+ def compressRangesNDArrayNoFlatten(arraySlice):
17
44
  if isinstance(arraySlice, numpy.ndarray) and arraySlice.ndim > 1:
18
- return [process_nested_array(arraySlice[index]) for index in range(arraySlice.shape[0])]
45
+ return [compressRangesNDArrayNoFlatten(arraySlice[index]) for index in range(arraySlice.shape[0])]
19
46
  elif isinstance(arraySlice, numpy.ndarray) and arraySlice.ndim == 1:
20
47
  listWithRanges = []
21
48
  for group in more_itertools.consecutive_groups(arraySlice.tolist()):
@@ -28,7 +55,7 @@ def makeStrRLEcompacted(arrayTarget: numpy.ndarray, identifierName: str) -> str:
28
55
  return listWithRanges
29
56
  return arraySlice
30
57
 
31
- arrayAsNestedLists = process_nested_array(arrayTarget)
58
+ arrayAsNestedLists = compressRangesNDArrayNoFlatten(arrayTarget)
32
59
 
33
60
  stringMinimized = python_minifier.minify(str(arrayAsNestedLists))
34
61
  commaZeroMaximum = arrayTarget.shape[-1] - 1
@@ -40,25 +67,47 @@ def makeStrRLEcompacted(arrayTarget: numpy.ndarray, identifierName: str) -> str:
40
67
 
41
68
  return f"{identifierName} = numpy.array({stringMinimized}, dtype=numpy.{arrayTarget.dtype})"
42
69
 
43
- def writeModuleWithNumba(listDimensions, **keywordArguments: Optional[str]) -> pathlib.Path:
44
- datatypeLarge = keywordArguments.get('datatypeLarge', datatypeLargeDEFAULT)
45
- datatypeMedium = keywordArguments.get('datatypeMedium', datatypeMediumDEFAULT)
46
- datatypeSmall = keywordArguments.get('datatypeSmall', datatypeSmallDEFAULT)
47
- datatypeModule = keywordArguments.get('datatypeModule', datatypeModuleDEFAULT)
48
-
49
- dtypeLarge = make_dtype(datatypeLarge, datatypeModule) # type: ignore
50
- dtypeMedium = make_dtype(datatypeMedium, datatypeModule) # type: ignore
51
- dtypeSmall = make_dtype(datatypeSmall, datatypeModule) # type: ignore
52
-
53
- stateJob = makeStateJob(listDimensions, writeJob=False, dtypeLarge = dtypeLarge, dtypeMedium = dtypeMedium, dtypeSmall = dtypeSmall)
70
+ def writeModuleWithNumba(listDimensions) -> pathlib.Path:
71
+ """
72
+ Writes a Numba-optimized Python module for map folding calculations.
73
+
74
+ This function takes map dimensions and generates a specialized Python module with Numba
75
+ optimizations. It processes a sequential counting algorithm, adds Numba decorators and
76
+ necessary data structures, and writes the resulting code to a file.
77
+
78
+ Parameters:
79
+ listDimensions: List of integers representing the dimensions of the map to be folded.
80
+
81
+ Returns:
82
+ pathlib.Path: Path to the generated Python module file.
83
+
84
+ The generated module includes:
85
+ - Numba JIT compilation decorators for performance optimization
86
+ - Required numpy and numba imports
87
+ - Dynamic and static data structures needed for folding calculations
88
+ - Processed algorithm from the original sequential counter
89
+ - Launch code for standalone execution
90
+ - Code to write the final fold count to a file
91
+ The function handles:
92
+ - Translation of original code to Numba-compatible syntax
93
+ - Insertion of pre-calculated values from the state job
94
+ - Management of variable declarations and assignments
95
+ - Setup of proper data types for Numba optimization
96
+ - Organization of the output file structure
97
+
98
+ Note:
99
+ The generated module requires Numba and numpy to be installed.
100
+ The output file will be placed in the same directory as the folds total file,
101
+ with a .py extension.
102
+ """
103
+ stateJob = makeStateJob(listDimensions, writeJob=False)
54
104
  pathFilenameFoldsTotal = getPathFilenameFoldsTotal(stateJob['mapShape'])
55
105
 
56
106
  from syntheticModules import countSequential
57
107
  algorithmSource = countSequential
58
108
  codeSource = inspect.getsource(algorithmSource)
59
109
 
60
- if datatypeLarge:
61
- lineNumba = f"@numba.jit(numba.types.{datatypeLarge}(), cache=True, nopython=True, fastmath=True, forceinline=True, inline='always', looplift=False, _nrt=True, error_model='numpy', parallel=False, boundscheck=False, no_cfunc_wrapper=True, no_cpython_wrapper=False)"
110
+ lineNumba = f"@numba.jit(numba.types.{hackSSOTdatatype('datatypeFoldsTotal')}(), cache=True, nopython=True, fastmath=True, forceinline=True, inline='always', looplift=False, _nrt=True, error_model='numpy', parallel=False, boundscheck=False, no_cfunc_wrapper=False, no_cpython_wrapper=False)"
62
111
 
63
112
  linesImport = "\n".join([
64
113
  "import numpy"
@@ -68,8 +117,6 @@ def writeModuleWithNumba(listDimensions, **keywordArguments: Optional[str]) -> p
68
117
  ImaIndent = ' '
69
118
  linesDataDynamic = """"""
70
119
  linesDataDynamic = "\n".join([linesDataDynamic
71
- # , ImaIndent + f"foldsTotal = numba.types.{datatypeLarge}(0)"
72
- # , ImaIndent + makeStrRLEcompacted(stateJob['foldGroups'], 'foldGroups')
73
120
  , ImaIndent + makeStrRLEcompacted(stateJob['gapsWhere'], 'gapsWhere')
74
121
  ])
75
122
 
@@ -97,11 +144,18 @@ def writeModuleWithNumba(listDimensions, **keywordArguments: Optional[str]) -> p
97
144
  elif 'my[indexMy.' in lineSource:
98
145
  if 'dimensionsTotal' in lineSource:
99
146
  continue
100
- # leaf1ndex = my[indexMy.leaf1ndex.value]
147
+ # Statements are in the form: leaf1ndex = my[indexMy.leaf1ndex.value]
101
148
  identifier, statement = lineSource.split('=')
102
- lineSource = ImaIndent + identifier.strip() + f"=numba.types.{datatypeSmall}({str(eval(statement.strip()))})"
149
+ lineSource = ImaIndent + identifier.strip() + f"=numba.types.{hackSSOTdatatype(identifier.strip())}({str(eval(statement.strip()))})"
150
+ elif ': int =' in lineSource or ':int=' in lineSource:
151
+ if 'dimensionsTotal' in lineSource:
152
+ continue
153
+ # Statements are in the form: groupsOfFolds: int = 0
154
+ assignment, statement = lineSource.split('=')
155
+ identifier = assignment.split(':')[0].strip()
156
+ lineSource = ImaIndent + identifier.strip() + f"=numba.types.{hackSSOTdatatype(identifier.strip())}({str(eval(statement.strip()))})"
103
157
  elif 'track[indexTrack.' in lineSource:
104
- # leafAbove = track[indexTrack.leafAbove.value]
158
+ # Statements are in the form: leafAbove = track[indexTrack.leafAbove.value]
105
159
  identifier, statement = lineSource.split('=')
106
160
  lineSource = ImaIndent + makeStrRLEcompacted(eval(statement.strip()), identifier.strip())
107
161
  elif 'foldGroups[-1]' in lineSource:
@@ -116,10 +170,11 @@ def writeModuleWithNumba(listDimensions, **keywordArguments: Optional[str]) -> p
116
170
  linesLaunch = """"""
117
171
  linesLaunch = linesLaunch + f"""
118
172
  if __name__ == '__main__':
119
- import time
120
- timeStart = time.perf_counter()
173
+ # import time
174
+ # timeStart = time.perf_counter()
121
175
  {identifierCallableLaunch}()
122
- print(time.perf_counter() - timeStart)"""
176
+ # print(time.perf_counter() - timeStart)
177
+ """
123
178
 
124
179
  linesWriteFoldsTotal = """"""
125
180
  linesWriteFoldsTotal = "\n".join([linesWriteFoldsTotal
@@ -143,11 +198,11 @@ if __name__ == '__main__':
143
198
  return pathFilenameDestination
144
199
 
145
200
  if __name__ == '__main__':
146
- listDimensions = [6,6]
147
- datatypeLarge = 'int64'
148
- datatypeMedium = 'uint8'
149
- datatypeSmall = datatypeMedium
150
- pathFilenameModule = writeModuleWithNumba(listDimensions, datatypeLarge=datatypeLarge, datatypeMedium=datatypeMedium, datatypeSmall=datatypeSmall)
201
+ listDimensions = [5,5]
202
+ setDatatypeFoldsTotal('int64', sourGrapes=True)
203
+ setDatatypeElephino('uint8', sourGrapes=True)
204
+ setDatatypeLeavesTotal('int8', sourGrapes=True)
205
+ pathFilenameModule = writeModuleWithNumba(listDimensions)
151
206
 
152
207
  # Induce numba.jit compilation
153
208
  moduleSpec = importlib.util.spec_from_file_location(pathFilenameModule.stem, pathFilenameModule)
@@ -0,0 +1,446 @@
1
+ from mapFolding import EnumIndices, relativePathSyntheticModules, setDatatypeElephino, setDatatypeFoldsTotal, setDatatypeLeavesTotal, setDatatypeModule
2
+ from mapFolding import indexMy, indexTrack, getAlgorithmSource, ParametersNumba, parametersNumbaDEFAULT, hackSSOTdatatype, hackSSOTdtype
3
+ from typing import cast, Dict, List, Optional, Sequence, Set, Type, Union
4
+ from types import ModuleType
5
+ import ast
6
+ import inspect
7
+ import numba
8
+ import numpy
9
+ import pathlib
10
+
11
+ """TODO
12
+ Convert types
13
+ e.g. `groupsOfFolds: int = 0` to `groupsOfFolds = numba.types.{datatypeLarge}(0)`
14
+ This isn't necessary for Numba, but I may the infrastructure for other compilers or paradigms."""
15
+
16
+ class RecursiveInliner(ast.NodeTransformer):
17
+ """
18
+ Class RecursiveInliner:
19
+ A custom AST NodeTransformer designed to recursively inline function calls from a given dictionary
20
+ of function definitions into the AST. Once a particular function has been inlined, it is marked
21
+ as completed to avoid repeated inlining. This transformation modifies the AST in-place by substituting
22
+ eligible function calls with the body of their corresponding function.
23
+ Attributes:
24
+ dictionaryFunctions (Dict[str, ast.FunctionDef]):
25
+ A mapping of function name to its AST definition, used as a source for inlining.
26
+ callablesCompleted (Set[str]):
27
+ A set to track function names that have already been inlined to prevent multiple expansions.
28
+ Methods:
29
+ inlineFunctionBody(callableTargetName: str) -> Optional[ast.FunctionDef]:
30
+ Retrieves the AST definition for a given function name from dictionaryFunctions
31
+ and recursively inlines any function calls within it. Returns the function definition
32
+ that was inlined or None if the function was already processed.
33
+ visit_Call(callNode: ast.Call) -> ast.AST:
34
+ Inspects calls within the AST. If a function call matches one in dictionaryFunctions,
35
+ it is replaced by the inlined body. If the last statement in the inlined body is a return
36
+ or an expression, that value or expression is substituted; otherwise, a constant is returned.
37
+ visit_Expr(node: ast.Expr) -> Union[ast.AST, List[ast.AST]]:
38
+ Handles expression nodes in the AST. If the expression is a function call from
39
+ dictionaryFunctions, its statements are expanded in place, effectively inlining
40
+ the called function's statements into the surrounding context.
41
+ """
42
+ def __init__(self, dictionaryFunctions: Dict[str, ast.FunctionDef]):
43
+ self.dictionaryFunctions = dictionaryFunctions
44
+ self.callablesCompleted: Set[str] = set()
45
+
46
+ def inlineFunctionBody(self, callableTargetName: str) -> Optional[ast.FunctionDef]:
47
+ if (callableTargetName in self.callablesCompleted):
48
+ return None
49
+
50
+ self.callablesCompleted.add(callableTargetName)
51
+ inlineDefinition = self.dictionaryFunctions[callableTargetName]
52
+ for astNode in ast.walk(inlineDefinition):
53
+ self.visit(astNode)
54
+ return inlineDefinition
55
+
56
+ def visit_Call(self, callNode: ast.Call) -> ast.AST:
57
+ callNodeVisited = self.generic_visit(callNode)
58
+ if (isinstance(callNodeVisited, ast.Call) and isinstance(callNodeVisited.func, ast.Name) and callNodeVisited.func.id in self.dictionaryFunctions):
59
+ inlineDefinition = self.inlineFunctionBody(callNodeVisited.func.id)
60
+ if (inlineDefinition and inlineDefinition.body):
61
+ statementTerminating = inlineDefinition.body[-1]
62
+ if (isinstance(statementTerminating, ast.Return) and statementTerminating.value is not None):
63
+ return self.visit(statementTerminating.value)
64
+ elif (isinstance(statementTerminating, ast.Expr) and statementTerminating.value is not None):
65
+ return self.visit(statementTerminating.value)
66
+ return ast.Constant(value=None)
67
+ return callNodeVisited
68
+
69
+ def visit_Expr(self, node: ast.Expr) -> Union[ast.AST, List[ast.AST]]:
70
+ if (isinstance(node.value, ast.Call)):
71
+ if (isinstance(node.value.func, ast.Name) and node.value.func.id in self.dictionaryFunctions):
72
+ inlineDefinition = self.inlineFunctionBody(node.value.func.id)
73
+ if (inlineDefinition):
74
+ return [self.visit(stmt) for stmt in inlineDefinition.body]
75
+ return self.generic_visit(node)
76
+
77
+ def decorateCallableWithNumba(astCallable: ast.FunctionDef, parallel: bool=False) -> ast.FunctionDef:
78
+ """
79
+ Decorates an AST function definition with Numba JIT compilation parameters.
80
+
81
+ This function processes an AST FunctionDef node and adds Numba-specific decorators
82
+ for JIT compilation. It handles array parameter typing and compilation options.
83
+
84
+ Parameters
85
+ ----------
86
+ astCallable : ast.FunctionDef
87
+ The AST node representing the function to be decorated with Numba JIT.
88
+ parallel : bool, optional
89
+ Whether to enable parallel execution in Numba compilation.
90
+ Default is False.
91
+
92
+ Returns
93
+ -------
94
+ ast.FunctionDef
95
+ The modified AST function definition node with added Numba decorators.
96
+
97
+ Notes
98
+ -----
99
+ The function performs the following main tasks:
100
+ 1. Processes function parameters to create Numba-compatible type signatures
101
+ 2. Constructs appropriate Numba compilation parameters
102
+ 3. Creates and attaches a @numba.jit decorator to the function
103
+ Special handling is included for the 'countInitialize' function, which receives
104
+ empty compilation parameters.
105
+ The function relies on external parameters:
106
+ - parametersNumbaDEFAULT: Default Numba compilation parameters
107
+ - ParametersNumba: Class/type for handling Numba parameters
108
+ - hackSSOTdatatype: Function for determining default datatypes
109
+ """
110
+ def makeNumbaParameterSignatureElement(signatureElement: ast.arg):
111
+ """
112
+ Converts an AST function parameter signature element into a Numba-compatible type annotation.
113
+
114
+ This function processes parameter annotations for array types, handling both shape and datatype
115
+ specifications. It supports multi-dimensional arrays through tuple-based shape definitions and
116
+ various numeric datatypes.
117
+
118
+ Parameters
119
+ ----------
120
+ signatureElement : ast.arg
121
+ The AST argument node containing the parameter's name and type annotation.
122
+ Expected annotation format: Type[shape_tuple, dtype]
123
+ where shape_tuple can be either a single dimension or a tuple of dimensions,
124
+ and dtype specifies the data type.
125
+
126
+ Returns
127
+ -------
128
+ ast.Subscript
129
+ A Numba-compatible type annotation as an AST node, representing an array type
130
+ with the specified shape and datatype.
131
+
132
+ Notes
133
+ -----
134
+ The function handles two main cases for shape specifications:
135
+ 1. Multi-dimensional arrays with tuple-based shapes
136
+ 2. Single-dimension arrays with simple slice notation
137
+ The datatype can be either explicitly specified in the annotation or determined
138
+ through a fallback mechanism using hackSSOTdatatype().
139
+ """
140
+ if isinstance(signatureElement.annotation, ast.Subscript) and isinstance(signatureElement.annotation.slice, ast.Tuple):
141
+ annotationShape = signatureElement.annotation.slice.elts[0]
142
+ if isinstance(annotationShape, ast.Subscript) and isinstance(annotationShape.slice, ast.Tuple):
143
+ shapeAsListSlices: Sequence[ast.expr] = [ast.Slice() for axis in range(len(annotationShape.slice.elts))]
144
+ shapeAsListSlices[-1] = ast.Slice(step=ast.Constant(value=1))
145
+ shapeAST = ast.Tuple(elts=list(shapeAsListSlices), ctx=ast.Load())
146
+ else:
147
+ shapeAST = ast.Slice(step=ast.Constant(value=1))
148
+
149
+ annotationDtype = signatureElement.annotation.slice.elts[1]
150
+ if (isinstance(annotationDtype, ast.Subscript) and isinstance(annotationDtype.slice, ast.Attribute)):
151
+ datatypeAST = annotationDtype.slice.attr
152
+ else:
153
+ datatypeAST = None
154
+
155
+ ndarrayName = signatureElement.arg
156
+ Z0Z_hacky_dtype = hackSSOTdatatype(ndarrayName)
157
+ datatype_attr = datatypeAST or Z0Z_hacky_dtype
158
+
159
+ datatypeNumba = ast.Attribute(value=ast.Name(id='numba', ctx=ast.Load()), attr=datatype_attr, ctx=ast.Load())
160
+
161
+ return ast.Subscript(value=datatypeNumba, slice=shapeAST, ctx=ast.Load())
162
+
163
+ # TODO: more explicit handling of decorators. I'm able to ignore this because I know `algorithmSource` doesn't have any decorators.
164
+ # callableSourceDecorators = [decorator for decorator in callableInlined.decorator_list]
165
+
166
+ listNumbaParameterSignature: Sequence[ast.expr] = []
167
+ for parameter in astCallable.args.args:
168
+ signatureElement = makeNumbaParameterSignatureElement(parameter)
169
+ if (signatureElement):
170
+ listNumbaParameterSignature.append(signatureElement)
171
+
172
+ astArgsNumbaSignature = ast.Tuple(elts=listNumbaParameterSignature, ctx=ast.Load())
173
+
174
+ if astCallable.name == 'countInitialize':
175
+ parametersNumba = {}
176
+ else:
177
+ parametersNumba = parametersNumbaDEFAULT if not parallel else ParametersNumba({**parametersNumbaDEFAULT, 'parallel': True})
178
+ listKeywordsNumbaSignature = [ast.keyword(arg=parameterName, value=ast.Constant(value=parameterValue)) for parameterName, parameterValue in parametersNumba.items()]
179
+
180
+ astDecoratorNumba = ast.Call(func=ast.Attribute(value=ast.Name(id='numba', ctx=ast.Load()), attr='jit', ctx=ast.Load()), args=[astArgsNumbaSignature], keywords=listKeywordsNumbaSignature)
181
+
182
+ astCallable.decorator_list = [astDecoratorNumba]
183
+ return astCallable
184
+
185
+ class UnpackArrayAccesses(ast.NodeTransformer):
186
+ """
187
+ A class that transforms array accesses using enum indices into local variables.
188
+
189
+ This AST transformer identifies array accesses using enum indices and replaces them
190
+ with local variables, adding initialization statements at the start of functions.
191
+
192
+ Parameters:
193
+ enumIndexClass (Type[EnumIndices]): The enum class used for array indexing
194
+ arrayName (str): The name of the array being accessed
195
+
196
+ Attributes:
197
+ enumIndexClass (Type[EnumIndices]): Stored enum class for index lookups
198
+ arrayName (str): Name of the array being transformed
199
+ substitutions (dict): Tracks variable substitutions and their original nodes
200
+
201
+ The transformer handles two main cases:
202
+ 1. Scalar array access - array[EnumIndices.MEMBER]
203
+ 2. Array slice access - array[EnumIndices.MEMBER, other_indices...]
204
+ For each identified access pattern, it:
205
+ 1. Creates a local variable named after the enum member
206
+ 2. Adds initialization code at function start
207
+ 3. Replaces original array access with the local variable
208
+ """
209
+
210
+ def __init__(self, enumIndexClass: Type[EnumIndices], arrayName: str):
211
+ self.enumIndexClass = enumIndexClass
212
+ self.arrayName = arrayName
213
+ self.substitutions = {}
214
+
215
+ def extract_member_name(self, node: ast.AST) -> Optional[str]:
216
+ """Recursively extract enum member name from any node in the AST."""
217
+ if isinstance(node, ast.Attribute) and node.attr == 'value':
218
+ innerAttribute = node.value
219
+ while isinstance(innerAttribute, ast.Attribute):
220
+ if (isinstance(innerAttribute.value, ast.Name) and innerAttribute.value.id == self.enumIndexClass.__name__):
221
+ return innerAttribute.attr
222
+ innerAttribute = innerAttribute.value
223
+ return None
224
+
225
+ def transform_slice_element(self, node: ast.AST) -> ast.AST:
226
+ """Transform any enum references within a slice element."""
227
+ if isinstance(node, ast.Subscript):
228
+ if isinstance(node.slice, ast.Attribute):
229
+ member_name = self.extract_member_name(node.slice)
230
+ if member_name:
231
+ return ast.Name(id=member_name, ctx=node.ctx)
232
+ elif isinstance(node, ast.Tuple):
233
+ # Handle tuple slices by transforming each element
234
+ return ast.Tuple(elts=cast(List[ast.expr], [self.transform_slice_element(elt) for elt in node.elts]), ctx=node.ctx)
235
+ elif isinstance(node, ast.Attribute):
236
+ member_name = self.extract_member_name(node)
237
+ if member_name:
238
+ return ast.Name(id=member_name, ctx=ast.Load())
239
+ return node
240
+
241
+ def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
242
+ # Recursively visit any nested subscripts in value or slice
243
+ node.value = self.visit(node.value)
244
+ node.slice = self.visit(node.slice)
245
+ # If node.value is not our arrayName, just return node
246
+ if not (isinstance(node.value, ast.Name) and node.value.id == self.arrayName):
247
+ return node
248
+
249
+ # Handle scalar array access
250
+ if isinstance(node.slice, ast.Attribute):
251
+ memberName = self.extract_member_name(node.slice)
252
+ if memberName:
253
+ self.substitutions[memberName] = ('scalar', node)
254
+ return ast.Name(id=memberName, ctx=ast.Load())
255
+
256
+ # Handle array slice access
257
+ if isinstance(node.slice, ast.Tuple) and node.slice.elts:
258
+ firstElement = node.slice.elts[0]
259
+ memberName = self.extract_member_name(firstElement)
260
+ sliceRemainder = [self.visit(elem) for elem in node.slice.elts[1:]]
261
+ if memberName:
262
+ self.substitutions[memberName] = ('array', node)
263
+ if len(sliceRemainder) == 0:
264
+ return ast.Name(id=memberName, ctx=ast.Load())
265
+ return ast.Subscript(value=ast.Name(id=memberName, ctx=ast.Load()), slice=ast.Tuple(elts=sliceRemainder, ctx=ast.Load()) if len(sliceRemainder) > 1 else sliceRemainder[0], ctx=ast.Load())
266
+
267
+ # If single-element tuple, unwrap
268
+ if isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 1:
269
+ node.slice = node.slice.elts[0]
270
+
271
+ return node
272
+
273
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
274
+ node = cast(ast.FunctionDef, self.generic_visit(node))
275
+
276
+ initializations = []
277
+ for name, (kind, original_node) in self.substitutions.items():
278
+ if kind == 'scalar':
279
+ initializations.append(ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=original_node))
280
+ else: # array
281
+ initializations.append(
282
+ ast.Assign(
283
+ targets=[ast.Name(id=name, ctx=ast.Store())],
284
+ value=ast.Subscript(value=ast.Name(id=self.arrayName, ctx=ast.Load()),
285
+ slice=ast.Attribute(value=ast.Attribute(
286
+ value=ast.Name(id=self.enumIndexClass.__name__, ctx=ast.Load()),
287
+ attr=name, ctx=ast.Load()), attr='value', ctx=ast.Load()), ctx=ast.Load())))
288
+
289
+ node.body = initializations + node.body
290
+ return node
291
+
292
+ def inlineOneCallable(codeSource, callableTarget):
293
+ """
294
+ Inlines a target callable function and its dependencies within the provided code source.
295
+
296
+ This function performs function inlining, optionally adds Numba decorators, and handles array access unpacking
297
+ for specific callable targets. It processes the source code through AST manipulation and returns the modified source.
298
+
299
+ Parameters:
300
+ codeSource (str): The source code containing the callable to be inlined.
301
+ callableTarget (str): The name of the callable function to be inlined. Special handling is provided for
302
+ 'countParallel', 'countInitialize', and 'countSequential'.
303
+
304
+ Returns:
305
+ str: The modified source code with the inlined callable and necessary imports.
306
+
307
+ The function performs the following operations:
308
+ 1. Parses the source code into an AST
309
+ 2. Extracts import statements and function definitions
310
+ 3. Inlines the target function using RecursiveInliner
311
+ 4. Applies Numba decoration if needed
312
+ 5. Handles special array access unpacking for 'countSequential'
313
+ 6. Reconstructs and returns the modified source code
314
+
315
+ Note:
316
+ - Special handling is provided for 'countParallel', 'countInitialize', and 'countSequential' targets
317
+ - For 'countSequential', additional array access unpacking is performed for 'my' and 'track' indices
318
+ - `UnpackArrayAccesses` would need modification to handle 'countParallel'
319
+ """
320
+
321
+ codeParsed: ast.Module = ast.parse(codeSource, type_comments=True)
322
+ codeSourceImportStatements = {statement for statement in codeParsed.body if isinstance(statement, (ast.Import, ast.ImportFrom))}
323
+ dictionaryFunctions = {statement.name: statement for statement in codeParsed.body if isinstance(statement, ast.FunctionDef)}
324
+ callableInlinerWorkhorse = RecursiveInliner(dictionaryFunctions)
325
+ callableInlined = callableInlinerWorkhorse.inlineFunctionBody(callableTarget)
326
+
327
+ if callableInlined:
328
+ ast.fix_missing_locations(callableInlined)
329
+ parallel = callableTarget == 'countParallel'
330
+ callableDecorated = decorateCallableWithNumba(callableInlined, parallel)
331
+
332
+ if callableTarget == 'countSequential':
333
+ unpackerMy = UnpackArrayAccesses(indexMy, 'my')
334
+ callableDecorated = cast(ast.FunctionDef, unpackerMy.visit(callableDecorated))
335
+ ast.fix_missing_locations(callableDecorated)
336
+
337
+ unpackerTrack = UnpackArrayAccesses(indexTrack, 'track')
338
+ callableDecorated = cast(ast.FunctionDef, unpackerTrack.visit(callableDecorated))
339
+ ast.fix_missing_locations(callableDecorated)
340
+
341
+ moduleAST = ast.Module(body=cast(List[ast.stmt], list(codeSourceImportStatements) + [callableDecorated]), type_ignores=[])
342
+ ast.fix_missing_locations(moduleAST)
343
+ moduleSource = ast.unparse(moduleAST)
344
+ return moduleSource
345
+
346
+ class AppendDunderInit(ast.NodeTransformer):
347
+ """AST transformer that validates and appends imports to __init__.py files."""
348
+
349
+ def __init__(self, listPathFilenamesDestination: list[tuple[pathlib.Path, str]]):
350
+ self.listPathFilenamesDestination = listPathFilenamesDestination
351
+ self.listTuplesDunderInit = []
352
+
353
+ def process_init_files(self) -> list[tuple[pathlib.Path, str]]:
354
+ for pathFilename, callableTarget in self.listPathFilenamesDestination:
355
+ pathDunderInit = pathFilename.parent / "__init__.py"
356
+
357
+ # Create empty init if doesn't exist
358
+ if not pathDunderInit.exists():
359
+ pathDunderInit.write_text("")
360
+
361
+ # Parse existing init file
362
+ try:
363
+ treeInit = ast.parse(pathDunderInit.read_text())
364
+ except SyntaxError:
365
+ treeInit = ast.Module(body=[], type_ignores=[])
366
+
367
+ # Compute the lowercase module target
368
+ moduleTarget = "." + pathFilename.stem
369
+ moduleTargetLower = moduleTarget.lower()
370
+
371
+ # Track existing imports as (normalizedModule, name)
372
+ setImportsExisting = set()
373
+ for node in treeInit.body:
374
+ if isinstance(node, ast.ImportFrom) and node.module:
375
+ # Compare on a lowercase basis
376
+ if node.module.lower() == moduleTargetLower:
377
+ for alias in node.names:
378
+ setImportsExisting.add((moduleTargetLower, alias.name))
379
+
380
+ # Only append if this exact import doesn't exist
381
+ if (moduleTargetLower, callableTarget) not in setImportsExisting:
382
+ newImport = ast.ImportFrom(
383
+ module=moduleTarget,
384
+ names=[ast.alias(name=callableTarget, asname=None)],
385
+ level=0
386
+ )
387
+ treeInit.body.append(newImport)
388
+ ast.fix_missing_locations(treeInit)
389
+ pathDunderInit.write_text(ast.unparse(treeInit))
390
+
391
+ self.listTuplesDunderInit.append((pathDunderInit, callableTarget))
392
+
393
+ return self.listTuplesDunderInit
394
+
395
+ def inlineMapFoldingNumba(listCallablesAsStr: List[str], algorithmSource: Optional[ModuleType] = None):
396
+ """Synthesizes numba-optimized versions of map folding functions.
397
+ This function creates specialized versions of map folding functions by inlining
398
+ target callables and generating optimized modules. It handles the code generation
399
+ and file writing process.
400
+
401
+ Parameters:
402
+ listCallablesAsStr (List[str]): List of callable names to be processed as strings.
403
+ algorithmSource (Optional[ModuleType], optional): Source module containing the algorithms.
404
+ If None, will be obtained via getAlgorithmSource(). Defaults to None.
405
+
406
+ Returns:
407
+ List[Tuple[pathlib.Path, str]]: List of tuples containing:
408
+ - Generated file paths
409
+ - Associated callable names
410
+
411
+ Raises:
412
+ Exception: If inline operation fails during code generation.
413
+
414
+ Note:
415
+ - Generated files are placed in a synthetic modules subdirectory
416
+ - Modifies __init__.py files to expose generated modules
417
+ - Current implementation contains hardcoded paths that should be abstracted
418
+ """
419
+ if not algorithmSource:
420
+ algorithmSource = getAlgorithmSource()
421
+
422
+ listPathFilenamesDestination: list[tuple[pathlib.Path, str]] = []
423
+
424
+ # TODO abstract this process
425
+ # especially remove the hardcoded paths and filenames
426
+
427
+ for callableTarget in listCallablesAsStr:
428
+ codeSource = inspect.getsource(algorithmSource)
429
+ moduleSource = inlineOneCallable(codeSource, callableTarget)
430
+ if not moduleSource:
431
+ raise Exception("Pylance, OMG! The sky is falling!")
432
+ pathFilenameAlgorithm = pathlib.Path(inspect.getfile(algorithmSource))
433
+ pathFilenameDestination = pathFilenameAlgorithm.parent / relativePathSyntheticModules / pathFilenameAlgorithm.with_stem("numba"+callableTarget[5:None]).name
434
+ pathFilenameDestination.write_text(moduleSource)
435
+ listPathFilenamesDestination.append((pathFilenameDestination, callableTarget))
436
+
437
+ # This almost works: it duplicates existing imports, though
438
+ listTuplesDunderInit = AppendDunderInit(listPathFilenamesDestination).process_init_files()
439
+
440
+ if __name__ == '__main__':
441
+ listCallablesAsStr: List[str] = ['countInitialize', 'countParallel', 'countSequential']
442
+ setDatatypeModule('numpy', sourGrapes=True)
443
+ setDatatypeFoldsTotal('int64', sourGrapes=True)
444
+ setDatatypeElephino('uint8', sourGrapes=True)
445
+ setDatatypeLeavesTotal('uint8', sourGrapes=True)
446
+ inlineMapFoldingNumba(listCallablesAsStr)
@@ -1,4 +1,3 @@
1
- from .Initialize import countInitialize
2
- from .Parallel import countParallel
3
- from .Sequential import countSequential
4
-
1
+ from .numbaInitialize import countInitialize
2
+ from .numbaParallel import countParallel
3
+ from .numbaSequential import countSequential
@@ -1,11 +1,14 @@
1
- import numba
2
- from numpy import integer
3
- from mapFolding import indexMy, indexTrack
4
1
  import numpy
5
2
  from typing import Any, Tuple
3
+ from numpy import integer
4
+ from mapFolding import indexMy, indexTrack
5
+ import numba
6
6
 
7
7
  @numba.jit((numba.uint8[:, :, ::1], numba.uint8[::1], numba.uint8[::1], numba.uint8[:, ::1]))
8
- def countInitialize(connectionGraph: numpy.ndarray[Tuple[int, int, int], numpy.dtype[integer[Any]]], gapsWhere: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], my: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], track: numpy.ndarray[Tuple[int, int], numpy.dtype[integer[Any]]]):
8
+ def countInitialize(connectionGraph: numpy.ndarray[Tuple[int, int, int], numpy.dtype[integer[Any]]]
9
+ , gapsWhere: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]]
10
+ , my: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]]
11
+ , track: numpy.ndarray[Tuple[int, int], numpy.dtype[integer[Any]]]):
9
12
  while my[indexMy.leaf1ndex.value]:
10
13
  if my[indexMy.leaf1ndex.value] <= 1 or track[indexTrack.leafBelow.value, 0] == 1:
11
14
  my[indexMy.dimensionsUnconstrained.value] = my[indexMy.dimensionsTotal.value]
@@ -45,4 +48,4 @@ def countInitialize(connectionGraph: numpy.ndarray[Tuple[int, int, int], numpy.d
45
48
  track[indexTrack.gapRangeStart.value, my[indexMy.leaf1ndex.value]] = my[indexMy.gap1ndex.value]
46
49
  my[indexMy.leaf1ndex.value] += 1
47
50
  if my[indexMy.gap1ndex.value] > 0:
48
- return
51
+ return