mapFolding 0.3.11__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mapFolding/__init__.py +44 -32
- mapFolding/basecamp.py +50 -50
- mapFolding/beDRY.py +336 -336
- mapFolding/oeis.py +262 -262
- mapFolding/reference/flattened.py +294 -293
- mapFolding/reference/hunterNumba.py +126 -126
- mapFolding/reference/irvineJavaPort.py +99 -99
- mapFolding/reference/jax.py +153 -153
- mapFolding/reference/lunnan.py +148 -148
- mapFolding/reference/lunnanNumpy.py +115 -115
- mapFolding/reference/lunnanWhile.py +114 -114
- mapFolding/reference/rotatedEntryPoint.py +183 -183
- mapFolding/reference/total_countPlus1vsPlusN.py +203 -203
- mapFolding/someAssemblyRequired/__init__.py +2 -1
- mapFolding/someAssemblyRequired/getLLVMforNoReason.py +12 -12
- mapFolding/someAssemblyRequired/makeJob.py +48 -48
- mapFolding/someAssemblyRequired/synthesizeModuleJAX.py +17 -17
- mapFolding/someAssemblyRequired/synthesizeNumba.py +345 -803
- mapFolding/someAssemblyRequired/synthesizeNumbaGeneralized.py +371 -0
- mapFolding/someAssemblyRequired/synthesizeNumbaJob.py +150 -0
- mapFolding/someAssemblyRequired/synthesizeNumbaModules.py +75 -0
- mapFolding/syntheticModules/__init__.py +0 -0
- mapFolding/syntheticModules/numba_countInitialize.py +2 -2
- mapFolding/syntheticModules/numba_countParallel.py +3 -3
- mapFolding/syntheticModules/numba_countSequential.py +28 -28
- mapFolding/syntheticModules/numba_doTheNeedful.py +6 -6
- mapFolding/theDao.py +168 -169
- mapFolding/theSSOT.py +190 -162
- mapFolding/theSSOTnumba.py +91 -75
- mapFolding-0.4.0.dist-info/METADATA +122 -0
- mapFolding-0.4.0.dist-info/RECORD +41 -0
- tests/conftest.py +238 -128
- tests/test_oeis.py +80 -80
- tests/test_other.py +137 -224
- tests/test_tasks.py +21 -21
- tests/test_types.py +2 -2
- mapFolding-0.3.11.dist-info/METADATA +0 -155
- mapFolding-0.3.11.dist-info/RECORD +0 -39
- tests/conftest_tmpRegistry.py +0 -62
- tests/conftest_uniformTests.py +0 -53
- {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/LICENSE +0 -0
- {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/WHEEL +0 -0
- {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/entry_points.txt +0 -0
- {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/top_level.txt +0 -0
mapFolding/reference/jax.py
CHANGED
|
@@ -9,200 +9,200 @@ dtypeMedium = jax.numpy.uint32
|
|
|
9
9
|
dtypeMaximum = jax.numpy.uint32
|
|
10
10
|
|
|
11
11
|
def countFolds(listDimensions: List[int]) -> int:
|
|
12
|
-
|
|
12
|
+
listDimensionsPositive: List[int] = validateListDimensions(listDimensions)
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
14
|
+
n: int = getLeavesTotal(listDimensionsPositive)
|
|
15
|
+
d: int = len(listDimensions)
|
|
16
|
+
import numpy
|
|
17
|
+
D: numpy.ndarray = makeConnectionGraph(listDimensionsPositive)
|
|
18
|
+
connectionGraph = jax.numpy.asarray(D, dtype=dtypeMedium)
|
|
19
|
+
del listDimensionsPositive
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
return foldingsJAX(n, d, connectionGraph)
|
|
22
22
|
|
|
23
23
|
def foldingsJAX(leavesTotal: jaxtyping.UInt32, dimensionsTotal: jaxtyping.UInt32, connectionGraph: jaxtyping.Array) -> jaxtyping.UInt32:
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
def doNothing(argument):
|
|
26
|
+
return argument
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
|
|
29
|
+
comparand = comparisonValues[6]
|
|
30
|
+
return comparand > 0
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
33
|
+
_0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
|
|
34
34
|
|
|
35
|
-
|
|
35
|
+
sentinel = leafBelow.at[0].get().astype(jax.numpy.uint32)
|
|
36
36
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
|
|
38
|
+
lambda argumentX: dao(findGapsDo(argumentX)),
|
|
39
|
+
lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
|
|
40
|
+
allValues)
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
return allValues
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
def findGapsCondition(leafBelowSentinel, activeLeafNumber):
|
|
45
|
+
return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
|
|
46
46
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
47
|
+
def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
48
|
+
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
|
|
49
|
+
return comparisonValues[-1] <= dimensionsTotal
|
|
50
50
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
51
|
+
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
52
|
+
def ifLeafIsUnconstrainedCondition(comparand):
|
|
53
|
+
return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
|
|
54
54
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
55
|
+
def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
56
|
+
unconstrained_unconstrainedLeaf = unconstrainedValues[3]
|
|
57
|
+
unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
|
|
58
|
+
return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
|
|
59
59
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
60
|
+
def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
61
|
+
def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
62
|
+
return comparisonValues[-1] != activeLeaf1ndex
|
|
63
63
|
|
|
64
|
-
|
|
65
|
-
|
|
64
|
+
def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
65
|
+
countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee = countGapsDoValues
|
|
66
66
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
67
|
+
countGapsPotentialGaps = countGapsPotentialGaps.at[countGapsGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
|
|
68
|
+
countGapsGap1ndexLowerBound = jax.numpy.where(jax.numpy.equal(countGapsCountDimensionsGapped[countGapsLeaf1ndexConnectee], 0), countGapsGap1ndexLowerBound + 1, countGapsGap1ndexLowerBound)
|
|
69
|
+
countGapsCountDimensionsGapped = countGapsCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
|
|
70
|
+
countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.uint32)
|
|
71
71
|
|
|
72
|
-
|
|
72
|
+
return (countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee)
|
|
73
73
|
|
|
74
|
-
|
|
74
|
+
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
|
|
75
75
|
|
|
76
|
-
|
|
76
|
+
leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.uint32)
|
|
77
77
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
78
|
+
countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
|
|
79
|
+
countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
|
|
80
|
+
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
|
|
81
81
|
|
|
82
|
-
|
|
82
|
+
return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
|
|
83
83
|
|
|
84
|
-
|
|
84
|
+
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
85
85
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
|
|
87
|
+
ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
|
|
88
|
+
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
|
|
89
89
|
|
|
90
|
-
|
|
91
|
-
|
|
90
|
+
dimensionNumber = 1 + dimensionNumber
|
|
91
|
+
return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
|
|
92
92
|
|
|
93
|
-
|
|
94
|
-
|
|
93
|
+
def almostUselessCondition(comparand):
|
|
94
|
+
return comparand == dimensionsTotal
|
|
95
95
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
96
|
+
def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
97
|
+
def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
|
|
98
|
+
return comparisonValues[-1] < activeLeaf1ndex
|
|
99
99
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
100
|
+
def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
101
|
+
leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
102
|
+
leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
|
|
103
|
+
gapNumberLowerBound = 1 + gapNumberLowerBound
|
|
104
|
+
leafNumber = 1 + leafNumber
|
|
105
|
+
return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
|
|
106
|
+
return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
|
|
108
|
+
def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
109
|
+
return comparisonValues[-1] < gap1ndexCeiling
|
|
110
110
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
111
|
+
def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
112
|
+
gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
|
|
113
|
+
gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
|
|
114
|
+
activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.uint32)
|
|
115
|
+
gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
|
|
116
|
+
index = 1 + index
|
|
117
|
+
return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
_0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
120
120
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
121
|
+
unconstrainedLeaf = jax.numpy.uint32(0)
|
|
122
|
+
dimension1ndex = jax.numpy.uint32(1)
|
|
123
|
+
gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.uint32)
|
|
124
|
+
activeGap1ndex = gap1ndexCeiling
|
|
125
|
+
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
|
|
126
|
+
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
|
|
127
|
+
countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
128
|
+
del dimension1ndex
|
|
129
129
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
130
|
+
leaf1ndex = jax.numpy.uint32(0)
|
|
131
|
+
for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
|
|
132
|
+
for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
133
|
+
gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
134
|
+
del leaf1ndex
|
|
135
135
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
136
|
+
indexMiniGap = activeGap1ndex
|
|
137
|
+
miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
|
|
138
|
+
miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
|
|
139
|
+
countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
|
|
140
|
+
del indexMiniGap
|
|
141
141
|
|
|
142
|
-
|
|
142
|
+
return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
143
143
|
|
|
144
|
-
|
|
145
|
-
|
|
144
|
+
def incrementCondition(leafBelowSentinel, activeLeafNumber):
|
|
145
|
+
return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
|
|
146
146
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
147
|
+
def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
148
|
+
foldingsSubTotal = allValues[5]
|
|
149
|
+
foldingsSubTotal = leavesTotal + foldingsSubTotal
|
|
150
|
+
return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
|
|
151
|
+
|
|
152
|
+
def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
153
|
+
def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
154
|
+
comparand = backtrackingValues[2]
|
|
155
|
+
return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
|
|
156
156
|
|
|
157
|
-
|
|
158
|
-
|
|
157
|
+
def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
158
|
+
backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
|
|
159
159
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
160
|
+
activeLeafNumber = activeLeafNumber - 1
|
|
161
|
+
backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
|
|
162
|
+
backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
|
|
163
163
|
|
|
164
|
-
|
|
164
|
+
return (backtrackAbove, backtrackBelow, activeLeafNumber)
|
|
165
165
|
|
|
166
|
-
|
|
167
|
-
|
|
166
|
+
def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
|
|
167
|
+
return activeLeafNumber > 0
|
|
168
168
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
169
|
+
def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
170
|
+
placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
|
|
171
|
+
activeGapNumber = activeGapNumber - 1
|
|
172
|
+
placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
|
|
173
|
+
placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
|
|
174
|
+
placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
175
|
+
placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
176
|
+
placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
|
|
177
|
+
|
|
178
|
+
activeLeafNumber = 1 + activeLeafNumber
|
|
179
|
+
return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
|
|
180
|
+
|
|
181
|
+
leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
182
|
+
|
|
183
|
+
whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
|
|
184
|
+
whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
|
|
185
|
+
leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
|
|
186
|
+
|
|
187
|
+
if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
|
|
188
|
+
if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
|
|
189
|
+
leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
|
|
190
|
+
|
|
191
|
+
return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
192
|
+
|
|
193
|
+
# Dynamic values
|
|
194
|
+
A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
|
|
195
|
+
B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
|
|
196
|
+
count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
|
|
197
|
+
gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
|
|
198
|
+
gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
|
|
199
|
+
|
|
200
|
+
foldingsTotal = jax.numpy.uint32(0)
|
|
201
|
+
l = jax.numpy.uint32(1)
|
|
202
|
+
g = jax.numpy.uint32(0)
|
|
203
|
+
|
|
204
|
+
foldingsValues = (A, B, count, gapter, gap, foldingsTotal, l, g)
|
|
205
|
+
foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
|
|
206
|
+
return foldingsValues[5]
|
|
207
207
|
|
|
208
208
|
foldingsJAX = jax.jit(foldingsJAX, static_argnums=(0, 1))
|