mapFolding 0.3.6__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,17 @@
1
- from numpy import integer
2
1
  from typing import Any, Tuple
3
- import numba
4
- from mapFolding import indexMy, indexTrack
2
+ from numpy import integer
5
3
  import numpy
4
+ from mapFolding import indexMy, indexTrack
5
+ import numba
6
6
 
7
7
  @numba.jit((numba.uint8[:, :, ::1], numba.int64[::1], numba.uint8[::1], numba.uint8[::1], numba.uint8[:, ::1]), _nrt=True, boundscheck=False, cache=True, error_model='numpy', fastmath=True, forceinline=False, inline='never', looplift=False, no_cfunc_wrapper=True, no_cpython_wrapper=True, nopython=True, parallel=True)
8
- def countParallel(connectionGraph: numpy.ndarray[Tuple[int, int, int], numpy.dtype[integer[Any]]], foldGroups: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], gapsWherePARALLEL: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], myPARALLEL: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], trackPARALLEL: numpy.ndarray[Tuple[int, int], numpy.dtype[integer[Any]]]):
9
- for indexSherpa in numba.prange(myPARALLEL[indexMy.taskDivisions.value]):
10
- groupsOfFolds = numba.types.int64(0)
8
+ def countParallel(connectionGraph: numpy.ndarray[Tuple[int, int, int], numpy.dtype[integer[Any]]], foldGroups: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], gapsWhere: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], my: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], track: numpy.ndarray[Tuple[int, int], numpy.dtype[integer[Any]]]):
9
+ gapsWherePARALLEL = gapsWhere.copy()
10
+ myPARALLEL = my.copy()
11
+ trackPARALLEL = track.copy()
12
+ taskDivisionsPrange = myPARALLEL[indexMy.taskDivisions.value]
13
+ for indexSherpa in numba.prange(taskDivisionsPrange):
14
+ groupsOfFolds: int = 0
11
15
  gapsWhere = gapsWherePARALLEL.copy()
12
16
  my = myPARALLEL.copy()
13
17
  my[indexMy.taskIndex.value] = indexSherpa
@@ -1,8 +1,8 @@
1
- from numpy import integer
2
- import numba
3
1
  from typing import Any, Tuple
4
- from mapFolding import indexMy, indexTrack
2
+ from numpy import integer
5
3
  import numpy
4
+ from mapFolding import indexMy, indexTrack
5
+ import numba
6
6
 
7
7
  @numba.jit((numba.uint8[:, :, ::1], numba.int64[::1], numba.uint8[::1], numba.uint8[::1], numba.uint8[:, ::1]), _nrt=True, boundscheck=False, cache=True, error_model='numpy', fastmath=True, forceinline=False, inline='never', looplift=False, no_cfunc_wrapper=True, no_cpython_wrapper=True, nopython=True, parallel=False)
8
8
  def countSequential(connectionGraph: numpy.ndarray[Tuple[int, int, int], numpy.dtype[integer[Any]]], foldGroups: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], gapsWhere: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], my: numpy.ndarray[Tuple[int], numpy.dtype[integer[Any]]], track: numpy.ndarray[Tuple[int, int], numpy.dtype[integer[Any]]]):
@@ -19,7 +19,7 @@ def countSequential(connectionGraph: numpy.ndarray[Tuple[int, int, int], numpy.d
19
19
  indexMiniGap = my[indexMy.indexMiniGap.value]
20
20
  gap1ndex = my[indexMy.gap1ndex.value]
21
21
  taskIndex = my[indexMy.taskIndex.value]
22
- groupsOfFolds = numba.types.int64(0)
22
+ groupsOfFolds: int = 0
23
23
  doFindGaps = True
24
24
  while leaf1ndex:
25
25
  if (doFindGaps := (leaf1ndex <= 1 or leafBelow[0] == 1)) and leaf1ndex > foldGroups[-1]:
@@ -1,125 +0,0 @@
1
- from cffconvert.cli.create_citation import create_citation
2
- from cffconvert.cli.validate_or_write_output import validate_or_write_output
3
- from typing import Any, Dict
4
- import cffconvert
5
- import pathlib
6
- import packaging.metadata
7
- import tomli
8
- import ruamel.yaml
9
- import packaging
10
- from packaging.metadata import Metadata as PyPAMetadata
11
- import packaging.utils
12
- import packaging.version
13
-
14
- def addPypaMetadata(citation: cffconvert.Citation, metadata: PyPAMetadata) -> cffconvert.Citation:
15
- """
16
- Map the PyPA metadata to the citation's internal representation.
17
-
18
- Mapping:
19
- - title: metadata.name
20
- - version: metadata.version (converted to string)
21
- - keywords: metadata.keywords
22
- - license: metadata.license_expression
23
- - url: from project URLs (homepage)
24
- - repository: from project URLs (repository)
25
- """
26
- # Access the internal dictionary (used for conversion)
27
- citationData: Dict[str, Any] = citation._cffobj
28
-
29
- # Update title from PyPA metadata name
30
- if metadata.name:
31
- citationData["title"] = metadata.name
32
-
33
- # Update version from PyPA metadata version
34
- if metadata.version:
35
- citationData["version"] = str(metadata.version)
36
-
37
- # Update keywords from PyPA metadata keywords
38
- if metadata.keywords:
39
- citationData["keywords"] = metadata.keywords
40
-
41
- # Update license from PyPA metadata license_expression
42
- if metadata.license_expression:
43
- citationData["license"] = metadata.license_expression
44
-
45
- # Retrieve the project URLs that were attached in getPypaMetadata
46
- projectURLs: Dict[str, str] = getattr(metadata, "_project_urls", {})
47
-
48
- # Update the homepage URL
49
- if "homepage" in projectURLs:
50
- citationData["url"] = projectURLs["homepage"]
51
-
52
- # Update the repository URL
53
- if "repository" in projectURLs:
54
- citationData["repository"] = projectURLs["repository"]
55
-
56
- return citation
57
-
58
- def getPypaMetadata(packageData: Dict[str, Any]) -> PyPAMetadata:
59
- """
60
- Create a PyPA metadata object (version 2.4) from packageData.
61
-
62
- Mapping for project URLs:
63
- - 'homepage' and 'repository' are accepted from packageData['urls'].
64
- """
65
- dictionaryProjectURLs: Dict[str, str] = {}
66
- for urlKey, urlValue in packageData.get("urls", {}).items():
67
- lowerUrlKey = urlKey.lower()
68
- if lowerUrlKey in ("homepage", "repository"):
69
- dictionaryProjectURLs[lowerUrlKey] = urlValue
70
-
71
- metadataRaw = packaging.metadata.RawMetadata(
72
- keywords=packageData.get("keywords", []),
73
- license_expression=packageData.get("license", {}).get("text", ""),
74
- metadata_version="2.4",
75
- name=packaging.utils.canonicalize_name(packageData.get("name", None), validate=True),
76
- project_urls=dictionaryProjectURLs,
77
- version=packageData.get("version", None),
78
- )
79
-
80
- metadata = PyPAMetadata().from_raw(metadataRaw)
81
- # Attach the project URLs dictionary so it can be used later.
82
- setattr(metadata, "_project_urls", dictionaryProjectURLs)
83
- return metadata
84
-
85
- def logistics():
86
- # Determine paths from your SSOT.
87
- packageName: str = "mapFolding"
88
- pathRepoRoot = pathlib.Path(__file__).parent.parent.parent
89
- pathFilenamePackageSSOT = pathRepoRoot / "pyproject.toml"
90
- filenameGitHubAction = "updateCitation.yml"
91
- pathFilenameGitHubAction = pathRepoRoot / ".github" / "workflows" / filenameGitHubAction
92
-
93
- filenameCitationDOTcff = "CITATION.cff"
94
- pathCitations = pathRepoRoot / packageName / "citations"
95
- pathFilenameCitationSSOT = pathCitations / filenameCitationDOTcff
96
- pathFilenameCitationDOTcffRepo = pathRepoRoot / filenameCitationDOTcff
97
-
98
- # Create a citation object from the SSOT citation file.
99
- citationObject: cffconvert.Citation = create_citation(infile=pathFilenameCitationSSOT, url=None)
100
- # Print the current citation in CFF format (for debugging) using the as_cff method.
101
- print(citationObject.as_cff())
102
-
103
- # Load package metadata from pyproject.toml.
104
- tomlPackageData: Dict[str, Any] = tomli.loads(pathFilenamePackageSSOT.read_text())["project"]
105
- pypaMetadata: PyPAMetadata = getPypaMetadata(tomlPackageData)
106
-
107
- # Map the PyPA metadata into the citation's internal representation.
108
- citationObject = addPypaMetadata(citation=citationObject, metadata=pypaMetadata)
109
-
110
- # Validate and write out the updated citation file in both locations.
111
- # validate_or_write_output(
112
- # outfile=pathFilenameCitationSSOT,
113
- # outputformat="cff",
114
- # validate_only=False,
115
- # citation=citationObject,
116
- # )
117
- validate_or_write_output(
118
- outfile=pathFilenameCitationDOTcffRepo,
119
- outputformat="cff",
120
- validate_only=False,
121
- citation=citationObject,
122
- )
123
-
124
- if __name__ == "__main__":
125
- logistics()
@@ -1,27 +0,0 @@
1
- benchmarks/benchmarking.py,sha256=HD_0NSvuabblg94ftDre6LFnXShTe8MYj3hIodW-zV0,3076
2
- citations/updateCitation.py,sha256=PPxOERlYnw9b9xydiZL8utTU-sC2B4rBfOgjXc1S0OY,5264
3
- citations/updateCitationgpt.py,sha256=NtgSP4BCO5YcaYcYYb31vOxXcp3hmooug8VYpbhTc_w,4751
4
- reference/flattened.py,sha256=6blZ2Y9G8mu1F3gV8SKndPE398t2VVFlsgKlyeJ765A,16538
5
- reference/hunterNumba.py,sha256=HWndRgsajOf76rbb2LDNEZ6itsdYbyV-k3wgOFjeR6c,7104
6
- reference/irvineJavaPort.py,sha256=Sj-63Z-OsGuDoEBXuxyjRrNmmyl0d7Yz_XuY7I47Oyg,4250
7
- reference/jax.py,sha256=rojyK80lOATtbzxjGOHWHZngQa47CXCLJHZwIdN2MwI,14955
8
- reference/lunnan.py,sha256=XEcql_gxvCCghb6Or3qwmPbn4IZUbZTaSmw_fUjRxZE,5037
9
- reference/lunnanNumpy.py,sha256=HqDgSwTOZA-G0oophOEfc4zs25Mv4yw2aoF1v8miOLk,4653
10
- reference/lunnanWhile.py,sha256=7NY2IKO5XBgol0aWWF_Fi-7oTL9pvu_z6lB0TF1uVHk,4063
11
- reference/rotatedEntryPoint.py,sha256=z0QyDQtnMvXNj5ntWzzJUQUMFm1-xHGLVhtYzwmczUI,11530
12
- reference/total_countPlus1vsPlusN.py,sha256=usenM8Yn_G1dqlPl7NKKkcnbohBZVZBXTQRm2S3_EDA,8106
13
- someAssemblyRequired/__init__.py,sha256=3JnAKXfaYPtmxV_4AnZ6KpCosT_0GFV5Nw7K8sz4-Uo,34
14
- someAssemblyRequired/generalizeSourceCode.py,sha256=qyJD0ZdG0t-SYTItL_JjaIXm3-joWt3e-2nMSAH4Dbg,6392
15
- someAssemblyRequired/getLLVMforNoReason.py,sha256=FtJzw2pZS3A4NimWdZsegXaU-vKeCw8m67kcfb5wvGM,894
16
- someAssemblyRequired/makeJob.py,sha256=RTC80FhDrR19GqHtEeo6GpmlWZQESuf8FXqBqVzdOpk,1465
17
- someAssemblyRequired/synthesizeModuleJob.py,sha256=uyODwdI1_a76Pu21JsCNEapVW78yUD-CfInX-vg8U-w,7419
18
- someAssemblyRequired/synthesizeModules.py,sha256=foxk-mG-HGVap2USiA3ppCyWWXUmkLFzQiKacp5DD9M,11569
19
- syntheticModules/Initialize.py,sha256=KIAxLSyblzDTL8QJYINmdRjk2iRVYzXOWeqY8P6wPgw,4024
20
- syntheticModules/Parallel.py,sha256=Kq1uo5kfeeczk871yxaagsaNz8zaM8GWy0S3hZAEQz4,5343
21
- syntheticModules/Sequential.py,sha256=JwpHNFt_w77J0RBVoBji-OLnYNSTMnusRlYU-6b4P2w,3643
22
- syntheticModules/__init__.py,sha256=lUDBXOiislfP2sIxT13_GZgElaytoYqk0ODUsucMYew,117
23
- mapFolding-0.3.6.dist-info/METADATA,sha256=ViqnejEpnb4391VOp-nqnGXqyAwaXiqwUx3wpZbqyxM,7688
24
- mapFolding-0.3.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
25
- mapFolding-0.3.6.dist-info/entry_points.txt,sha256=F3OUeZR1XDTpoH7k3wXuRb3KF_kXTTeYhu5AGK1SiOQ,146
26
- mapFolding-0.3.6.dist-info/top_level.txt,sha256=yVG9dNZywoaddcsUdEDg7o0XOBzJd_4Z-sDaXGHpiMY,69
27
- mapFolding-0.3.6.dist-info/RECORD,,
@@ -1,122 +0,0 @@
1
- from mapFolding import datatypeLargeDEFAULT, datatypeMediumDEFAULT, datatypeSmallDEFAULT
2
- from typing import Dict, Optional, List, Set, Union
3
- import ast
4
-
5
- class RecursiveInlinerWithEnum(ast.NodeTransformer):
6
- """Process AST nodes to inline functions and substitute enum values.
7
- Also handles function decorators during inlining."""
8
-
9
- def __init__(self, dictionaryFunctions: Dict[str, ast.FunctionDef], dictionaryEnumValues: Dict[str, int]) -> None:
10
- self.dictionaryFunctions = dictionaryFunctions
11
- self.dictionaryEnumValues = dictionaryEnumValues
12
- self.processed = set()
13
-
14
- def inlineFunctionBody(self, functionName: str) -> Optional[ast.FunctionDef]:
15
- if functionName in self.processed:
16
- return None
17
-
18
- self.processed.add(functionName)
19
- inlineDefinition = self.dictionaryFunctions[functionName]
20
- # Recursively process the function body
21
- for node in ast.walk(inlineDefinition):
22
- self.visit(node)
23
- return inlineDefinition
24
-
25
- def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
26
- # Substitute enum identifiers (e.g., indexMy.leaf1ndex.value)
27
- if isinstance(node.value, ast.Attribute) and isinstance(node.value.value, ast.Name):
28
- enumPath = f"{node.value.value.id}.{node.value.attr}.{node.attr}"
29
- if enumPath in self.dictionaryEnumValues:
30
- return ast.Constant(value=self.dictionaryEnumValues[enumPath])
31
- return self.generic_visit(node)
32
-
33
- def visit_Call(self, node: ast.Call) -> ast.AST:
34
- callNode = self.generic_visit(node)
35
- if isinstance(callNode, ast.Call) and isinstance(callNode.func, ast.Name) and callNode.func.id in self.dictionaryFunctions:
36
- inlineDefinition = self.inlineFunctionBody(callNode.func.id)
37
- if (inlineDefinition and inlineDefinition.body):
38
- lastStmt = inlineDefinition.body[-1]
39
- if isinstance(lastStmt, ast.Return) and lastStmt.value is not None:
40
- return self.visit(lastStmt.value)
41
- elif isinstance(lastStmt, ast.Expr) and lastStmt.value is not None:
42
- return self.visit(lastStmt.value)
43
- return ast.Constant(value=None)
44
- return callNode
45
-
46
- def visit_Expr(self, node: ast.Expr) -> Union[ast.AST, List[ast.AST]]:
47
- if isinstance(node.value, ast.Call):
48
- if isinstance(node.value.func, ast.Name) and node.value.func.id in self.dictionaryFunctions:
49
- inlineDefinition = self.inlineFunctionBody(node.value.func.id)
50
- if inlineDefinition:
51
- return [self.visit(stmt) for stmt in inlineDefinition.body]
52
- return self.generic_visit(node)
53
-
54
- def findRequiredImports(node: ast.AST) -> Set[str]:
55
- """Find all modules that need to be imported based on AST analysis.
56
- NOTE: due to hardcoding, this is a glorified regex. No, wait, this is less versatile than regex."""
57
- requiredImports = set()
58
-
59
- class ImportFinder(ast.NodeVisitor):
60
- def visit_Name(self, node: ast.Name) -> None:
61
- if node.id in {'numba'}:
62
- requiredImports.add(node.id)
63
- self.generic_visit(node)
64
-
65
- def visitDecorator(self, node: ast.AST) -> None:
66
- if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
67
- if node.func.id == 'jit':
68
- requiredImports.add('numba')
69
- self.generic_visit(node)
70
-
71
- ImportFinder().visit(node)
72
- return requiredImports
73
-
74
- def generateImports(requiredImports: Set[str]) -> str:
75
- """Generate import statements based on required modules."""
76
- importStatements = {'import numba', 'from mapFolding import indexMy, indexTrack'}
77
-
78
- importMapping = {
79
- 'numba': 'import numba',
80
- }
81
-
82
- for moduleName in sorted(requiredImports):
83
- if moduleName in importMapping:
84
- importStatements.add(importMapping[moduleName])
85
-
86
- return '\n'.join(importStatements)
87
-
88
- def makeInlineFunction(sourceCode: str, targetFunctionName: str, dictionaryEnumValues: Dict[str, int], skipEnum: bool=False, **keywordArguments: Optional[str]):
89
- datatypeLarge = keywordArguments.get('datatypeLarge', datatypeLargeDEFAULT)
90
- datatypeMedium = keywordArguments.get('datatypeMedium', datatypeMediumDEFAULT)
91
- datatypeSmall = keywordArguments.get('datatypeSmall', datatypeSmallDEFAULT)
92
- if skipEnum:
93
- dictionaryEnumValues = {}
94
- dictionaryParsed = ast.parse(sourceCode)
95
- dictionaryFunctions = {
96
- element.name: element
97
- for element in dictionaryParsed.body
98
- if isinstance(element, ast.FunctionDef)
99
- }
100
- nodeTarget = dictionaryFunctions[targetFunctionName]
101
- nodeInliner = RecursiveInlinerWithEnum(dictionaryFunctions, dictionaryEnumValues)
102
- nodeInlined = nodeInliner.visit(nodeTarget)
103
- ast.fix_missing_locations(nodeInlined)
104
- callableInlinedDecorators = [decorator for decorator in nodeInlined.decorator_list]
105
-
106
- requiredImports = findRequiredImports(nodeInlined)
107
- importStatements = generateImports(requiredImports)
108
- importsRequired = importStatements
109
- dictionaryDecoratorsNumba={
110
- 'countInitialize':
111
- f'@numba.jit((numba.{datatypeSmall}[:,:,::1], numba.{datatypeMedium}[::1], numba.{datatypeSmall}[::1], numba.{datatypeMedium}[:,::1]), parallel=False, boundscheck=False, cache=True, error_model="numpy", fastmath=True, looplift=False, nogil=True, nopython=True)\n',
112
- 'countParallel':
113
- f'@numba.jit((numba.{datatypeSmall}[:,:,::1], numba.{datatypeLarge}[::1], numba.{datatypeMedium}[::1], numba.{datatypeSmall}[::1], numba.{datatypeMedium}[:,::1]), parallel=True, boundscheck=False, cache=True, error_model="numpy", fastmath=True, looplift=False, nogil=True, nopython=True)\n',
114
- 'countSequential':
115
- f'@numba.jit((numba.{datatypeSmall}[:,:,::1], numba.{datatypeLarge}[::1], numba.{datatypeMedium}[::1], numba.{datatypeSmall}[::1], numba.{datatypeMedium}[:,::1]), parallel=False, boundscheck=False, cache=True, error_model="numpy", fastmath=True, looplift=False, nogil=True, nopython=True)\n',
116
- }
117
-
118
- lineNumbaDecorator = dictionaryDecoratorsNumba[targetFunctionName]
119
-
120
- # inlinedCode = ast.unparse(ast.Module(body=[nodeInlined], type_ignores=[]))
121
- callableInlined = lineNumbaDecorator + ast.unparse(nodeInlined)
122
- return (callableInlined, callableInlinedDecorators, importsRequired)
@@ -1,216 +0,0 @@
1
- from mapFolding import indexMy, indexTrack, getAlgorithmSource, ParametersNumba, parametersNumbaDEFAULT, hackSSOTdtype
2
- from mapFolding import datatypeLargeDEFAULT, datatypeMediumDEFAULT, datatypeSmallDEFAULT, EnumIndices
3
- import pathlib
4
- import inspect
5
- import numpy
6
- import numba
7
- from typing import Dict, Optional, List, Union, Sequence, Type, cast
8
- import ast
9
-
10
- algorithmSource = getAlgorithmSource()
11
-
12
- class RecursiveInliner(ast.NodeTransformer):
13
- def __init__(self, dictionaryFunctions: Dict[str, ast.FunctionDef]):
14
- self.dictionaryFunctions = dictionaryFunctions
15
- self.processed = set()
16
-
17
- def inlineFunctionBody(self, functionName: str) -> Optional[ast.FunctionDef]:
18
- if (functionName in self.processed):
19
- return None
20
-
21
- self.processed.add(functionName)
22
- inlineDefinition = self.dictionaryFunctions[functionName]
23
- # Recursively process the function body
24
- for node in ast.walk(inlineDefinition):
25
- self.visit(node)
26
- return inlineDefinition
27
-
28
- def visit_Call(self, node: ast.Call) -> ast.AST:
29
- callNode = self.generic_visit(node)
30
- if (isinstance(callNode, ast.Call) and isinstance(callNode.func, ast.Name) and callNode.func.id in self.dictionaryFunctions):
31
- inlineDefinition = self.inlineFunctionBody(callNode.func.id)
32
- if (inlineDefinition and inlineDefinition.body):
33
- lastStmt = inlineDefinition.body[-1]
34
- if (isinstance(lastStmt, ast.Return) and lastStmt.value is not None):
35
- return self.visit(lastStmt.value)
36
- elif (isinstance(lastStmt, ast.Expr) and lastStmt.value is not None):
37
- return self.visit(lastStmt.value)
38
- return ast.Constant(value=None)
39
- return callNode
40
-
41
- def visit_Expr(self, node: ast.Expr) -> Union[ast.AST, List[ast.AST]]:
42
- if (isinstance(node.value, ast.Call)):
43
- if (isinstance(node.value.func, ast.Name) and node.value.func.id in self.dictionaryFunctions):
44
- inlineDefinition = self.inlineFunctionBody(node.value.func.id)
45
- if (inlineDefinition):
46
- return [self.visit(stmt) for stmt in inlineDefinition.body]
47
- return self.generic_visit(node)
48
-
49
- def decorateCallableWithNumba(astCallable: ast.FunctionDef, parallel: bool=False, **keywordArguments: Optional[str]) -> ast.FunctionDef:
50
- def makeNumbaParameterSignatureElement(signatureElement: ast.arg):
51
- if isinstance(signatureElement.annotation, ast.Subscript) and isinstance(signatureElement.annotation.slice, ast.Tuple):
52
- annotationShape = signatureElement.annotation.slice.elts[0]
53
- if isinstance(annotationShape, ast.Subscript) and isinstance(annotationShape.slice, ast.Tuple):
54
- shapeAsListSlices: Sequence[ast.expr] = [ast.Slice() for axis in range(len(annotationShape.slice.elts))]
55
- shapeAsListSlices[-1] = ast.Slice(step=ast.Constant(value=1))
56
- shapeAST = ast.Tuple(elts=list(shapeAsListSlices), ctx=ast.Load())
57
- else:
58
- shapeAST = ast.Slice(step=ast.Constant(value=1))
59
-
60
- annotationDtype = signatureElement.annotation.slice.elts[1]
61
- if (isinstance(annotationDtype, ast.Subscript) and isinstance(annotationDtype.slice, ast.Attribute)):
62
- datatypeAST = annotationDtype.slice.attr
63
- else:
64
- datatypeAST = None
65
-
66
- ndarrayName = signatureElement.arg
67
- Z0Z_hackyStr = hackSSOTdtype[ndarrayName]
68
- Z0Z_hackyStr = Z0Z_hackyStr[0] + 'ata' + Z0Z_hackyStr[1:]
69
- datatype_attr = keywordArguments.get(Z0Z_hackyStr, None) or datatypeAST or eval(Z0Z_hackyStr+'DEFAULT')
70
-
71
- datatypeNumba = ast.Attribute(value=ast.Name(id='numba', ctx=ast.Load()), attr=datatype_attr, ctx=ast.Load())
72
-
73
- return ast.Subscript(value=datatypeNumba, slice=shapeAST, ctx=ast.Load())
74
-
75
- # callableSourceDecorators = [decorator for decorator in callableInlined.decorator_list]
76
-
77
- listNumbaParameterSignature: Sequence[ast.expr] = []
78
- for parameter in astCallable.args.args:
79
- signatureElement = makeNumbaParameterSignatureElement(parameter)
80
- if (signatureElement):
81
- listNumbaParameterSignature.append(signatureElement)
82
-
83
- astArgsNumbaSignature = ast.Tuple(elts=listNumbaParameterSignature, ctx=ast.Load())
84
-
85
- if astCallable.name == 'countInitialize':
86
- parametersNumba = {}
87
- else:
88
- parametersNumba = parametersNumbaDEFAULT if not parallel else ParametersNumba({**parametersNumbaDEFAULT, 'parallel': True})
89
- listKeywordsNumbaSignature = [ast.keyword(arg=parameterName, value=ast.Constant(value=parameterValue)) for parameterName, parameterValue in parametersNumba.items()]
90
-
91
- astDecoratorNumba = ast.Call(func=ast.Attribute(value=ast.Name(id='numba', ctx=ast.Load()), attr='jit', ctx=ast.Load()), args=[astArgsNumbaSignature], keywords=listKeywordsNumbaSignature)
92
-
93
- astCallable.decorator_list = [astDecoratorNumba]
94
- return astCallable
95
-
96
- class UnpackArrayAccesses(ast.NodeTransformer):
97
- """AST transformer that replaces array accesses with simpler variables."""
98
-
99
- def __init__(self, enumIndexClass: Type[EnumIndices], arrayName: str):
100
- self.enumIndexClass = enumIndexClass
101
- self.arrayName = arrayName
102
- self.substitutions = {}
103
-
104
- def extract_member_name(self, node: ast.AST) -> Optional[str]:
105
- """Recursively extract enum member name from any node in the AST."""
106
- if isinstance(node, ast.Attribute) and node.attr == 'value':
107
- innerAttribute = node.value
108
- while isinstance(innerAttribute, ast.Attribute):
109
- if (isinstance(innerAttribute.value, ast.Name) and innerAttribute.value.id == self.enumIndexClass.__name__):
110
- return innerAttribute.attr
111
- innerAttribute = innerAttribute.value
112
- return None
113
-
114
- def transform_slice_element(self, node: ast.AST) -> ast.AST:
115
- """Transform any enum references within a slice element."""
116
- if isinstance(node, ast.Subscript):
117
- if isinstance(node.slice, ast.Attribute):
118
- member_name = self.extract_member_name(node.slice)
119
- if member_name:
120
- return ast.Name(id=member_name, ctx=node.ctx)
121
- elif isinstance(node, ast.Tuple):
122
- # Handle tuple slices by transforming each element
123
- return ast.Tuple(elts=cast(List[ast.expr], [self.transform_slice_element(elt) for elt in node.elts]), ctx=node.ctx)
124
- elif isinstance(node, ast.Attribute):
125
- member_name = self.extract_member_name(node)
126
- if member_name:
127
- return ast.Name(id=member_name, ctx=ast.Load())
128
- return node
129
-
130
- def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
131
- # Recursively visit any nested subscripts in value or slice
132
- node.value = self.visit(node.value)
133
- node.slice = self.visit(node.slice)
134
- # If node.value is not our arrayName, just return node
135
- if not (isinstance(node.value, ast.Name) and node.value.id == self.arrayName):
136
- return node
137
-
138
- # Handle scalar array access
139
- if isinstance(node.slice, ast.Attribute):
140
- memberName = self.extract_member_name(node.slice)
141
- if memberName:
142
- self.substitutions[memberName] = ('scalar', node)
143
- return ast.Name(id=memberName, ctx=ast.Load())
144
-
145
- # Handle array slice access
146
- if isinstance(node.slice, ast.Tuple) and node.slice.elts:
147
- firstElement = node.slice.elts[0]
148
- memberName = self.extract_member_name(firstElement)
149
- sliceRemainder = [self.visit(elem) for elem in node.slice.elts[1:]]
150
- if memberName:
151
- self.substitutions[memberName] = ('array', node)
152
- if len(sliceRemainder) == 0:
153
- return ast.Name(id=memberName, ctx=ast.Load())
154
- return ast.Subscript(value=ast.Name(id=memberName, ctx=ast.Load()), slice=ast.Tuple(elts=sliceRemainder, ctx=ast.Load()) if len(sliceRemainder) > 1 else sliceRemainder[0], ctx=ast.Load())
155
-
156
- # If single-element tuple, unwrap
157
- if isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 1:
158
- node.slice = node.slice.elts[0]
159
-
160
- return node
161
-
162
- def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
163
- node = cast(ast.FunctionDef, self.generic_visit(node))
164
-
165
- initializations = []
166
- for name, (kind, original_node) in self.substitutions.items():
167
- if kind == 'scalar':
168
- initializations.append(ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=original_node))
169
- else: # array
170
- initializations.append(
171
- ast.Assign(
172
- targets=[ast.Name(id=name, ctx=ast.Store())],
173
- value=ast.Subscript(value=ast.Name(id=self.arrayName, ctx=ast.Load()),
174
- slice=ast.Attribute(value=ast.Attribute(
175
- value=ast.Name(id=self.enumIndexClass.__name__, ctx=ast.Load()),
176
- attr=name, ctx=ast.Load()), attr='value', ctx=ast.Load()), ctx=ast.Load())))
177
-
178
- node.body = initializations + node.body
179
- return node
180
-
181
- def inlineMapFoldingNumba(**keywordArguments: Optional[str]):
182
- codeSource = inspect.getsource(algorithmSource)
183
- pathFilenameAlgorithm = pathlib.Path(inspect.getfile(algorithmSource))
184
-
185
- listPathFilenamesDestination: list[pathlib.Path] = []
186
- listCallables = [ 'countInitialize', 'countParallel', 'countSequential', ]
187
- for callableTarget in listCallables:
188
- codeParsed: ast.Module = ast.parse(codeSource, type_comments=True)
189
- codeSourceImportStatements = {statement for statement in codeParsed.body if isinstance(statement, (ast.Import, ast.ImportFrom))}
190
- dictionaryFunctions = {statement.name: statement for statement in codeParsed.body if isinstance(statement, ast.FunctionDef)}
191
- callableInlinerWorkhorse = RecursiveInliner(dictionaryFunctions)
192
- parallel = callableTarget == 'countParallel'
193
- callableInlined = callableInlinerWorkhorse.inlineFunctionBody(callableTarget)
194
- if callableInlined:
195
- ast.fix_missing_locations(callableInlined)
196
- callableDecorated = decorateCallableWithNumba(callableInlined, parallel, **keywordArguments)
197
-
198
- if callableTarget == 'countSequential':
199
- myUnpacker = UnpackArrayAccesses(indexMy, 'my')
200
- callableDecorated = cast(ast.FunctionDef, myUnpacker.visit(callableDecorated))
201
- ast.fix_missing_locations(callableDecorated)
202
-
203
- trackUnpacker = UnpackArrayAccesses(indexTrack, 'track')
204
- callableDecorated = cast(ast.FunctionDef, trackUnpacker.visit(callableDecorated))
205
- ast.fix_missing_locations(callableDecorated)
206
-
207
- moduleAST = ast.Module(body=cast(List[ast.stmt], list(codeSourceImportStatements) + [callableDecorated]), type_ignores=[])
208
- ast.fix_missing_locations(moduleAST)
209
- moduleSource = ast.unparse(moduleAST)
210
-
211
- pathFilenameDestination = pathFilenameAlgorithm.parent / "syntheticModules" / pathFilenameAlgorithm.with_stem(callableTarget).name[5:None]
212
- pathFilenameDestination.write_text(moduleSource)
213
- listPathFilenamesDestination.append(pathFilenameDestination)
214
-
215
- if __name__ == '__main__':
216
- inlineMapFoldingNumba()