mapFolding 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mapFolding/JAX/taskJAX.py DELETED
@@ -1,313 +0,0 @@
1
- from mapFolding import validateListDimensions, getLeavesTotal
2
- from typing import List, Tuple
3
- import jax
4
- import jaxtyping
5
-
6
- dtypeDefault = jax.numpy.int32
7
- dtypeMaximum = jax.numpy.int32
8
-
9
- def countFolds(listDimensions: List[int]):
10
- """Calculate foldings across multiple devices using pmap"""
11
- p = validateListDimensions(listDimensions)
12
- n = getLeavesTotal(p)
13
-
14
- # Get number of devices (GPUs/TPUs)
15
- deviceCount = jax.device_count()
16
-
17
- if deviceCount > 1:
18
- # Split work across devices
19
- tasksPerDevice = (n + deviceCount - 1) // deviceCount
20
- paddedTaskCount = tasksPerDevice * deviceCount
21
-
22
- # Create padded array of task indices
23
- arrayTaskIndices = jax.numpy.arange(paddedTaskCount, dtype=dtypeDefault)
24
- arrayTaskIndices = arrayTaskIndices.reshape((deviceCount, tasksPerDevice))
25
-
26
- # Create pmapped function
27
- parallelFoldingsTask = jax.pmap(lambda x: jax.vmap(lambda y: foldingsTask(tuple(p), y))(x))
28
-
29
- # Run computation across devices
30
- arrayResults = parallelFoldingsTask(arrayTaskIndices)
31
-
32
- # Sum valid results (ignore padding)
33
- return jax.numpy.sum(arrayResults[:, :min(tasksPerDevice, n - tasksPerDevice * (deviceCount-1))])
34
- else:
35
- # Fall back to sequential execution if no multiple devices available
36
- arrayTaskIndices = jax.numpy.arange(n, dtype=dtypeDefault)
37
- batchedFoldingsTask = jax.vmap(lambda x: foldingsTask(tuple(p), x))
38
- return jax.numpy.sum(batchedFoldingsTask(arrayTaskIndices))
39
-
40
- def foldingsTask(p, taskIndex) -> jaxtyping.UInt32:
41
- arrayDimensions = jax.numpy.asarray(p, dtype=dtypeDefault)
42
- leavesTotal = jax.numpy.prod(arrayDimensions)
43
- dimensionsTotal = jax.numpy.size(arrayDimensions)
44
-
45
- """How to build a leaf connection graph, also called a "Cartesian Product Decomposition"
46
- or a "Dimensional Product Mapping", with sentinels:
47
- Step 1: find the cumulative product of the map's dimensions"""
48
- cumulativeProduct = jax.numpy.ones(dimensionsTotal + 1, dtype=dtypeDefault)
49
- cumulativeProduct = cumulativeProduct.at[1:].set(jax.numpy.cumprod(arrayDimensions))
50
-
51
- """Step 2: for each dimension, create a coordinate system """
52
- """coordinateSystem[dimension1ndex][leaf1ndex] holds the dimension1ndex-th coordinate of leaf leaf1ndex"""
53
- coordinateSystem = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
54
-
55
- # Create mesh of indices for vectorized computation
56
- dimension1ndices, leaf1ndices = jax.numpy.meshgrid(
57
- jax.numpy.arange(1, dimensionsTotal + 1),
58
- jax.numpy.arange(1, leavesTotal + 1),
59
- indexing='ij'
60
- )
61
-
62
- # Compute all coordinates at once using broadcasting
63
- coordinateSystem = coordinateSystem.at[1:, 1:].set(
64
- ((leaf1ndices - 1) // cumulativeProduct.at[dimension1ndices - 1].get()) %
65
- arrayDimensions.at[dimension1ndices - 1].get() + 1
66
- )
67
- del dimension1ndices, leaf1ndices
68
-
69
- """Step 3: create a huge empty connection graph"""
70
- connectionGraph = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
71
-
72
- # Create 3D mesh of indices for vectorized computation
73
- dimension1ndices, activeLeaf1ndices, connectee1ndices = jax.numpy.meshgrid(
74
- jax.numpy.arange(1, dimensionsTotal + 1),
75
- jax.numpy.arange(1, leavesTotal + 1),
76
- jax.numpy.arange(1, leavesTotal + 1),
77
- indexing='ij'
78
- )
79
-
80
- # Create masks for valid indices
81
- maskActiveConnectee = connectee1ndices <= activeLeaf1ndices
82
-
83
- # Calculate coordinate parity comparison
84
- coordsParity = (coordinateSystem.at[dimension1ndices, activeLeaf1ndices].get() & 1) == \
85
- (coordinateSystem.at[dimension1ndices, connectee1ndices].get() & 1)
86
-
87
- # Compute distance conditions
88
- isFirstCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == 1
89
- isLastCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == \
90
- arrayDimensions.at[dimension1ndices - 1].get()
91
- exceedsActive = connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get() > activeLeaf1ndices
92
-
93
- # Compute connection values for even and odd parities
94
- evenParityValues = jax.numpy.where(
95
- isFirstCoord,
96
- connectee1ndices,
97
- connectee1ndices - cumulativeProduct.at[dimension1ndices - 1].get()
98
- )
99
-
100
- oddParityValues = jax.numpy.where(
101
- jax.numpy.logical_or(isLastCoord, exceedsActive),
102
- connectee1ndices,
103
- connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get()
104
- )
105
-
106
- # Combine based on parity and valid indices
107
- connectionValues = jax.numpy.where(
108
- coordsParity,
109
- evenParityValues,
110
- oddParityValues
111
- )
112
-
113
- # Update only valid connections
114
- connectionGraph = connectionGraph.at[dimension1ndices, activeLeaf1ndices, connectee1ndices].set(
115
- jax.numpy.where(maskActiveConnectee, connectionValues, 0)
116
- )
117
-
118
- def doNothing(argument):
119
- return argument
120
-
121
- def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
122
- comparand = comparisonValues[6]
123
- return comparand > 0
124
-
125
- def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
126
- _0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
127
-
128
- sentinel = leafBelow.at[0].get().astype(jax.numpy.int32)
129
-
130
- allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
131
- lambda argumentX: dao(findGapsDo(argumentX)),
132
- lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
133
- allValues)
134
-
135
- return allValues
136
-
137
- def findGapsCondition(leafBelowSentinel, activeLeafNumber):
138
- return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
139
-
140
- def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
141
- def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
142
- return comparisonValues[-1] <= dimensionsTotal
143
-
144
- def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
145
- def ifLeafIsUnconstrainedCondition(comparand):
146
- return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
147
-
148
- def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
149
- unconstrained_unconstrainedLeaf = unconstrainedValues[3]
150
- unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
151
- return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
152
-
153
- def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
154
- def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
155
- return comparisonValues[-1] != activeLeaf1ndex
156
-
157
- def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
158
- # if taskDivisions == False or activeLeaf1ndex != leavesTotal or leaf1ndexConnectee % leavesTotal == taskIndex:
159
- def taskDivisionComparison():
160
- return jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
161
- # return taskDivisions == False or jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
162
-
163
- def taskDivisionDo(taskDivisionDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
164
- taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound = taskDivisionDoValues
165
-
166
- taskDivisionPotentialGaps = taskDivisionPotentialGaps.at[taskDivisionGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
167
- taskDivisionGap1ndexLowerBound = jax.numpy.where(
168
- jax.numpy.equal(taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].get(), 0), taskDivisionGap1ndexLowerBound + 1, taskDivisionGap1ndexLowerBound)
169
- taskDivisionCountDimensionsGapped = taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
170
-
171
- return (taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound)
172
-
173
- countGapsLeaf1ndexConnectee = countGapsDoValues[3]
174
- taskDivisionValues = (countGapsDoValues[0], countGapsDoValues[1], countGapsDoValues[2])
175
- taskDivisionValues = jax.lax.cond(taskDivisionComparison(), taskDivisionDo, doNothing, taskDivisionValues)
176
-
177
- countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.int32)
178
-
179
- return (taskDivisionValues[0], taskDivisionValues[1], taskDivisionValues[2], countGapsLeaf1ndexConnectee)
180
-
181
- unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
182
-
183
- leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.int32)
184
-
185
- countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
186
- countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
187
- unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
188
-
189
- return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
190
-
191
- dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
192
-
193
- ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
194
- ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
195
- dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
196
-
197
- dimensionNumber = 1 + dimensionNumber
198
- return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
199
-
200
- def almostUselessCondition(comparand):
201
- return comparand == dimensionsTotal
202
-
203
- def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
204
- def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
205
- return comparisonValues[-1] < activeLeaf1ndex
206
-
207
- def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
208
- leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
209
- leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
210
- gapNumberLowerBound = 1 + gapNumberLowerBound
211
- leafNumber = 1 + leafNumber
212
- return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
213
- return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
214
-
215
- def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
216
- return comparisonValues[-1] < gap1ndexCeiling
217
-
218
- def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
219
- gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
220
- gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
221
- activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.int32)
222
- gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
223
- index = 1 + index
224
- return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
225
-
226
- _0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
227
-
228
- unconstrainedLeaf = jax.numpy.int32(0)
229
- dimension1ndex = jax.numpy.int32(1)
230
- gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.int32)
231
- activeGap1ndex = gap1ndexCeiling
232
- for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
233
- for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
234
- countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
235
- del dimension1ndex
236
-
237
- leaf1ndex = jax.numpy.int32(0)
238
- for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
239
- for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
240
- gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
241
- del leaf1ndex
242
-
243
- indexMiniGap = activeGap1ndex
244
- miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
245
- miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
246
- countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
247
- del indexMiniGap
248
-
249
- return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
250
-
251
- def incrementCondition(leafBelowSentinel, activeLeafNumber):
252
- return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
253
-
254
- def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
255
- foldingsSubTotal = allValues[5]
256
- foldingsSubTotal = leavesTotal + foldingsSubTotal
257
- return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
258
-
259
- def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
260
- def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
261
- comparand = backtrackingValues[2]
262
- return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
263
-
264
- def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
265
- backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
266
-
267
- activeLeafNumber = activeLeafNumber - 1
268
- backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
269
- backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
270
-
271
- return (backtrackAbove, backtrackBelow, activeLeafNumber)
272
-
273
- def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
274
- return activeLeafNumber > 0
275
-
276
- def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
277
- placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
278
- activeGapNumber = activeGapNumber - 1
279
- placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
280
- placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
281
- placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
282
- placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
283
- placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
284
-
285
- activeLeafNumber = 1 + activeLeafNumber
286
- return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
287
-
288
- leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
289
-
290
- whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
291
- whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
292
- leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
293
-
294
- if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
295
- if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
296
- leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
297
-
298
- return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
299
-
300
- # Dynamic values
301
- A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
302
- B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
303
- count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
304
- gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
305
- gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
306
-
307
- foldingsSubTotal = jax.numpy.int32(0)
308
- l = jax.numpy.int32(1)
309
- g = jax.numpy.int32(0)
310
-
311
- foldingsValues = (A, B, count, gapter, gap, foldingsSubTotal, l, g)
312
- foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
313
- return foldingsValues[5]
@@ -1,44 +0,0 @@
1
- import numba
2
-
3
- @numba.jit((numba.int64[:, :, ::1], numba.int64[::1], numba.int64[::1], numba.int64[::1], numba.int64[:, ::1]), parallel=False, boundscheck=False, error_model='numpy', fastmath=True, looplift=False, nogil=True, nopython=True)
4
- def countInitialize(connectionGraph, gapsWhere, my, the, track):
5
- while my[6] > 0:
6
- if my[6] <= 1 or track[1, 0] == 1:
7
- my[1] = 0
8
- my[3] = track[3, my[6] - 1]
9
- my[0] = 1
10
- while my[0] <= the[0]:
11
- if connectionGraph[my[0], my[6], my[6]] == my[6]:
12
- my[1] += 1
13
- else:
14
- my[7] = connectionGraph[my[0], my[6], my[6]]
15
- while my[7] != my[6]:
16
- gapsWhere[my[3]] = my[7]
17
- if track[2, my[7]] == 0:
18
- my[3] += 1
19
- track[2, my[7]] += 1
20
- my[7] = connectionGraph[my[0], my[6], track[1, my[7]]]
21
- my[0] += 1
22
- if my[1] == the[0]:
23
- my[4] = 0
24
- while my[4] < my[6]:
25
- gapsWhere[my[3]] = my[4]
26
- my[3] += 1
27
- my[4] += 1
28
- my[5] = my[2]
29
- while my[5] < my[3]:
30
- gapsWhere[my[2]] = gapsWhere[my[5]]
31
- if track[2, gapsWhere[my[5]]] == the[0] - my[1]:
32
- my[2] += 1
33
- track[2, gapsWhere[my[5]]] = 0
34
- my[5] += 1
35
- if my[6] > 0:
36
- my[2] -= 1
37
- track[0, my[6]] = gapsWhere[my[2]]
38
- track[1, my[6]] = track[1, track[0, my[6]]]
39
- track[1, track[0, my[6]]] = my[6]
40
- track[0, track[1, my[6]]] = my[6]
41
- track[3, my[6]] = my[2]
42
- my[6] += 1
43
- if my[2] > 0:
44
- return
@@ -1,49 +0,0 @@
1
- import numba
2
-
3
- @numba.jit((numba.int64[:, :, ::1], numba.int64[::1], numba.int64[::1], numba.int64[::1], numba.int64[::1], numba.int64[:, ::1]), parallel=True, boundscheck=False, error_model='numpy', fastmath=True, looplift=False, nogil=True, nopython=True)
4
- def countParallel(connectionGraph, foldsSubTotals, gapsWherePARALLEL, myPARALLEL, the, trackPARALLEL):
5
- for indexSherpa in numba.prange(the[2]):
6
- gapsWhere = gapsWherePARALLEL.copy()
7
- my = myPARALLEL.copy()
8
- my[8] = indexSherpa
9
- track = trackPARALLEL.copy()
10
- while my[6] > 0:
11
- if my[6] <= 1 or track[1, 0] == 1:
12
- if my[6] > the[1]:
13
- foldsSubTotals[my[8]] += the[1]
14
- else:
15
- my[1] = 0
16
- my[3] = track[3, my[6] - 1]
17
- my[0] = 1
18
- while my[0] <= the[0]:
19
- if connectionGraph[my[0], my[6], my[6]] == my[6]:
20
- my[1] += 1
21
- else:
22
- my[7] = connectionGraph[my[0], my[6], my[6]]
23
- while my[7] != my[6]:
24
- if my[6] != the[2] or my[7] % the[2] == my[8]:
25
- gapsWhere[my[3]] = my[7]
26
- if track[2, my[7]] == 0:
27
- my[3] += 1
28
- track[2, my[7]] += 1
29
- my[7] = connectionGraph[my[0], my[6], track[1, my[7]]]
30
- my[0] += 1
31
- my[5] = my[2]
32
- while my[5] < my[3]:
33
- gapsWhere[my[2]] = gapsWhere[my[5]]
34
- if track[2, gapsWhere[my[5]]] == the[0] - my[1]:
35
- my[2] += 1
36
- track[2, gapsWhere[my[5]]] = 0
37
- my[5] += 1
38
- while my[6] > 0 and my[2] == track[3, my[6] - 1]:
39
- my[6] -= 1
40
- track[1, track[0, my[6]]] = track[1, my[6]]
41
- track[0, track[1, my[6]]] = track[0, my[6]]
42
- if my[6] > 0:
43
- my[2] -= 1
44
- track[0, my[6]] = gapsWhere[my[2]]
45
- track[1, my[6]] = track[1, track[0, my[6]]]
46
- track[1, track[0, my[6]]] = my[6]
47
- track[0, track[1, my[6]]] = my[6]
48
- track[3, my[6]] = my[2]
49
- my[6] += 1
@@ -1,43 +0,0 @@
1
- import numba
2
-
3
- @numba.jit((numba.int64[:, :, ::1], numba.int64[::1], numba.int64[::1], numba.int64[::1], numba.int64[::1], numba.int64[:, ::1]), parallel=False, boundscheck=False, error_model='numpy', fastmath=True, looplift=False, nogil=True, nopython=True)
4
- def countSequential(connectionGraph, foldsSubTotals, gapsWhere, my, the, track):
5
- while my[6] > 0:
6
- if my[6] <= 1 or track[1, 0] == 1:
7
- if my[6] > the[1]:
8
- foldsSubTotals[my[8]] += the[1]
9
- else:
10
- my[1] = 0
11
- my[3] = track[3, my[6] - 1]
12
- my[0] = 1
13
- while my[0] <= the[0]:
14
- if connectionGraph[my[0], my[6], my[6]] == my[6]:
15
- my[1] += 1
16
- else:
17
- my[7] = connectionGraph[my[0], my[6], my[6]]
18
- while my[7] != my[6]:
19
- gapsWhere[my[3]] = my[7]
20
- if track[2, my[7]] == 0:
21
- my[3] += 1
22
- track[2, my[7]] += 1
23
- my[7] = connectionGraph[my[0], my[6], track[1, my[7]]]
24
- my[0] += 1
25
- my[5] = my[2]
26
- while my[5] < my[3]:
27
- gapsWhere[my[2]] = gapsWhere[my[5]]
28
- if track[2, gapsWhere[my[5]]] == the[0] - my[1]:
29
- my[2] += 1
30
- track[2, gapsWhere[my[5]]] = 0
31
- my[5] += 1
32
- while my[6] > 0 and my[2] == track[3, my[6] - 1]:
33
- my[6] -= 1
34
- track[1, track[0, my[6]]] = track[1, my[6]]
35
- track[0, track[1, my[6]]] = track[0, my[6]]
36
- if my[6] > 0:
37
- my[2] -= 1
38
- track[0, my[6]] = gapsWhere[my[2]]
39
- track[1, my[6]] = track[1, track[0, my[6]]]
40
- track[1, track[0, my[6]]] = my[6]
41
- track[0, track[1, my[6]]] = my[6]
42
- track[3, my[6]] = my[2]
43
- my[6] += 1
@@ -1,35 +0,0 @@
1
- mapFolding/__init__.py,sha256=wnf2EzHR2unVha6-Y0gRoSPaE4PDdT4VngINa_dfT2E,337
2
- mapFolding/babbage.py,sha256=51fO7lwcTsTvSMwzKW1G2nGslGoEQt19IgnqZi8znao,2222
3
- mapFolding/beDRY.py,sha256=XawGabR1vhzOfdA46HSXmisA5EmxisTKdA3D98KDeac,13699
4
- mapFolding/countInitialize.py,sha256=pIeH52OwDMfuHXT2T4BbPmMm6r7zJnGc-e0QVQCKyDc,1824
5
- mapFolding/countParallel.py,sha256=1sLGIlMj_xZ4bFkG1srOPcDUCrSKc1q3x2QN_8l_sgY,2451
6
- mapFolding/countSequential.py,sha256=QSXwK3o8YBcxNrir_wGMXgqp38hXYTJanYXFLxUPCPo,1993
7
- mapFolding/importSelector.py,sha256=OY_LuUrLW5SFV6qM1tSgI2Rnfi5Bj3Fhdrkryo0WycE,392
8
- mapFolding/inlineAfunction.py,sha256=KO2snTNSGX-4urRtTOYqAZBCsBCaMfr5bo6rNZR9MPA,5102
9
- mapFolding/lovelace.py,sha256=iu7anbA_TacIAjc4EKkeBVxIJKAMdrYgvR4evzMZ1WY,15193
10
- mapFolding/oeis.py,sha256=_-fLGc1ybZ2eFxoiBrSmojMexeg6ROxtrLaBF2BzMn4,12144
11
- mapFolding/startHere.py,sha256=or7QhxgMls2hvP_I2eTBP5tffLrc3SMiE5Gz_Ik2aJY,4328
12
- mapFolding/theSSOT.py,sha256=3Zty4rYWOqrwivuCaKA71R0HM4rjmvtkL_Bsn4ZhwFo,2318
13
- mapFolding/JAX/lunnanJAX.py,sha256=xMZloN47q-MVfjdYOM1hi9qR4OnLq7qALmGLMraevQs,14819
14
- mapFolding/JAX/taskJAX.py,sha256=yJNeH0rL6EhJ6ppnATHF0Zf81CDMC10bnPnimVxE1hc,20037
15
- mapFolding/benchmarks/benchmarking.py,sha256=HD_0NSvuabblg94ftDre6LFnXShTe8MYj3hIodW-zV0,3076
16
- mapFolding/reference/flattened.py,sha256=X9nvRzg7YDcpCtSDTL4YiidjshlX9rg2e6JVCY6i2u0,16547
17
- mapFolding/reference/hunterNumba.py,sha256=0giUyqAFzP-XKcq3Kz8wIWCK0BVFhjABVJ1s-w4Jhu0,7109
18
- mapFolding/reference/irvineJavaPort.py,sha256=Sj-63Z-OsGuDoEBXuxyjRrNmmyl0d7Yz_XuY7I47Oyg,4250
19
- mapFolding/reference/lunnan.py,sha256=XEcql_gxvCCghb6Or3qwmPbn4IZUbZTaSmw_fUjRxZE,5037
20
- mapFolding/reference/lunnanNumpy.py,sha256=HqDgSwTOZA-G0oophOEfc4zs25Mv4yw2aoF1v8miOLk,4653
21
- mapFolding/reference/lunnanWhile.py,sha256=7NY2IKO5XBgol0aWWF_Fi-7oTL9pvu_z6lB0TF1uVHk,4063
22
- mapFolding/reference/rotatedEntryPoint.py,sha256=z0QyDQtnMvXNj5ntWzzJUQUMFm1-xHGLVhtYzwmczUI,11530
23
- mapFolding/reference/total_countPlus1vsPlusN.py,sha256=usenM8Yn_G1dqlPl7NKKkcnbohBZVZBXTQRm2S3_EDA,8106
24
- tests/__init__.py,sha256=eg9smg-6VblOr0kisM40CpGnuDtU2JgEEWGDTFVOlW8,57
25
- tests/conftest.py,sha256=AWB3m_jxMlkmOmGvk2ApJEk2ro5v8gmmJDcyLwN1oow,13761
26
- tests/pythons_idiotic_namespace.py,sha256=oOLDBergQqqhGuRpsXUnFD-R_6AlJipNKYHw-kk_OKw,33
27
- tests/test_oeis.py,sha256=vxnwO-cSR68htkyMh9QMVv-lvxBo6qlwPg1Rbx4JylY,7963
28
- tests/test_other.py,sha256=amhsy7VWzpuW_slBOTFPhC7e4o4k6Yp4xweNK1VHZnc,11906
29
- tests/test_tasks.py,sha256=Nwe4iuSjwGZvsw5CXCcic7tkBxgM5JX9mrGZMDYhAwE,1785
30
- tests/test_temporary.py,sha256=4FIEc9KGRpNsgU_eh8mXG49PSPqo8WLeZEyFI4Dpy3U,1127
31
- mapFolding-0.2.4.dist-info/METADATA,sha256=w1OgxNLylmuYfEKUsSIChm_8jLtjV63_OB04n8Btjm8,6543
32
- mapFolding-0.2.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
33
- mapFolding-0.2.4.dist-info/entry_points.txt,sha256=F3OUeZR1XDTpoH7k3wXuRb3KF_kXTTeYhu5AGK1SiOQ,146
34
- mapFolding-0.2.4.dist-info/top_level.txt,sha256=1gP2vFaqPwHujGwb3UjtMlLEGN-943VSYFR7V4gDqW8,17
35
- mapFolding-0.2.4.dist-info/RECORD,,
tests/test_temporary.py DELETED
@@ -1,25 +0,0 @@
1
- from tests.conftest import *
2
- from typing import Dict, List, Tuple
3
- import importlib
4
- import pytest
5
-
6
- @pytest.fixture(scope="session", autouse=True)
7
- def runSecondSetAfterAll(request: pytest.FixtureRequest):
8
- """Run after all other tests complete."""
9
- def toggleAndRerun():
10
- import mapFolding.importSelector
11
- import mapFolding.babbage
12
- mapFolding.importSelector.useLovelace = not mapFolding.importSelector.useLovelace
13
- importlib.reload(mapFolding.importSelector)
14
- importlib.reload(mapFolding.babbage)
15
-
16
- request.addfinalizer(toggleAndRerun)
17
-
18
- @pytest.mark.order(after="runSecondSetAfterAll")
19
- def test_myabilitytodealwithbs(oeisID: str):
20
- for n in settingsOEIS[oeisID]['valuesTestValidation']:
21
- standardComparison(settingsOEIS[oeisID]['valuesKnown'][n], oeisIDfor_n, oeisID, n)
22
-
23
- @pytest.mark.order(after="runSecondSetAfterAll")
24
- def test_eff_em_el(listDimensionsTest_countFolds: List[int], foldsTotalKnown: Dict[Tuple[int, ...], int]) -> None:
25
- standardComparison(foldsTotalKnown[tuple(listDimensionsTest_countFolds)], countFolds, listDimensionsTest_countFolds, None, 'maximum')