mapFolding 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mapFolding/__init__.py +96 -58
- mapFolding/basecamp.py +5 -7
- mapFolding/beDRY.py +11 -41
- mapFolding/oeis.py +71 -74
- mapFolding/theConfiguration.py +58 -0
- mapFolding/theDao.py +1 -1
- mapFolding/theSSOT.py +14 -48
- mapFolding/theSSOTdatatypes.py +25 -36
- mapFolding/theWrongWay.py +7 -0
- {mapFolding-0.5.0.dist-info → mapfolding-0.6.0.dist-info}/METADATA +6 -4
- mapfolding-0.6.0.dist-info/RECORD +16 -0
- {mapFolding-0.5.0.dist-info → mapfolding-0.6.0.dist-info}/WHEEL +1 -1
- {mapFolding-0.5.0.dist-info → mapfolding-0.6.0.dist-info}/top_level.txt +0 -1
- mapFolding/reference/flattened.py +0 -377
- mapFolding/reference/hunterNumba.py +0 -132
- mapFolding/reference/irvineJavaPort.py +0 -120
- mapFolding/reference/jax.py +0 -208
- mapFolding/reference/lunnan.py +0 -153
- mapFolding/reference/lunnanNumpy.py +0 -123
- mapFolding/reference/lunnanWhile.py +0 -121
- mapFolding/reference/rotatedEntryPoint.py +0 -240
- mapFolding/reference/total_countPlus1vsPlusN.py +0 -211
- mapFolding/someAssemblyRequired/__init__.py +0 -5
- mapFolding/someAssemblyRequired/getLLVMforNoReason.py +0 -19
- mapFolding/someAssemblyRequired/makeJob.py +0 -56
- mapFolding/someAssemblyRequired/synthesizeModuleJAX.py +0 -27
- mapFolding/someAssemblyRequired/synthesizeNumba.py +0 -345
- mapFolding/someAssemblyRequired/synthesizeNumbaGeneralized.py +0 -397
- mapFolding/someAssemblyRequired/synthesizeNumbaJob.py +0 -155
- mapFolding/someAssemblyRequired/synthesizeNumbaModules.py +0 -123
- mapFolding/syntheticModules/numbaCount.py +0 -158
- mapFolding/syntheticModules/numba_doTheNeedful.py +0 -13
- mapFolding-0.5.0.dist-info/RECORD +0 -39
- tests/__init__.py +0 -1
- tests/conftest.py +0 -335
- tests/test_computations.py +0 -42
- tests/test_oeis.py +0 -128
- tests/test_other.py +0 -175
- tests/test_tasks.py +0 -40
- /mapFolding/{syntheticModules/__init__.py → py.typed} +0 -0
- {mapFolding-0.5.0.dist-info → mapfolding-0.6.0.dist-info}/LICENSE +0 -0
- {mapFolding-0.5.0.dist-info → mapfolding-0.6.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
import importlib
|
|
2
|
-
import importlib.util
|
|
3
|
-
import llvmlite.binding
|
|
4
|
-
import pathlib
|
|
5
|
-
|
|
6
|
-
def writeModuleLLVM(pathFilename: pathlib.Path, identifierCallable: str) -> pathlib.Path:
|
|
7
|
-
"""Import the generated module directly and get its LLVM IR."""
|
|
8
|
-
specTarget = importlib.util.spec_from_file_location("generatedModule", pathFilename)
|
|
9
|
-
if specTarget is None or specTarget.loader is None:
|
|
10
|
-
raise ImportError(f"Could not create module spec or loader for {pathFilename}")
|
|
11
|
-
moduleTarget = importlib.util.module_from_spec(specTarget)
|
|
12
|
-
specTarget.loader.exec_module(moduleTarget)
|
|
13
|
-
|
|
14
|
-
# Get LLVM IR and write to file
|
|
15
|
-
linesLLVM = moduleTarget.__dict__[identifierCallable].inspect_llvm()[()]
|
|
16
|
-
moduleLLVM = llvmlite.binding.module.parse_assembly(linesLLVM)
|
|
17
|
-
pathFilenameLLVM = pathFilename.with_suffix(".ll")
|
|
18
|
-
pathFilenameLLVM.write_text(str(moduleLLVM))
|
|
19
|
-
return pathFilenameLLVM
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
from mapFolding import getPathFilenameFoldsTotal, computationState, outfitCountFolds, getAlgorithmSource
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from types import ModuleType
|
|
5
|
-
from typing import Any, Literal, overload
|
|
6
|
-
import pickle
|
|
7
|
-
|
|
8
|
-
@overload
|
|
9
|
-
def makeStateJob(listDimensions: Sequence[int], *, writeJob: Literal[True] , **keywordArguments: str | None) -> Path: ...
|
|
10
|
-
@overload
|
|
11
|
-
def makeStateJob(listDimensions: Sequence[int], *, writeJob: Literal[False] , **keywordArguments: str | None) -> computationState: ...
|
|
12
|
-
def makeStateJob(listDimensions: Sequence[int], *, writeJob: bool = True, **keywordArguments: Any | None) -> computationState | Path:
|
|
13
|
-
"""
|
|
14
|
-
Creates a computation state job for map folding calculations and optionally saves it to disk.
|
|
15
|
-
|
|
16
|
-
This function initializes a computation state for map folding calculations based on the given dimensions,
|
|
17
|
-
sets up the initial counting configuration, and can optionally save the state to a pickle file.
|
|
18
|
-
|
|
19
|
-
Parameters
|
|
20
|
-
----------
|
|
21
|
-
listDimensions : Sequence[int]
|
|
22
|
-
The dimensions of the map to be folded, typically as [height, width].
|
|
23
|
-
writeJob : bool, optional
|
|
24
|
-
If True, saves the computation state to disk. If False, returns the state object directly.
|
|
25
|
-
Default is True.
|
|
26
|
-
**keywordArguments : Optional[str]
|
|
27
|
-
Additional keyword arguments to be passed to the outfitCountFolds function.
|
|
28
|
-
|
|
29
|
-
Returns
|
|
30
|
-
-------
|
|
31
|
-
Union[computationState, Path]
|
|
32
|
-
If writeJob is False, returns the computation state object.
|
|
33
|
-
If writeJob is True, returns the Path object pointing to the saved state file.
|
|
34
|
-
|
|
35
|
-
Notes
|
|
36
|
-
-----
|
|
37
|
-
The function creates necessary directories and saves the state as a pickle file
|
|
38
|
-
when writeJob is True. The file is saved in a directory structure based on the map shape.
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
stateUniversal: computationState = outfitCountFolds(listDimensions, **keywordArguments)
|
|
42
|
-
|
|
43
|
-
moduleSource: ModuleType = getAlgorithmSource()
|
|
44
|
-
moduleSource.countInitialize(stateUniversal['connectionGraph'], stateUniversal['gapsWhere'], stateUniversal['my'], stateUniversal['track'])
|
|
45
|
-
|
|
46
|
-
if not writeJob:
|
|
47
|
-
return stateUniversal
|
|
48
|
-
|
|
49
|
-
pathFilenameChopChop = getPathFilenameFoldsTotal(stateUniversal['mapShape'])
|
|
50
|
-
suffix = pathFilenameChopChop.suffix
|
|
51
|
-
pathJob = Path(str(pathFilenameChopChop)[0:-len(suffix)])
|
|
52
|
-
pathJob.mkdir(parents=True, exist_ok=True)
|
|
53
|
-
pathFilenameJob = pathJob / 'stateJob.pkl'
|
|
54
|
-
|
|
55
|
-
pathFilenameJob.write_bytes(pickle.dumps(stateUniversal))
|
|
56
|
-
return pathFilenameJob
|
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
from mapFolding import getAlgorithmSource, getPathSyntheticModules
|
|
2
|
-
from mapFolding import setDatatypeModule, setDatatypeFoldsTotal, setDatatypeElephino, setDatatypeLeavesTotal
|
|
3
|
-
import ast
|
|
4
|
-
import inspect
|
|
5
|
-
import pathlib
|
|
6
|
-
|
|
7
|
-
def transformPythonToJAX(codePython: str) -> None:
|
|
8
|
-
astPython = ast.parse(codePython)
|
|
9
|
-
|
|
10
|
-
def writeJax(*, codeSource: str | None = None, pathFilenameAlgorithm: pathlib.Path | None = None, pathFilenameDestination: pathlib.Path | None = None) -> None:
|
|
11
|
-
if codeSource is None and pathFilenameAlgorithm is None:
|
|
12
|
-
algorithmSource = getAlgorithmSource()
|
|
13
|
-
codeSource = inspect.getsource(algorithmSource)
|
|
14
|
-
transformedText = transformPythonToJAX(codeSource)
|
|
15
|
-
pathFilenameAlgorithm = pathlib.Path(inspect.getfile(algorithmSource))
|
|
16
|
-
else:
|
|
17
|
-
raise NotImplementedError("You haven't written this part yet.")
|
|
18
|
-
if pathFilenameDestination is None:
|
|
19
|
-
pathFilenameDestination = getPathSyntheticModules() / "countJax.py"
|
|
20
|
-
# pathFilenameDestination.write_text(transformedText)
|
|
21
|
-
|
|
22
|
-
if __name__ == '__main__':
|
|
23
|
-
setDatatypeModule('jax.numpy', sourGrapes=True)
|
|
24
|
-
setDatatypeFoldsTotal('int64', sourGrapes=True)
|
|
25
|
-
setDatatypeElephino('uint8', sourGrapes=True)
|
|
26
|
-
setDatatypeLeavesTotal('uint8', sourGrapes=True)
|
|
27
|
-
writeJax()
|
|
@@ -1,345 +0,0 @@
|
|
|
1
|
-
"""I think this module is free of hardcoded values.
|
|
2
|
-
TODO: consolidate the logic in this module."""
|
|
3
|
-
from mapFolding.someAssemblyRequired.synthesizeNumbaGeneralized import *
|
|
4
|
-
|
|
5
|
-
def insertArrayIn_body(FunctionDefTarget: ast.FunctionDef, identifier: str, arrayTarget: numpy.ndarray, allImports: UniversalImportTracker, unrollSlices: int | None = None) -> tuple[ast.FunctionDef, UniversalImportTracker]:
|
|
6
|
-
arrayType = type(arrayTarget)
|
|
7
|
-
moduleConstructor = arrayType.__module__
|
|
8
|
-
constructorName = arrayType.__name__
|
|
9
|
-
# NOTE hack
|
|
10
|
-
constructorName = constructorName.replace('ndarray', 'array')
|
|
11
|
-
argData_dtype: numpy.dtype = arrayTarget.dtype
|
|
12
|
-
datatypeName = argData_dtype.name
|
|
13
|
-
dtypeAsName = f"{moduleConstructor}_{datatypeName}"
|
|
14
|
-
|
|
15
|
-
allImports.addImportFromStr(moduleConstructor, constructorName)
|
|
16
|
-
allImports.addImportFromStr(moduleConstructor, datatypeName, dtypeAsName)
|
|
17
|
-
|
|
18
|
-
def insertAssign(assignee: str, arraySlice: numpy.ndarray) -> None:
|
|
19
|
-
nonlocal FunctionDefTarget
|
|
20
|
-
onlyDataRLE = autoDecodingRLE(arraySlice, addSpaces=True)
|
|
21
|
-
astStatement = cast(ast.Expr, ast.parse(onlyDataRLE).body[0])
|
|
22
|
-
dataAst = astStatement.value
|
|
23
|
-
|
|
24
|
-
arrayCall = Then.make_astCall(name=constructorName, args=[dataAst], list_astKeywords=[ast.keyword(arg='dtype', value=ast.Name(id=dtypeAsName, ctx=ast.Load()))])
|
|
25
|
-
|
|
26
|
-
assignment = ast.Assign(targets=[ast.Name(id=assignee, ctx=ast.Store())], value=arrayCall)#NOTE
|
|
27
|
-
FunctionDefTarget.body.insert(0, assignment)
|
|
28
|
-
|
|
29
|
-
if not unrollSlices:
|
|
30
|
-
insertAssign(identifier, arrayTarget)
|
|
31
|
-
else:
|
|
32
|
-
for index, arraySlice in enumerate(arrayTarget):
|
|
33
|
-
insertAssign(f"{identifier}_{index}", arraySlice)
|
|
34
|
-
|
|
35
|
-
return FunctionDefTarget, allImports
|
|
36
|
-
|
|
37
|
-
def findAndReplaceTrackArrayIn_body(FunctionDefTarget: ast.FunctionDef, identifier: str , arrayTarget: numpy.ndarray , allImports: UniversalImportTracker) -> tuple[ast.FunctionDef, UniversalImportTracker]:
|
|
38
|
-
|
|
39
|
-
arrayType = type(arrayTarget)
|
|
40
|
-
moduleConstructor = arrayType.__module__
|
|
41
|
-
constructorName = arrayType.__name__
|
|
42
|
-
# NOTE hack
|
|
43
|
-
constructorName = constructorName.replace('ndarray', 'array')
|
|
44
|
-
allImports.addImportFromStr(moduleConstructor, constructorName)
|
|
45
|
-
|
|
46
|
-
for statement in FunctionDefTarget.body.copy():
|
|
47
|
-
if ifThis.isUnpackingAnArray(identifier)(statement):
|
|
48
|
-
datatypeName = hackSSOTdatatype(statement.targets[0].id) # type: ignore
|
|
49
|
-
dtypeAsName = f"{moduleConstructor}_{datatypeName}"
|
|
50
|
-
indexAsStr = ast.unparse(statement.value.slice) # type: ignore
|
|
51
|
-
arraySlice = arrayTarget[eval(indexAsStr)]
|
|
52
|
-
|
|
53
|
-
onlyDataRLE = autoDecodingRLE(arraySlice, addSpaces=True)
|
|
54
|
-
astStatement = cast(ast.Expr, ast.parse(onlyDataRLE).body[0])
|
|
55
|
-
dataAst = astStatement.value
|
|
56
|
-
|
|
57
|
-
arrayCall = Then.make_astCall(name=constructorName, args=[dataAst], list_astKeywords=[ast.keyword(arg='dtype', value=ast.Name(id=dtypeAsName, ctx=ast.Load()))])
|
|
58
|
-
|
|
59
|
-
assignment = ast.Assign(targets=[statement.targets[0]], value=arrayCall) # type: ignore
|
|
60
|
-
FunctionDefTarget.body.insert(0, assignment)
|
|
61
|
-
FunctionDefTarget.body.remove(statement)
|
|
62
|
-
allImports.addImportFromStr(moduleConstructor, datatypeName, dtypeAsName)
|
|
63
|
-
return FunctionDefTarget, allImports
|
|
64
|
-
|
|
65
|
-
def findAndReplaceArraySubscriptIn_body(FunctionDefTarget: ast.FunctionDef, identifier: str, arrayTarget: numpy.ndarray, Z0Z_listChaff: list[str], allImports: UniversalImportTracker) -> tuple[ast.FunctionDef, UniversalImportTracker]:
|
|
66
|
-
moduleConstructor = Z0Z_getDatatypeModuleScalar()
|
|
67
|
-
for statement in FunctionDefTarget.body.copy():
|
|
68
|
-
if ifThis.isUnpackingAnArray(identifier)(statement):
|
|
69
|
-
astSubscript: ast.Subscript = statement.value # type: ignore
|
|
70
|
-
astAssignee: ast.Name = statement.targets[0] # type: ignore
|
|
71
|
-
argData_dtypeName = hackSSOTdatatype(astAssignee.id)
|
|
72
|
-
allImports.addImportFromStr(moduleConstructor, argData_dtypeName)
|
|
73
|
-
indexAs_astAttribute: ast.Attribute = astSubscript.slice # type: ignore
|
|
74
|
-
indexAsStr = ast.unparse(indexAs_astAttribute)
|
|
75
|
-
argDataSlice: int = arrayTarget[eval(indexAsStr)].item()
|
|
76
|
-
astCall = ast.Call(func=ast.Name(id=argData_dtypeName, ctx=ast.Load()), args=[ast.Constant(value=argDataSlice)], keywords=[])
|
|
77
|
-
assignment = ast.Assign(targets=[astAssignee], value=astCall)
|
|
78
|
-
if astAssignee.id not in Z0Z_listChaff:
|
|
79
|
-
FunctionDefTarget.body.insert(0, assignment)
|
|
80
|
-
FunctionDefTarget.body.remove(statement)
|
|
81
|
-
return FunctionDefTarget, allImports
|
|
82
|
-
|
|
83
|
-
def removeAssignTargetFrom_body(FunctionDefTarget: ast.FunctionDef, identifier: str) -> ast.FunctionDef:
|
|
84
|
-
# Remove assignment nodes where the target is either a Subscript referencing `identifier` or satisfies ifThis.nameIs(identifier).
|
|
85
|
-
def predicate(astNode: ast.AST) -> bool:
|
|
86
|
-
if not isinstance(astNode, ast.Assign) or not astNode.targets:
|
|
87
|
-
return False
|
|
88
|
-
targetNode = astNode.targets[0]
|
|
89
|
-
return (isinstance(targetNode, ast.Subscript) and isinstance(targetNode.value, ast.Name) and targetNode.value.id == identifier) or ifThis.nameIs(identifier)(targetNode)
|
|
90
|
-
def replacementBuilder(astNode: ast.AST) -> ast.stmt | None:
|
|
91
|
-
# Returning None removes the node.
|
|
92
|
-
return None
|
|
93
|
-
FunctionDefSherpa = NodeReplacer(predicate, replacementBuilder).visit(FunctionDefTarget)
|
|
94
|
-
if not FunctionDefSherpa:
|
|
95
|
-
raise FREAKOUT("Dude, where's my function?")
|
|
96
|
-
else:
|
|
97
|
-
FunctionDefTarget = cast(ast.FunctionDef, FunctionDefSherpa)
|
|
98
|
-
ast.fix_missing_locations(FunctionDefTarget)
|
|
99
|
-
return FunctionDefTarget
|
|
100
|
-
|
|
101
|
-
def findAndReplaceAnnAssignIn_body(FunctionDefTarget: ast.FunctionDef, allImports: UniversalImportTracker) -> tuple[ast.FunctionDef, UniversalImportTracker]:
|
|
102
|
-
moduleConstructor = Z0Z_getDatatypeModuleScalar()
|
|
103
|
-
for stmt in FunctionDefTarget.body.copy():
|
|
104
|
-
if isinstance(stmt, ast.AnnAssign):
|
|
105
|
-
if isinstance(stmt.target, ast.Name) and isinstance(stmt.value, ast.Constant):
|
|
106
|
-
astAssignee: ast.Name = stmt.target
|
|
107
|
-
argData_dtypeName = hackSSOTdatatype(astAssignee.id)
|
|
108
|
-
allImports.addImportFromStr(moduleConstructor, argData_dtypeName)
|
|
109
|
-
astCall = ast.Call(func=ast.Name(id=argData_dtypeName, ctx=ast.Load()) , args=[stmt.value], keywords=[])
|
|
110
|
-
assignment = ast.Assign(targets=[astAssignee], value=astCall)
|
|
111
|
-
FunctionDefTarget.body.insert(0, assignment)
|
|
112
|
-
FunctionDefTarget.body.remove(stmt)
|
|
113
|
-
return FunctionDefTarget, allImports
|
|
114
|
-
|
|
115
|
-
def findThingyReplaceWithConstantIn_body(FunctionDefTarget: ast.FunctionDef, object: str, value: int) -> ast.FunctionDef:
|
|
116
|
-
"""
|
|
117
|
-
Replaces nodes in astFunction matching the AST of the string `object`
|
|
118
|
-
with a constant node holding the provided value.
|
|
119
|
-
"""
|
|
120
|
-
targetExpression = ast.parse(object, mode='eval').body
|
|
121
|
-
targetDump = ast.dump(targetExpression, annotate_fields=False)
|
|
122
|
-
|
|
123
|
-
def findNode(node: ast.AST) -> bool:
|
|
124
|
-
return ast.dump(node, annotate_fields=False) == targetDump
|
|
125
|
-
|
|
126
|
-
def replaceWithConstant(node: ast.AST) -> ast.AST:
|
|
127
|
-
return ast.copy_location(ast.Constant(value=value), node)
|
|
128
|
-
|
|
129
|
-
transformer = NodeReplacer(findNode, replaceWithConstant)
|
|
130
|
-
newFunction = cast(ast.FunctionDef, transformer.visit(FunctionDefTarget))
|
|
131
|
-
ast.fix_missing_locations(newFunction)
|
|
132
|
-
return newFunction
|
|
133
|
-
|
|
134
|
-
def findAstNameReplaceWithConstantIn_body(FunctionDefTarget: ast.FunctionDef, name: str, value: int) -> ast.FunctionDef:
|
|
135
|
-
def replaceWithConstant(node: ast.AST) -> ast.AST:
|
|
136
|
-
return ast.copy_location(ast.Constant(value=value), node)
|
|
137
|
-
|
|
138
|
-
return cast(ast.FunctionDef, NodeReplacer(ifThis.nameIs(name), replaceWithConstant).visit(FunctionDefTarget))
|
|
139
|
-
|
|
140
|
-
def insertReturnStatementIn_body(FunctionDefTarget: ast.FunctionDef, arrayTarget: numpy.ndarray, allImports: UniversalImportTracker) -> tuple[ast.FunctionDef, UniversalImportTracker]:
|
|
141
|
-
"""Add multiplication and return statement to function, properly constructing AST nodes."""
|
|
142
|
-
# Create AST for multiplication operation
|
|
143
|
-
multiplicand = Z0Z_identifierCountFolds
|
|
144
|
-
datatype = hackSSOTdatatype(multiplicand)
|
|
145
|
-
multiplyOperation = ast.BinOp(
|
|
146
|
-
left=ast.Name(id=multiplicand, ctx=ast.Load()),
|
|
147
|
-
op=ast.Mult(), right=ast.Constant(value=int(arrayTarget[-1])))
|
|
148
|
-
|
|
149
|
-
returnStatement = ast.Return(value=multiplyOperation)
|
|
150
|
-
|
|
151
|
-
datatype = hackSSOTdatatype(Z0Z_identifierCountFolds)
|
|
152
|
-
FunctionDefTarget.returns = ast.Name(id=datatype, ctx=ast.Load())
|
|
153
|
-
datatypeModuleScalar = Z0Z_getDatatypeModuleScalar()
|
|
154
|
-
allImports.addImportFromStr(datatypeModuleScalar, datatype)
|
|
155
|
-
|
|
156
|
-
FunctionDefTarget.body.append(returnStatement)
|
|
157
|
-
|
|
158
|
-
return FunctionDefTarget, allImports
|
|
159
|
-
|
|
160
|
-
def findAndReplaceWhileLoopIn_body(FunctionDefTarget: ast.FunctionDef, iteratorName: str, iterationsTotal: int) -> ast.FunctionDef:
|
|
161
|
-
"""
|
|
162
|
-
Unroll all nested while loops matching the condition that their test uses `iteratorName`.
|
|
163
|
-
"""
|
|
164
|
-
# Helper transformer to replace iterator occurrences with a constant.
|
|
165
|
-
class ReplaceIterator(ast.NodeTransformer):
|
|
166
|
-
def __init__(self, iteratorName: str, constantValue: int) -> None:
|
|
167
|
-
super().__init__()
|
|
168
|
-
self.iteratorName = iteratorName
|
|
169
|
-
self.constantValue = constantValue
|
|
170
|
-
|
|
171
|
-
def visit_Name(self, node: ast.Name) -> ast.AST:
|
|
172
|
-
if node.id == self.iteratorName:
|
|
173
|
-
return ast.copy_location(ast.Constant(value=self.constantValue), node)
|
|
174
|
-
return self.generic_visit(node)
|
|
175
|
-
|
|
176
|
-
# NodeTransformer that finds while loops (even if deeply nested) and unrolls them.
|
|
177
|
-
class WhileLoopUnroller(ast.NodeTransformer):
|
|
178
|
-
def __init__(self, iteratorName: str, iterationsTotal: int) -> None:
|
|
179
|
-
super().__init__()
|
|
180
|
-
self.iteratorName = iteratorName
|
|
181
|
-
self.iterationsTotal = iterationsTotal
|
|
182
|
-
|
|
183
|
-
def visit_While(self, node: ast.While) -> list[ast.stmt]:
|
|
184
|
-
# Check if the while loop's test uses the iterator.
|
|
185
|
-
if isinstance(node.test, ast.Compare) and ifThis.nameIs(self.iteratorName)(node.test.left):
|
|
186
|
-
# Recurse the while loop body and remove AugAssign that increments the iterator.
|
|
187
|
-
cleanBodyStatements: list[ast.stmt] = []
|
|
188
|
-
for loopStatement in node.body:
|
|
189
|
-
# Recursively visit nested statements.
|
|
190
|
-
visitedStatement = self.visit(loopStatement)
|
|
191
|
-
# Remove direct AugAssign: iterator += 1.
|
|
192
|
-
if (isinstance(loopStatement, ast.AugAssign) and
|
|
193
|
-
isinstance(loopStatement.target, ast.Name) and
|
|
194
|
-
loopStatement.target.id == self.iteratorName and
|
|
195
|
-
isinstance(loopStatement.op, ast.Add) and
|
|
196
|
-
isinstance(loopStatement.value, ast.Constant) and
|
|
197
|
-
loopStatement.value.value == 1):
|
|
198
|
-
continue
|
|
199
|
-
cleanBodyStatements.append(visitedStatement)
|
|
200
|
-
|
|
201
|
-
newStatements: list[ast.stmt] = []
|
|
202
|
-
# Unroll using the filtered body.
|
|
203
|
-
for iterationIndex in range(self.iterationsTotal):
|
|
204
|
-
for loopStatement in cleanBodyStatements:
|
|
205
|
-
copiedStatement = copy.deepcopy(loopStatement)
|
|
206
|
-
replacer = ReplaceIterator(self.iteratorName, iterationIndex)
|
|
207
|
-
newStatement = replacer.visit(copiedStatement)
|
|
208
|
-
ast.fix_missing_locations(newStatement)
|
|
209
|
-
newStatements.append(newStatement)
|
|
210
|
-
# Optionally, process the orelse block.
|
|
211
|
-
if node.orelse:
|
|
212
|
-
for elseStmt in node.orelse:
|
|
213
|
-
visitedElse = self.visit(elseStmt)
|
|
214
|
-
if isinstance(visitedElse, list):
|
|
215
|
-
newStatements.extend(visitedElse)
|
|
216
|
-
else:
|
|
217
|
-
newStatements.append(visitedElse)
|
|
218
|
-
return newStatements
|
|
219
|
-
return [cast(ast.stmt, self.generic_visit(node))]
|
|
220
|
-
|
|
221
|
-
newFunctionDef = WhileLoopUnroller(iteratorName, iterationsTotal).visit(FunctionDefTarget)
|
|
222
|
-
ast.fix_missing_locations(newFunctionDef)
|
|
223
|
-
return newFunctionDef
|
|
224
|
-
|
|
225
|
-
def makeLauncherBasicJobNumba(callableTarget: str, pathFilenameFoldsTotal: Path) -> ast.Module:
|
|
226
|
-
linesLaunch = f"""
|
|
227
|
-
if __name__ == '__main__':
|
|
228
|
-
import time
|
|
229
|
-
timeStart = time.perf_counter()
|
|
230
|
-
foldsTotal = {callableTarget}()
|
|
231
|
-
print(foldsTotal, time.perf_counter() - timeStart)
|
|
232
|
-
writeStream = open('{pathFilenameFoldsTotal.as_posix()}', 'w')
|
|
233
|
-
writeStream.write(str(foldsTotal))
|
|
234
|
-
writeStream.close()
|
|
235
|
-
"""
|
|
236
|
-
return ast.parse(linesLaunch)
|
|
237
|
-
|
|
238
|
-
def makeFunctionDef(astModule: ast.Module,
|
|
239
|
-
callableTarget: str,
|
|
240
|
-
parametersNumba: ParametersNumba | None = None,
|
|
241
|
-
inlineCallables: bool | None = False,
|
|
242
|
-
unpackArrays: bool | None = False,
|
|
243
|
-
allImports: UniversalImportTracker | None = None) -> tuple[ast.FunctionDef, UniversalImportTracker]:
|
|
244
|
-
if allImports is None:
|
|
245
|
-
allImports = UniversalImportTracker()
|
|
246
|
-
for statement in astModule.body:
|
|
247
|
-
if isinstance(statement, (ast.Import, ast.ImportFrom)):
|
|
248
|
-
allImports.addAst(statement)
|
|
249
|
-
|
|
250
|
-
if inlineCallables:
|
|
251
|
-
dictionaryFunctionDef = {statement.name: statement for statement in astModule.body if isinstance(statement, ast.FunctionDef)}
|
|
252
|
-
callableInlinerWorkhorse = RecursiveInliner(dictionaryFunctionDef)
|
|
253
|
-
# NOTE the inliner assumes each function is not called more than once
|
|
254
|
-
# TODO change the inliner to handle multiple calls to the same function
|
|
255
|
-
FunctionDefTarget = callableInlinerWorkhorse.inlineFunctionBody(callableTarget)
|
|
256
|
-
else:
|
|
257
|
-
FunctionDefTarget = next((node for node in astModule.body if isinstance(node, ast.FunctionDef) and node.name == callableTarget), None)
|
|
258
|
-
if not FunctionDefTarget:
|
|
259
|
-
raise ValueError(f"Could not find function {callableTarget} in source code")
|
|
260
|
-
|
|
261
|
-
ast.fix_missing_locations(FunctionDefTarget)
|
|
262
|
-
|
|
263
|
-
FunctionDefTarget, allImports = decorateCallableWithNumba(FunctionDefTarget, allImports, parametersNumba)
|
|
264
|
-
|
|
265
|
-
# NOTE vestigial hardcoding
|
|
266
|
-
if unpackArrays:
|
|
267
|
-
for tupleUnpack in [(indexMy, 'my'), (indexTrack, 'track')]:
|
|
268
|
-
unpacker = UnpackArrays(*tupleUnpack)
|
|
269
|
-
FunctionDefTarget = cast(ast.FunctionDef, unpacker.visit(FunctionDefTarget))
|
|
270
|
-
ast.fix_missing_locations(FunctionDefTarget)
|
|
271
|
-
|
|
272
|
-
return FunctionDefTarget, allImports
|
|
273
|
-
|
|
274
|
-
def decorateCallableWithNumba(FunctionDefTarget: ast.FunctionDef, allImports: UniversalImportTracker, parametersNumba: ParametersNumba | None = None) -> tuple[ast.FunctionDef, UniversalImportTracker]:
|
|
275
|
-
def Z0Z_UnhandledDecorators(astCallable: ast.FunctionDef) -> ast.FunctionDef:
|
|
276
|
-
# TODO: more explicit handling of decorators. I'm able to ignore this because I know `algorithmSource` doesn't have any decorators.
|
|
277
|
-
for decoratorItem in astCallable.decorator_list.copy():
|
|
278
|
-
import warnings
|
|
279
|
-
astCallable.decorator_list.remove(decoratorItem)
|
|
280
|
-
warnings.warn(f"Removed decorator {ast.unparse(decoratorItem)} from {astCallable.name}")
|
|
281
|
-
return astCallable
|
|
282
|
-
|
|
283
|
-
def make_arg4parameter(signatureElement: ast.arg) -> ast.Subscript | None:
|
|
284
|
-
if isinstance(signatureElement.annotation, ast.Subscript) and isinstance(signatureElement.annotation.slice, ast.Tuple):
|
|
285
|
-
annotationShape = signatureElement.annotation.slice.elts[0]
|
|
286
|
-
if isinstance(annotationShape, ast.Subscript) and isinstance(annotationShape.slice, ast.Tuple):
|
|
287
|
-
shapeAsListSlices = [ast.Slice() for axis in range(len(annotationShape.slice.elts))]
|
|
288
|
-
shapeAsListSlices[-1] = ast.Slice(step=ast.Constant(value=1))
|
|
289
|
-
shapeAST = ast.Tuple(elts=list(shapeAsListSlices), ctx=ast.Load())
|
|
290
|
-
else:
|
|
291
|
-
shapeAST = ast.Slice(step=ast.Constant(value=1))
|
|
292
|
-
|
|
293
|
-
annotationDtype = signatureElement.annotation.slice.elts[1]
|
|
294
|
-
if (isinstance(annotationDtype, ast.Subscript) and isinstance(annotationDtype.slice, ast.Attribute)):
|
|
295
|
-
datatypeAST = annotationDtype.slice.attr
|
|
296
|
-
else:
|
|
297
|
-
datatypeAST = None
|
|
298
|
-
|
|
299
|
-
ndarrayName = signatureElement.arg
|
|
300
|
-
Z0Z_hacky_dtype = hackSSOTdatatype(ndarrayName)
|
|
301
|
-
datatype_attr = datatypeAST or Z0Z_hacky_dtype
|
|
302
|
-
allImports.addImportFromStr(datatypeModuleDecorator, datatype_attr)
|
|
303
|
-
datatypeNumba = ast.Name(id=datatype_attr, ctx=ast.Load())
|
|
304
|
-
|
|
305
|
-
return ast.Subscript(value=datatypeNumba, slice=shapeAST, ctx=ast.Load())
|
|
306
|
-
return
|
|
307
|
-
|
|
308
|
-
datatypeModuleDecorator = Z0Z_getDatatypeModuleScalar()
|
|
309
|
-
list_argsDecorator: Sequence[ast.expr] = []
|
|
310
|
-
|
|
311
|
-
list_arg4signature_or_function: list[ast.expr] = []
|
|
312
|
-
for parameter in FunctionDefTarget.args.args:
|
|
313
|
-
signatureElement = make_arg4parameter(parameter)
|
|
314
|
-
if signatureElement:
|
|
315
|
-
list_arg4signature_or_function.append(signatureElement)
|
|
316
|
-
|
|
317
|
-
if FunctionDefTarget.returns and isinstance(FunctionDefTarget.returns, ast.Name):
|
|
318
|
-
theReturn: ast.Name = FunctionDefTarget.returns
|
|
319
|
-
list_argsDecorator = [cast(ast.expr, ast.Call(func=ast.Name(id=theReturn.id, ctx=ast.Load())
|
|
320
|
-
, args=list_arg4signature_or_function if list_arg4signature_or_function else [] , keywords=[] ) )]
|
|
321
|
-
elif list_arg4signature_or_function:
|
|
322
|
-
list_argsDecorator = [cast(ast.expr, ast.Tuple(elts=list_arg4signature_or_function, ctx=ast.Load()))]
|
|
323
|
-
|
|
324
|
-
for decorator in FunctionDefTarget.decorator_list.copy():
|
|
325
|
-
if thisIsAnyNumbaJitDecorator(decorator):
|
|
326
|
-
decorator = cast(ast.Call, decorator)
|
|
327
|
-
if parametersNumba is None:
|
|
328
|
-
parametersNumbaSherpa = Then.copy_astCallKeywords(decorator)
|
|
329
|
-
if (HunterIsSureThereAreBetterWaysToDoThis := True):
|
|
330
|
-
if parametersNumbaSherpa:
|
|
331
|
-
parametersNumba = cast(ParametersNumba, parametersNumbaSherpa)
|
|
332
|
-
FunctionDefTarget.decorator_list.remove(decorator)
|
|
333
|
-
|
|
334
|
-
FunctionDefTarget = Z0Z_UnhandledDecorators(FunctionDefTarget)
|
|
335
|
-
if parametersNumba is None:
|
|
336
|
-
parametersNumba = parametersNumbaDEFAULT
|
|
337
|
-
listDecoratorKeywords = [ast.keyword(arg=parameterName, value=ast.Constant(value=parameterValue)) for parameterName, parameterValue in parametersNumba.items()]
|
|
338
|
-
|
|
339
|
-
decoratorModule = Z0Z_getDatatypeModuleScalar()
|
|
340
|
-
decoratorCallable = Z0Z_getDecoratorCallable()
|
|
341
|
-
allImports.addImportFromStr(decoratorModule, decoratorCallable)
|
|
342
|
-
astDecorator = Then.make_astCall(decoratorCallable, list_argsDecorator, listDecoratorKeywords, None)
|
|
343
|
-
|
|
344
|
-
FunctionDefTarget.decorator_list = [astDecorator]
|
|
345
|
-
return FunctionDefTarget, allImports
|