mapFolding 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mapFolding/JAX/lunnanJAX.py +206 -0
- mapFolding/JAX/taskJAX.py +313 -0
- mapFolding/__init__.py +21 -0
- mapFolding/babbage.py +12 -0
- mapFolding/beDRY.py +219 -0
- mapFolding/benchmarks/benchmarking.py +66 -0
- mapFolding/benchmarks/test_benchmarks.py +74 -0
- mapFolding/importPackages.py +5 -0
- mapFolding/lovelace.py +121 -0
- mapFolding/oeis.py +299 -0
- mapFolding/reference/hunterNumba.py +132 -0
- mapFolding/reference/irvineJavaPort.py +120 -0
- mapFolding/reference/lunnan.py +153 -0
- mapFolding/reference/lunnanNumpy.py +123 -0
- mapFolding/reference/lunnanWhile.py +121 -0
- mapFolding/reference/rotatedEntryPoint.py +240 -0
- mapFolding/startHere.py +54 -0
- mapFolding/theSSOT.py +62 -0
- mapFolding-0.2.0.dist-info/METADATA +170 -0
- mapFolding-0.2.0.dist-info/RECORD +28 -0
- mapFolding-0.2.0.dist-info/WHEEL +5 -0
- mapFolding-0.2.0.dist-info/entry_points.txt +4 -0
- mapFolding-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/conftest.py +262 -0
- tests/test_oeis.py +195 -0
- tests/test_other.py +71 -0
- tests/test_tasks.py +18 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from mapFolding import validateListDimensions, getLeavesTotal, makeConnectionGraph
|
|
2
|
+
from typing import List, Tuple
|
|
3
|
+
import jax
|
|
4
|
+
import jaxtyping
|
|
5
|
+
|
|
6
|
+
dtypeDefault = jax.numpy.uint32
|
|
7
|
+
dtypeMaximum = jax.numpy.uint32
|
|
8
|
+
|
|
9
|
+
def countFolds(listDimensions: List[int]) -> int:
|
|
10
|
+
listDimensionsPositive: List[int] = validateListDimensions(listDimensions)
|
|
11
|
+
|
|
12
|
+
n: int = getLeavesTotal(listDimensionsPositive)
|
|
13
|
+
d: int = len(listDimensions)
|
|
14
|
+
import numpy
|
|
15
|
+
D: numpy.ndarray = makeConnectionGraph(listDimensionsPositive)
|
|
16
|
+
connectionGraph = jax.numpy.asarray(D, dtype=dtypeDefault)
|
|
17
|
+
del listDimensionsPositive
|
|
18
|
+
|
|
19
|
+
return foldingsJAX(n, d, connectionGraph)
|
|
20
|
+
|
|
21
|
+
def foldingsJAX(leavesTotal: jaxtyping.UInt32, dimensionsTotal: jaxtyping.UInt32, connectionGraph: jaxtyping.Array) -> jaxtyping.UInt32:
|
|
22
|
+
|
|
23
|
+
def doNothing(argument):
|
|
24
|
+
return argument
|
|
25
|
+
|
|
26
|
+
def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
|
|
27
|
+
comparand = comparisonValues[6]
|
|
28
|
+
return comparand > 0
|
|
29
|
+
|
|
30
|
+
def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
31
|
+
_0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
|
|
32
|
+
|
|
33
|
+
sentinel = leafBelow.at[0].get().astype(jax.numpy.uint32)
|
|
34
|
+
|
|
35
|
+
allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
|
|
36
|
+
lambda argumentX: dao(findGapsDo(argumentX)),
|
|
37
|
+
lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
|
|
38
|
+
allValues)
|
|
39
|
+
|
|
40
|
+
return allValues
|
|
41
|
+
|
|
42
|
+
def findGapsCondition(leafBelowSentinel, activeLeafNumber):
|
|
43
|
+
return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
|
|
44
|
+
|
|
45
|
+
def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
46
|
+
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
|
|
47
|
+
return comparisonValues[-1] <= dimensionsTotal
|
|
48
|
+
|
|
49
|
+
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
50
|
+
def ifLeafIsUnconstrainedCondition(comparand):
|
|
51
|
+
return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
|
|
52
|
+
|
|
53
|
+
def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
54
|
+
unconstrained_unconstrainedLeaf = unconstrainedValues[3]
|
|
55
|
+
unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
|
|
56
|
+
return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
|
|
57
|
+
|
|
58
|
+
def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
59
|
+
def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
60
|
+
return comparisonValues[-1] != activeLeaf1ndex
|
|
61
|
+
|
|
62
|
+
def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
63
|
+
countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee = countGapsDoValues
|
|
64
|
+
|
|
65
|
+
countGapsPotentialGaps = countGapsPotentialGaps.at[countGapsGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
|
|
66
|
+
countGapsGap1ndexLowerBound = jax.numpy.where(jax.numpy.equal(countGapsCountDimensionsGapped[countGapsLeaf1ndexConnectee], 0), countGapsGap1ndexLowerBound + 1, countGapsGap1ndexLowerBound)
|
|
67
|
+
countGapsCountDimensionsGapped = countGapsCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
|
|
68
|
+
countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.uint32)
|
|
69
|
+
|
|
70
|
+
return (countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee)
|
|
71
|
+
|
|
72
|
+
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
|
|
73
|
+
|
|
74
|
+
leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.uint32)
|
|
75
|
+
|
|
76
|
+
countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
|
|
77
|
+
countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
|
|
78
|
+
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
|
|
79
|
+
|
|
80
|
+
return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
|
|
81
|
+
|
|
82
|
+
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
83
|
+
|
|
84
|
+
ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
|
|
85
|
+
ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
|
|
86
|
+
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
|
|
87
|
+
|
|
88
|
+
dimensionNumber = 1 + dimensionNumber
|
|
89
|
+
return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
|
|
90
|
+
|
|
91
|
+
def almostUselessCondition(comparand):
|
|
92
|
+
return comparand == dimensionsTotal
|
|
93
|
+
|
|
94
|
+
def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
95
|
+
def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
|
|
96
|
+
return comparisonValues[-1] < activeLeaf1ndex
|
|
97
|
+
|
|
98
|
+
def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
99
|
+
leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
100
|
+
leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
|
|
101
|
+
gapNumberLowerBound = 1 + gapNumberLowerBound
|
|
102
|
+
leafNumber = 1 + leafNumber
|
|
103
|
+
return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
|
|
104
|
+
return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
105
|
+
|
|
106
|
+
def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
107
|
+
return comparisonValues[-1] < gap1ndexCeiling
|
|
108
|
+
|
|
109
|
+
def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
110
|
+
gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
|
|
111
|
+
gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
|
|
112
|
+
activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.uint32)
|
|
113
|
+
gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
|
|
114
|
+
index = 1 + index
|
|
115
|
+
return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
|
|
116
|
+
|
|
117
|
+
_0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
118
|
+
|
|
119
|
+
unconstrainedLeaf = jax.numpy.uint32(0)
|
|
120
|
+
dimension1ndex = jax.numpy.uint32(1)
|
|
121
|
+
gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.uint32)
|
|
122
|
+
activeGap1ndex = gap1ndexCeiling
|
|
123
|
+
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
|
|
124
|
+
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
|
|
125
|
+
countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
126
|
+
del dimension1ndex
|
|
127
|
+
|
|
128
|
+
leaf1ndex = jax.numpy.uint32(0)
|
|
129
|
+
for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
|
|
130
|
+
for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
131
|
+
gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
132
|
+
del leaf1ndex
|
|
133
|
+
|
|
134
|
+
indexMiniGap = activeGap1ndex
|
|
135
|
+
miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
|
|
136
|
+
miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
|
|
137
|
+
countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
|
|
138
|
+
del indexMiniGap
|
|
139
|
+
|
|
140
|
+
return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
141
|
+
|
|
142
|
+
def incrementCondition(leafBelowSentinel, activeLeafNumber):
|
|
143
|
+
return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
|
|
144
|
+
|
|
145
|
+
def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
146
|
+
foldingsSubTotal = allValues[5]
|
|
147
|
+
foldingsSubTotal = leavesTotal + foldingsSubTotal
|
|
148
|
+
return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
|
|
149
|
+
|
|
150
|
+
def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
151
|
+
def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
152
|
+
comparand = backtrackingValues[2]
|
|
153
|
+
return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
|
|
154
|
+
|
|
155
|
+
def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
156
|
+
backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
|
|
157
|
+
|
|
158
|
+
activeLeafNumber = activeLeafNumber - 1
|
|
159
|
+
backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
|
|
160
|
+
backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
|
|
161
|
+
|
|
162
|
+
return (backtrackAbove, backtrackBelow, activeLeafNumber)
|
|
163
|
+
|
|
164
|
+
def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
|
|
165
|
+
return activeLeafNumber > 0
|
|
166
|
+
|
|
167
|
+
def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
168
|
+
placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
|
|
169
|
+
activeGapNumber = activeGapNumber - 1
|
|
170
|
+
placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
|
|
171
|
+
placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
|
|
172
|
+
placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
173
|
+
placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
174
|
+
placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
|
|
175
|
+
|
|
176
|
+
activeLeafNumber = 1 + activeLeafNumber
|
|
177
|
+
return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
|
|
178
|
+
|
|
179
|
+
leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
180
|
+
|
|
181
|
+
whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
|
|
182
|
+
whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
|
|
183
|
+
leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
|
|
184
|
+
|
|
185
|
+
if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
|
|
186
|
+
if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
|
|
187
|
+
leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
|
|
188
|
+
|
|
189
|
+
return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
190
|
+
|
|
191
|
+
# Dynamic values
|
|
192
|
+
A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
193
|
+
B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
194
|
+
count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
195
|
+
gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
196
|
+
gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
|
|
197
|
+
|
|
198
|
+
foldingsTotal = jax.numpy.uint32(0)
|
|
199
|
+
l = jax.numpy.uint32(1)
|
|
200
|
+
g = jax.numpy.uint32(0)
|
|
201
|
+
|
|
202
|
+
foldingsValues = (A, B, count, gapter, gap, foldingsTotal, l, g)
|
|
203
|
+
foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
|
|
204
|
+
return foldingsValues[5]
|
|
205
|
+
|
|
206
|
+
foldingsJAX = jax.jit(foldingsJAX, static_argnums=(0, 1))
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
from mapFolding import validateListDimensions, getLeavesTotal
|
|
2
|
+
from typing import List, Tuple
|
|
3
|
+
import jax
|
|
4
|
+
import jaxtyping
|
|
5
|
+
|
|
6
|
+
dtypeDefault = jax.numpy.int32
|
|
7
|
+
dtypeMaximum = jax.numpy.int32
|
|
8
|
+
|
|
9
|
+
def countFolds(listDimensions: List[int]):
|
|
10
|
+
"""Calculate foldings across multiple devices using pmap"""
|
|
11
|
+
p = validateListDimensions(listDimensions)
|
|
12
|
+
n = getLeavesTotal(p)
|
|
13
|
+
|
|
14
|
+
# Get number of devices (GPUs/TPUs)
|
|
15
|
+
deviceCount = jax.device_count()
|
|
16
|
+
|
|
17
|
+
if deviceCount > 1:
|
|
18
|
+
# Split work across devices
|
|
19
|
+
tasksPerDevice = (n + deviceCount - 1) // deviceCount
|
|
20
|
+
paddedTaskCount = tasksPerDevice * deviceCount
|
|
21
|
+
|
|
22
|
+
# Create padded array of task indices
|
|
23
|
+
arrayTaskIndices = jax.numpy.arange(paddedTaskCount, dtype=dtypeDefault)
|
|
24
|
+
arrayTaskIndices = arrayTaskIndices.reshape((deviceCount, tasksPerDevice))
|
|
25
|
+
|
|
26
|
+
# Create pmapped function
|
|
27
|
+
parallelFoldingsTask = jax.pmap(lambda x: jax.vmap(lambda y: foldingsTask(tuple(p), y))(x))
|
|
28
|
+
|
|
29
|
+
# Run computation across devices
|
|
30
|
+
arrayResults = parallelFoldingsTask(arrayTaskIndices)
|
|
31
|
+
|
|
32
|
+
# Sum valid results (ignore padding)
|
|
33
|
+
return jax.numpy.sum(arrayResults[:, :min(tasksPerDevice, n - tasksPerDevice * (deviceCount-1))])
|
|
34
|
+
else:
|
|
35
|
+
# Fall back to sequential execution if no multiple devices available
|
|
36
|
+
arrayTaskIndices = jax.numpy.arange(n, dtype=dtypeDefault)
|
|
37
|
+
batchedFoldingsTask = jax.vmap(lambda x: foldingsTask(tuple(p), x))
|
|
38
|
+
return jax.numpy.sum(batchedFoldingsTask(arrayTaskIndices))
|
|
39
|
+
|
|
40
|
+
def foldingsTask(p, taskIndex) -> jaxtyping.UInt32:
|
|
41
|
+
arrayDimensions = jax.numpy.asarray(p, dtype=dtypeDefault)
|
|
42
|
+
leavesTotal = jax.numpy.prod(arrayDimensions)
|
|
43
|
+
dimensionsTotal = jax.numpy.size(arrayDimensions)
|
|
44
|
+
|
|
45
|
+
"""How to build a leaf connection graph, also called a "Cartesian Product Decomposition"
|
|
46
|
+
or a "Dimensional Product Mapping", with sentinels:
|
|
47
|
+
Step 1: find the cumulative product of the map's dimensions"""
|
|
48
|
+
cumulativeProduct = jax.numpy.ones(dimensionsTotal + 1, dtype=dtypeDefault)
|
|
49
|
+
cumulativeProduct = cumulativeProduct.at[1:].set(jax.numpy.cumprod(arrayDimensions))
|
|
50
|
+
|
|
51
|
+
"""Step 2: for each dimension, create a coordinate system """
|
|
52
|
+
"""coordinateSystem[dimension1ndex][leaf1ndex] holds the dimension1ndex-th coordinate of leaf leaf1ndex"""
|
|
53
|
+
coordinateSystem = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
|
|
54
|
+
|
|
55
|
+
# Create mesh of indices for vectorized computation
|
|
56
|
+
dimension1ndices, leaf1ndices = jax.numpy.meshgrid(
|
|
57
|
+
jax.numpy.arange(1, dimensionsTotal + 1),
|
|
58
|
+
jax.numpy.arange(1, leavesTotal + 1),
|
|
59
|
+
indexing='ij'
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Compute all coordinates at once using broadcasting
|
|
63
|
+
coordinateSystem = coordinateSystem.at[1:, 1:].set(
|
|
64
|
+
((leaf1ndices - 1) // cumulativeProduct.at[dimension1ndices - 1].get()) %
|
|
65
|
+
arrayDimensions.at[dimension1ndices - 1].get() + 1
|
|
66
|
+
)
|
|
67
|
+
del dimension1ndices, leaf1ndices
|
|
68
|
+
|
|
69
|
+
"""Step 3: create a huge empty connection graph"""
|
|
70
|
+
connectionGraph = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
|
|
71
|
+
|
|
72
|
+
# Create 3D mesh of indices for vectorized computation
|
|
73
|
+
dimension1ndices, activeLeaf1ndices, connectee1ndices = jax.numpy.meshgrid(
|
|
74
|
+
jax.numpy.arange(1, dimensionsTotal + 1),
|
|
75
|
+
jax.numpy.arange(1, leavesTotal + 1),
|
|
76
|
+
jax.numpy.arange(1, leavesTotal + 1),
|
|
77
|
+
indexing='ij'
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Create masks for valid indices
|
|
81
|
+
maskActiveConnectee = connectee1ndices <= activeLeaf1ndices
|
|
82
|
+
|
|
83
|
+
# Calculate coordinate parity comparison
|
|
84
|
+
coordsParity = (coordinateSystem.at[dimension1ndices, activeLeaf1ndices].get() & 1) == \
|
|
85
|
+
(coordinateSystem.at[dimension1ndices, connectee1ndices].get() & 1)
|
|
86
|
+
|
|
87
|
+
# Compute distance conditions
|
|
88
|
+
isFirstCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == 1
|
|
89
|
+
isLastCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == \
|
|
90
|
+
arrayDimensions.at[dimension1ndices - 1].get()
|
|
91
|
+
exceedsActive = connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get() > activeLeaf1ndices
|
|
92
|
+
|
|
93
|
+
# Compute connection values for even and odd parities
|
|
94
|
+
evenParityValues = jax.numpy.where(
|
|
95
|
+
isFirstCoord,
|
|
96
|
+
connectee1ndices,
|
|
97
|
+
connectee1ndices - cumulativeProduct.at[dimension1ndices - 1].get()
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
oddParityValues = jax.numpy.where(
|
|
101
|
+
jax.numpy.logical_or(isLastCoord, exceedsActive),
|
|
102
|
+
connectee1ndices,
|
|
103
|
+
connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get()
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Combine based on parity and valid indices
|
|
107
|
+
connectionValues = jax.numpy.where(
|
|
108
|
+
coordsParity,
|
|
109
|
+
evenParityValues,
|
|
110
|
+
oddParityValues
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Update only valid connections
|
|
114
|
+
connectionGraph = connectionGraph.at[dimension1ndices, activeLeaf1ndices, connectee1ndices].set(
|
|
115
|
+
jax.numpy.where(maskActiveConnectee, connectionValues, 0)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def doNothing(argument):
|
|
119
|
+
return argument
|
|
120
|
+
|
|
121
|
+
def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
|
|
122
|
+
comparand = comparisonValues[6]
|
|
123
|
+
return comparand > 0
|
|
124
|
+
|
|
125
|
+
def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
126
|
+
_0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
|
|
127
|
+
|
|
128
|
+
sentinel = leafBelow.at[0].get().astype(jax.numpy.int32)
|
|
129
|
+
|
|
130
|
+
allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
|
|
131
|
+
lambda argumentX: dao(findGapsDo(argumentX)),
|
|
132
|
+
lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
|
|
133
|
+
allValues)
|
|
134
|
+
|
|
135
|
+
return allValues
|
|
136
|
+
|
|
137
|
+
def findGapsCondition(leafBelowSentinel, activeLeafNumber):
|
|
138
|
+
return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
|
|
139
|
+
|
|
140
|
+
def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
141
|
+
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
|
|
142
|
+
return comparisonValues[-1] <= dimensionsTotal
|
|
143
|
+
|
|
144
|
+
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
145
|
+
def ifLeafIsUnconstrainedCondition(comparand):
|
|
146
|
+
return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
|
|
147
|
+
|
|
148
|
+
def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
149
|
+
unconstrained_unconstrainedLeaf = unconstrainedValues[3]
|
|
150
|
+
unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
|
|
151
|
+
return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
|
|
152
|
+
|
|
153
|
+
def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
154
|
+
def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
155
|
+
return comparisonValues[-1] != activeLeaf1ndex
|
|
156
|
+
|
|
157
|
+
def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
158
|
+
# if taskDivisions == False or activeLeaf1ndex != leavesTotal or leaf1ndexConnectee % leavesTotal == taskIndex:
|
|
159
|
+
def taskDivisionComparison():
|
|
160
|
+
return jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
|
|
161
|
+
# return taskDivisions == False or jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
|
|
162
|
+
|
|
163
|
+
def taskDivisionDo(taskDivisionDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
164
|
+
taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound = taskDivisionDoValues
|
|
165
|
+
|
|
166
|
+
taskDivisionPotentialGaps = taskDivisionPotentialGaps.at[taskDivisionGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
|
|
167
|
+
taskDivisionGap1ndexLowerBound = jax.numpy.where(
|
|
168
|
+
jax.numpy.equal(taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].get(), 0), taskDivisionGap1ndexLowerBound + 1, taskDivisionGap1ndexLowerBound)
|
|
169
|
+
taskDivisionCountDimensionsGapped = taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
|
|
170
|
+
|
|
171
|
+
return (taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound)
|
|
172
|
+
|
|
173
|
+
countGapsLeaf1ndexConnectee = countGapsDoValues[3]
|
|
174
|
+
taskDivisionValues = (countGapsDoValues[0], countGapsDoValues[1], countGapsDoValues[2])
|
|
175
|
+
taskDivisionValues = jax.lax.cond(taskDivisionComparison(), taskDivisionDo, doNothing, taskDivisionValues)
|
|
176
|
+
|
|
177
|
+
countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.int32)
|
|
178
|
+
|
|
179
|
+
return (taskDivisionValues[0], taskDivisionValues[1], taskDivisionValues[2], countGapsLeaf1ndexConnectee)
|
|
180
|
+
|
|
181
|
+
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
|
|
182
|
+
|
|
183
|
+
leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.int32)
|
|
184
|
+
|
|
185
|
+
countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
|
|
186
|
+
countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
|
|
187
|
+
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
|
|
188
|
+
|
|
189
|
+
return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
|
|
190
|
+
|
|
191
|
+
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
192
|
+
|
|
193
|
+
ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
|
|
194
|
+
ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
|
|
195
|
+
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
|
|
196
|
+
|
|
197
|
+
dimensionNumber = 1 + dimensionNumber
|
|
198
|
+
return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
|
|
199
|
+
|
|
200
|
+
def almostUselessCondition(comparand):
|
|
201
|
+
return comparand == dimensionsTotal
|
|
202
|
+
|
|
203
|
+
def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
204
|
+
def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
|
|
205
|
+
return comparisonValues[-1] < activeLeaf1ndex
|
|
206
|
+
|
|
207
|
+
def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
208
|
+
leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
209
|
+
leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
|
|
210
|
+
gapNumberLowerBound = 1 + gapNumberLowerBound
|
|
211
|
+
leafNumber = 1 + leafNumber
|
|
212
|
+
return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
|
|
213
|
+
return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
214
|
+
|
|
215
|
+
def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
216
|
+
return comparisonValues[-1] < gap1ndexCeiling
|
|
217
|
+
|
|
218
|
+
def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
219
|
+
gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
|
|
220
|
+
gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
|
|
221
|
+
activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.int32)
|
|
222
|
+
gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
|
|
223
|
+
index = 1 + index
|
|
224
|
+
return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
|
|
225
|
+
|
|
226
|
+
_0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
227
|
+
|
|
228
|
+
unconstrainedLeaf = jax.numpy.int32(0)
|
|
229
|
+
dimension1ndex = jax.numpy.int32(1)
|
|
230
|
+
gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.int32)
|
|
231
|
+
activeGap1ndex = gap1ndexCeiling
|
|
232
|
+
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
|
|
233
|
+
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
|
|
234
|
+
countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
235
|
+
del dimension1ndex
|
|
236
|
+
|
|
237
|
+
leaf1ndex = jax.numpy.int32(0)
|
|
238
|
+
for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
|
|
239
|
+
for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
240
|
+
gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
241
|
+
del leaf1ndex
|
|
242
|
+
|
|
243
|
+
indexMiniGap = activeGap1ndex
|
|
244
|
+
miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
|
|
245
|
+
miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
|
|
246
|
+
countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
|
|
247
|
+
del indexMiniGap
|
|
248
|
+
|
|
249
|
+
return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
250
|
+
|
|
251
|
+
def incrementCondition(leafBelowSentinel, activeLeafNumber):
|
|
252
|
+
return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
|
|
253
|
+
|
|
254
|
+
def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
255
|
+
foldingsSubTotal = allValues[5]
|
|
256
|
+
foldingsSubTotal = leavesTotal + foldingsSubTotal
|
|
257
|
+
return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
|
|
258
|
+
|
|
259
|
+
def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
260
|
+
def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
261
|
+
comparand = backtrackingValues[2]
|
|
262
|
+
return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
|
|
263
|
+
|
|
264
|
+
def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
265
|
+
backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
|
|
266
|
+
|
|
267
|
+
activeLeafNumber = activeLeafNumber - 1
|
|
268
|
+
backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
|
|
269
|
+
backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
|
|
270
|
+
|
|
271
|
+
return (backtrackAbove, backtrackBelow, activeLeafNumber)
|
|
272
|
+
|
|
273
|
+
def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
|
|
274
|
+
return activeLeafNumber > 0
|
|
275
|
+
|
|
276
|
+
def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
277
|
+
placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
|
|
278
|
+
activeGapNumber = activeGapNumber - 1
|
|
279
|
+
placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
|
|
280
|
+
placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
|
|
281
|
+
placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
282
|
+
placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
283
|
+
placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
|
|
284
|
+
|
|
285
|
+
activeLeafNumber = 1 + activeLeafNumber
|
|
286
|
+
return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
|
|
287
|
+
|
|
288
|
+
leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
289
|
+
|
|
290
|
+
whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
|
|
291
|
+
whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
|
|
292
|
+
leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
|
|
293
|
+
|
|
294
|
+
if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
|
|
295
|
+
if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
|
|
296
|
+
leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
|
|
297
|
+
|
|
298
|
+
return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
299
|
+
|
|
300
|
+
# Dynamic values
|
|
301
|
+
A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
302
|
+
B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
303
|
+
count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
304
|
+
gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
305
|
+
gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
|
|
306
|
+
|
|
307
|
+
foldingsSubTotal = jax.numpy.int32(0)
|
|
308
|
+
l = jax.numpy.int32(1)
|
|
309
|
+
g = jax.numpy.int32(0)
|
|
310
|
+
|
|
311
|
+
foldingsValues = (A, B, count, gapter, gap, foldingsSubTotal, l, g)
|
|
312
|
+
foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
|
|
313
|
+
return foldingsValues[5]
|
mapFolding/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Test concept: Import priority levels. Larger priority values should be imported before smaller priority values.
|
|
2
|
+
This seems to be a little silly: no useful information is encoded in the priority value, so I don't know if a
|
|
3
|
+
new import should have a lower or higher priority.
|
|
4
|
+
Crazy concept: Python doesn't cram at least two import roles into one system, call it `import` and tell us how
|
|
5
|
+
awesome Python is. Alternatively, I learn about the secret system for mapping physical names to logical names."""
|
|
6
|
+
|
|
7
|
+
# TODO Across the entire package, restructure computationDivisions.
|
|
8
|
+
# test modules need updating still
|
|
9
|
+
|
|
10
|
+
from .theSSOT import *
|
|
11
|
+
from .beDRY import getTaskDivisions, makeConnectionGraph, outfitFoldings, setCPUlimit
|
|
12
|
+
from .beDRY import getLeavesTotal, parseDimensions, validateListDimensions
|
|
13
|
+
from .startHere import countFolds
|
|
14
|
+
from .oeis import oeisIDfor_n, getOEISids, clearOEIScache
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
'clearOEIScache',
|
|
18
|
+
'countFolds',
|
|
19
|
+
'getOEISids',
|
|
20
|
+
'oeisIDfor_n',
|
|
21
|
+
]
|
mapFolding/babbage.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from mapFolding.lovelace import countFoldsCompiled
|
|
2
|
+
from numpy import integer
|
|
3
|
+
from numpy.typing import NDArray
|
|
4
|
+
from typing import Any, Tuple
|
|
5
|
+
import numba
|
|
6
|
+
import numpy
|
|
7
|
+
|
|
8
|
+
@numba.jit(cache=True)
|
|
9
|
+
def _countFolds(connectionGraph: NDArray[integer[Any]], foldsTotal: NDArray[integer[Any]], mapShape: Tuple[int, ...], my: NDArray[integer[Any]], gapsWhere: NDArray[integer[Any]], the: NDArray[integer[Any]], track: NDArray[integer[Any]]):
|
|
10
|
+
# TODO learn if I really must change this jitted function to get the super jit to recompile
|
|
11
|
+
# print('babbage')
|
|
12
|
+
return countFoldsCompiled(connectionGraph, foldsTotal, my, gapsWhere, the, track)
|