mapFolding 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,206 @@
1
+ from mapFolding import validateListDimensions, getLeavesTotal, makeConnectionGraph
2
+ from typing import List, Tuple
3
+ import jax
4
+ import jaxtyping
5
+
6
+ dtypeDefault = jax.numpy.uint32
7
+ dtypeMaximum = jax.numpy.uint32
8
+
9
+ def countFolds(listDimensions: List[int]) -> int:
10
+ listDimensionsPositive: List[int] = validateListDimensions(listDimensions)
11
+
12
+ n: int = getLeavesTotal(listDimensionsPositive)
13
+ d: int = len(listDimensions)
14
+ import numpy
15
+ D: numpy.ndarray = makeConnectionGraph(listDimensionsPositive)
16
+ connectionGraph = jax.numpy.asarray(D, dtype=dtypeDefault)
17
+ del listDimensionsPositive
18
+
19
+ return foldingsJAX(n, d, connectionGraph)
20
+
21
+ def foldingsJAX(leavesTotal: jaxtyping.UInt32, dimensionsTotal: jaxtyping.UInt32, connectionGraph: jaxtyping.Array) -> jaxtyping.UInt32:
22
+
23
+ def doNothing(argument):
24
+ return argument
25
+
26
+ def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
27
+ comparand = comparisonValues[6]
28
+ return comparand > 0
29
+
30
+ def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
31
+ _0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
32
+
33
+ sentinel = leafBelow.at[0].get().astype(jax.numpy.uint32)
34
+
35
+ allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
36
+ lambda argumentX: dao(findGapsDo(argumentX)),
37
+ lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
38
+ allValues)
39
+
40
+ return allValues
41
+
42
+ def findGapsCondition(leafBelowSentinel, activeLeafNumber):
43
+ return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
44
+
45
+ def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
46
+ def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
47
+ return comparisonValues[-1] <= dimensionsTotal
48
+
49
+ def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
50
+ def ifLeafIsUnconstrainedCondition(comparand):
51
+ return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
52
+
53
+ def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
54
+ unconstrained_unconstrainedLeaf = unconstrainedValues[3]
55
+ unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
56
+ return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
57
+
58
+ def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
59
+ def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
60
+ return comparisonValues[-1] != activeLeaf1ndex
61
+
62
+ def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
63
+ countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee = countGapsDoValues
64
+
65
+ countGapsPotentialGaps = countGapsPotentialGaps.at[countGapsGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
66
+ countGapsGap1ndexLowerBound = jax.numpy.where(jax.numpy.equal(countGapsCountDimensionsGapped[countGapsLeaf1ndexConnectee], 0), countGapsGap1ndexLowerBound + 1, countGapsGap1ndexLowerBound)
67
+ countGapsCountDimensionsGapped = countGapsCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
68
+ countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.uint32)
69
+
70
+ return (countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee)
71
+
72
+ unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
73
+
74
+ leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.uint32)
75
+
76
+ countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
77
+ countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
78
+ unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
79
+
80
+ return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
81
+
82
+ dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
83
+
84
+ ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
85
+ ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
86
+ dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
87
+
88
+ dimensionNumber = 1 + dimensionNumber
89
+ return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
90
+
91
+ def almostUselessCondition(comparand):
92
+ return comparand == dimensionsTotal
93
+
94
+ def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
95
+ def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
96
+ return comparisonValues[-1] < activeLeaf1ndex
97
+
98
+ def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
99
+ leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
100
+ leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
101
+ gapNumberLowerBound = 1 + gapNumberLowerBound
102
+ leafNumber = 1 + leafNumber
103
+ return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
104
+ return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
105
+
106
+ def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
107
+ return comparisonValues[-1] < gap1ndexCeiling
108
+
109
+ def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
110
+ gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
111
+ gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
112
+ activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.uint32)
113
+ gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
114
+ index = 1 + index
115
+ return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
116
+
117
+ _0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
118
+
119
+ unconstrainedLeaf = jax.numpy.uint32(0)
120
+ dimension1ndex = jax.numpy.uint32(1)
121
+ gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.uint32)
122
+ activeGap1ndex = gap1ndexCeiling
123
+ for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
124
+ for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
125
+ countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
126
+ del dimension1ndex
127
+
128
+ leaf1ndex = jax.numpy.uint32(0)
129
+ for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
130
+ for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
131
+ gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
132
+ del leaf1ndex
133
+
134
+ indexMiniGap = activeGap1ndex
135
+ miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
136
+ miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
137
+ countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
138
+ del indexMiniGap
139
+
140
+ return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
141
+
142
+ def incrementCondition(leafBelowSentinel, activeLeafNumber):
143
+ return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
144
+
145
+ def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
146
+ foldingsSubTotal = allValues[5]
147
+ foldingsSubTotal = leavesTotal + foldingsSubTotal
148
+ return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
149
+
150
+ def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
151
+ def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
152
+ comparand = backtrackingValues[2]
153
+ return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
154
+
155
+ def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
156
+ backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
157
+
158
+ activeLeafNumber = activeLeafNumber - 1
159
+ backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
160
+ backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
161
+
162
+ return (backtrackAbove, backtrackBelow, activeLeafNumber)
163
+
164
+ def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
165
+ return activeLeafNumber > 0
166
+
167
+ def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
168
+ placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
169
+ activeGapNumber = activeGapNumber - 1
170
+ placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
171
+ placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
172
+ placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
173
+ placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
174
+ placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
175
+
176
+ activeLeafNumber = 1 + activeLeafNumber
177
+ return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
178
+
179
+ leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
180
+
181
+ whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
182
+ whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
183
+ leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
184
+
185
+ if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
186
+ if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
187
+ leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
188
+
189
+ return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
190
+
191
+ # Dynamic values
192
+ A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
193
+ B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
194
+ count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
195
+ gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
196
+ gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
197
+
198
+ foldingsTotal = jax.numpy.uint32(0)
199
+ l = jax.numpy.uint32(1)
200
+ g = jax.numpy.uint32(0)
201
+
202
+ foldingsValues = (A, B, count, gapter, gap, foldingsTotal, l, g)
203
+ foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
204
+ return foldingsValues[5]
205
+
206
+ foldingsJAX = jax.jit(foldingsJAX, static_argnums=(0, 1))
@@ -0,0 +1,313 @@
1
+ from mapFolding import validateListDimensions, getLeavesTotal
2
+ from typing import List, Tuple
3
+ import jax
4
+ import jaxtyping
5
+
6
+ dtypeDefault = jax.numpy.int32
7
+ dtypeMaximum = jax.numpy.int32
8
+
9
+ def countFolds(listDimensions: List[int]):
10
+ """Calculate foldings across multiple devices using pmap"""
11
+ p = validateListDimensions(listDimensions)
12
+ n = getLeavesTotal(p)
13
+
14
+ # Get number of devices (GPUs/TPUs)
15
+ deviceCount = jax.device_count()
16
+
17
+ if deviceCount > 1:
18
+ # Split work across devices
19
+ tasksPerDevice = (n + deviceCount - 1) // deviceCount
20
+ paddedTaskCount = tasksPerDevice * deviceCount
21
+
22
+ # Create padded array of task indices
23
+ arrayTaskIndices = jax.numpy.arange(paddedTaskCount, dtype=dtypeDefault)
24
+ arrayTaskIndices = arrayTaskIndices.reshape((deviceCount, tasksPerDevice))
25
+
26
+ # Create pmapped function
27
+ parallelFoldingsTask = jax.pmap(lambda x: jax.vmap(lambda y: foldingsTask(tuple(p), y))(x))
28
+
29
+ # Run computation across devices
30
+ arrayResults = parallelFoldingsTask(arrayTaskIndices)
31
+
32
+ # Sum valid results (ignore padding)
33
+ return jax.numpy.sum(arrayResults[:, :min(tasksPerDevice, n - tasksPerDevice * (deviceCount-1))])
34
+ else:
35
+ # Fall back to sequential execution if no multiple devices available
36
+ arrayTaskIndices = jax.numpy.arange(n, dtype=dtypeDefault)
37
+ batchedFoldingsTask = jax.vmap(lambda x: foldingsTask(tuple(p), x))
38
+ return jax.numpy.sum(batchedFoldingsTask(arrayTaskIndices))
39
+
40
+ def foldingsTask(p, taskIndex) -> jaxtyping.UInt32:
41
+ arrayDimensions = jax.numpy.asarray(p, dtype=dtypeDefault)
42
+ leavesTotal = jax.numpy.prod(arrayDimensions)
43
+ dimensionsTotal = jax.numpy.size(arrayDimensions)
44
+
45
+ """How to build a leaf connection graph, also called a "Cartesian Product Decomposition"
46
+ or a "Dimensional Product Mapping", with sentinels:
47
+ Step 1: find the cumulative product of the map's dimensions"""
48
+ cumulativeProduct = jax.numpy.ones(dimensionsTotal + 1, dtype=dtypeDefault)
49
+ cumulativeProduct = cumulativeProduct.at[1:].set(jax.numpy.cumprod(arrayDimensions))
50
+
51
+ """Step 2: for each dimension, create a coordinate system """
52
+ """coordinateSystem[dimension1ndex][leaf1ndex] holds the dimension1ndex-th coordinate of leaf leaf1ndex"""
53
+ coordinateSystem = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
54
+
55
+ # Create mesh of indices for vectorized computation
56
+ dimension1ndices, leaf1ndices = jax.numpy.meshgrid(
57
+ jax.numpy.arange(1, dimensionsTotal + 1),
58
+ jax.numpy.arange(1, leavesTotal + 1),
59
+ indexing='ij'
60
+ )
61
+
62
+ # Compute all coordinates at once using broadcasting
63
+ coordinateSystem = coordinateSystem.at[1:, 1:].set(
64
+ ((leaf1ndices - 1) // cumulativeProduct.at[dimension1ndices - 1].get()) %
65
+ arrayDimensions.at[dimension1ndices - 1].get() + 1
66
+ )
67
+ del dimension1ndices, leaf1ndices
68
+
69
+ """Step 3: create a huge empty connection graph"""
70
+ connectionGraph = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
71
+
72
+ # Create 3D mesh of indices for vectorized computation
73
+ dimension1ndices, activeLeaf1ndices, connectee1ndices = jax.numpy.meshgrid(
74
+ jax.numpy.arange(1, dimensionsTotal + 1),
75
+ jax.numpy.arange(1, leavesTotal + 1),
76
+ jax.numpy.arange(1, leavesTotal + 1),
77
+ indexing='ij'
78
+ )
79
+
80
+ # Create masks for valid indices
81
+ maskActiveConnectee = connectee1ndices <= activeLeaf1ndices
82
+
83
+ # Calculate coordinate parity comparison
84
+ coordsParity = (coordinateSystem.at[dimension1ndices, activeLeaf1ndices].get() & 1) == \
85
+ (coordinateSystem.at[dimension1ndices, connectee1ndices].get() & 1)
86
+
87
+ # Compute distance conditions
88
+ isFirstCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == 1
89
+ isLastCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == \
90
+ arrayDimensions.at[dimension1ndices - 1].get()
91
+ exceedsActive = connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get() > activeLeaf1ndices
92
+
93
+ # Compute connection values for even and odd parities
94
+ evenParityValues = jax.numpy.where(
95
+ isFirstCoord,
96
+ connectee1ndices,
97
+ connectee1ndices - cumulativeProduct.at[dimension1ndices - 1].get()
98
+ )
99
+
100
+ oddParityValues = jax.numpy.where(
101
+ jax.numpy.logical_or(isLastCoord, exceedsActive),
102
+ connectee1ndices,
103
+ connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get()
104
+ )
105
+
106
+ # Combine based on parity and valid indices
107
+ connectionValues = jax.numpy.where(
108
+ coordsParity,
109
+ evenParityValues,
110
+ oddParityValues
111
+ )
112
+
113
+ # Update only valid connections
114
+ connectionGraph = connectionGraph.at[dimension1ndices, activeLeaf1ndices, connectee1ndices].set(
115
+ jax.numpy.where(maskActiveConnectee, connectionValues, 0)
116
+ )
117
+
118
+ def doNothing(argument):
119
+ return argument
120
+
121
+ def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
122
+ comparand = comparisonValues[6]
123
+ return comparand > 0
124
+
125
+ def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
126
+ _0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
127
+
128
+ sentinel = leafBelow.at[0].get().astype(jax.numpy.int32)
129
+
130
+ allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
131
+ lambda argumentX: dao(findGapsDo(argumentX)),
132
+ lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
133
+ allValues)
134
+
135
+ return allValues
136
+
137
+ def findGapsCondition(leafBelowSentinel, activeLeafNumber):
138
+ return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
139
+
140
+ def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
141
+ def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
142
+ return comparisonValues[-1] <= dimensionsTotal
143
+
144
+ def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
145
+ def ifLeafIsUnconstrainedCondition(comparand):
146
+ return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
147
+
148
+ def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
149
+ unconstrained_unconstrainedLeaf = unconstrainedValues[3]
150
+ unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
151
+ return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
152
+
153
+ def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
154
+ def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
155
+ return comparisonValues[-1] != activeLeaf1ndex
156
+
157
+ def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
158
+ # if taskDivisions == False or activeLeaf1ndex != leavesTotal or leaf1ndexConnectee % leavesTotal == taskIndex:
159
+ def taskDivisionComparison():
160
+ return jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
161
+ # return taskDivisions == False or jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
162
+
163
+ def taskDivisionDo(taskDivisionDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
164
+ taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound = taskDivisionDoValues
165
+
166
+ taskDivisionPotentialGaps = taskDivisionPotentialGaps.at[taskDivisionGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
167
+ taskDivisionGap1ndexLowerBound = jax.numpy.where(
168
+ jax.numpy.equal(taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].get(), 0), taskDivisionGap1ndexLowerBound + 1, taskDivisionGap1ndexLowerBound)
169
+ taskDivisionCountDimensionsGapped = taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
170
+
171
+ return (taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound)
172
+
173
+ countGapsLeaf1ndexConnectee = countGapsDoValues[3]
174
+ taskDivisionValues = (countGapsDoValues[0], countGapsDoValues[1], countGapsDoValues[2])
175
+ taskDivisionValues = jax.lax.cond(taskDivisionComparison(), taskDivisionDo, doNothing, taskDivisionValues)
176
+
177
+ countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.int32)
178
+
179
+ return (taskDivisionValues[0], taskDivisionValues[1], taskDivisionValues[2], countGapsLeaf1ndexConnectee)
180
+
181
+ unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
182
+
183
+ leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.int32)
184
+
185
+ countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
186
+ countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
187
+ unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
188
+
189
+ return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
190
+
191
+ dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
192
+
193
+ ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
194
+ ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
195
+ dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
196
+
197
+ dimensionNumber = 1 + dimensionNumber
198
+ return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
199
+
200
+ def almostUselessCondition(comparand):
201
+ return comparand == dimensionsTotal
202
+
203
+ def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
204
+ def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
205
+ return comparisonValues[-1] < activeLeaf1ndex
206
+
207
+ def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
208
+ leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
209
+ leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
210
+ gapNumberLowerBound = 1 + gapNumberLowerBound
211
+ leafNumber = 1 + leafNumber
212
+ return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
213
+ return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
214
+
215
+ def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
216
+ return comparisonValues[-1] < gap1ndexCeiling
217
+
218
+ def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
219
+ gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
220
+ gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
221
+ activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.int32)
222
+ gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
223
+ index = 1 + index
224
+ return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
225
+
226
+ _0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
227
+
228
+ unconstrainedLeaf = jax.numpy.int32(0)
229
+ dimension1ndex = jax.numpy.int32(1)
230
+ gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.int32)
231
+ activeGap1ndex = gap1ndexCeiling
232
+ for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
233
+ for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
234
+ countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
235
+ del dimension1ndex
236
+
237
+ leaf1ndex = jax.numpy.int32(0)
238
+ for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
239
+ for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
240
+ gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
241
+ del leaf1ndex
242
+
243
+ indexMiniGap = activeGap1ndex
244
+ miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
245
+ miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
246
+ countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
247
+ del indexMiniGap
248
+
249
+ return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
250
+
251
+ def incrementCondition(leafBelowSentinel, activeLeafNumber):
252
+ return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
253
+
254
+ def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
255
+ foldingsSubTotal = allValues[5]
256
+ foldingsSubTotal = leavesTotal + foldingsSubTotal
257
+ return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
258
+
259
+ def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
260
+ def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
261
+ comparand = backtrackingValues[2]
262
+ return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
263
+
264
+ def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
265
+ backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
266
+
267
+ activeLeafNumber = activeLeafNumber - 1
268
+ backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
269
+ backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
270
+
271
+ return (backtrackAbove, backtrackBelow, activeLeafNumber)
272
+
273
+ def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
274
+ return activeLeafNumber > 0
275
+
276
+ def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
277
+ placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
278
+ activeGapNumber = activeGapNumber - 1
279
+ placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
280
+ placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
281
+ placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
282
+ placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
283
+ placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
284
+
285
+ activeLeafNumber = 1 + activeLeafNumber
286
+ return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
287
+
288
+ leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
289
+
290
+ whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
291
+ whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
292
+ leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
293
+
294
+ if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
295
+ if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
296
+ leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
297
+
298
+ return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
299
+
300
+ # Dynamic values
301
+ A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
302
+ B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
303
+ count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
304
+ gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
305
+ gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
306
+
307
+ foldingsSubTotal = jax.numpy.int32(0)
308
+ l = jax.numpy.int32(1)
309
+ g = jax.numpy.int32(0)
310
+
311
+ foldingsValues = (A, B, count, gapter, gap, foldingsSubTotal, l, g)
312
+ foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
313
+ return foldingsValues[5]
mapFolding/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ """Test concept: Import priority levels. Larger priority values should be imported before smaller priority values.
2
+ This seems to be a little silly: no useful information is encoded in the priority value, so I don't know if a
3
+ new import should have a lower or higher priority.
4
+ Crazy concept: Python doesn't cram at least two import roles into one system, call it `import` and tell us how
5
+ awesome Python is. Alternatively, I learn about the secret system for mapping physical names to logical names."""
6
+
7
+ # TODO Across the entire package, restructure computationDivisions.
8
+ # test modules need updating still
9
+
10
+ from .theSSOT import *
11
+ from .beDRY import getTaskDivisions, makeConnectionGraph, outfitFoldings, setCPUlimit
12
+ from .beDRY import getLeavesTotal, parseDimensions, validateListDimensions
13
+ from .startHere import countFolds
14
+ from .oeis import oeisIDfor_n, getOEISids, clearOEIScache
15
+
16
+ __all__ = [
17
+ 'clearOEIScache',
18
+ 'countFolds',
19
+ 'getOEISids',
20
+ 'oeisIDfor_n',
21
+ ]
mapFolding/babbage.py ADDED
@@ -0,0 +1,12 @@
1
+ from mapFolding.lovelace import countFoldsCompiled
2
+ from numpy import integer
3
+ from numpy.typing import NDArray
4
+ from typing import Any, Tuple
5
+ import numba
6
+ import numpy
7
+
8
+ @numba.jit(cache=True)
9
+ def _countFolds(connectionGraph: NDArray[integer[Any]], foldsTotal: NDArray[integer[Any]], mapShape: Tuple[int, ...], my: NDArray[integer[Any]], gapsWhere: NDArray[integer[Any]], the: NDArray[integer[Any]], track: NDArray[integer[Any]]):
10
+ # TODO learn if I really must change this jitted function to get the super jit to recompile
11
+ # print('babbage')
12
+ return countFoldsCompiled(connectionGraph, foldsTotal, my, gapsWhere, the, track)