mapFolding 0.3.11__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. mapFolding/__init__.py +44 -32
  2. mapFolding/basecamp.py +50 -50
  3. mapFolding/beDRY.py +336 -336
  4. mapFolding/oeis.py +262 -262
  5. mapFolding/reference/flattened.py +294 -293
  6. mapFolding/reference/hunterNumba.py +126 -126
  7. mapFolding/reference/irvineJavaPort.py +99 -99
  8. mapFolding/reference/jax.py +153 -153
  9. mapFolding/reference/lunnan.py +148 -148
  10. mapFolding/reference/lunnanNumpy.py +115 -115
  11. mapFolding/reference/lunnanWhile.py +114 -114
  12. mapFolding/reference/rotatedEntryPoint.py +183 -183
  13. mapFolding/reference/total_countPlus1vsPlusN.py +203 -203
  14. mapFolding/someAssemblyRequired/__init__.py +2 -1
  15. mapFolding/someAssemblyRequired/getLLVMforNoReason.py +12 -12
  16. mapFolding/someAssemblyRequired/makeJob.py +48 -48
  17. mapFolding/someAssemblyRequired/synthesizeModuleJAX.py +17 -17
  18. mapFolding/someAssemblyRequired/synthesizeNumba.py +345 -803
  19. mapFolding/someAssemblyRequired/synthesizeNumbaGeneralized.py +371 -0
  20. mapFolding/someAssemblyRequired/synthesizeNumbaJob.py +150 -0
  21. mapFolding/someAssemblyRequired/synthesizeNumbaModules.py +75 -0
  22. mapFolding/syntheticModules/__init__.py +0 -0
  23. mapFolding/syntheticModules/numba_countInitialize.py +2 -2
  24. mapFolding/syntheticModules/numba_countParallel.py +3 -3
  25. mapFolding/syntheticModules/numba_countSequential.py +28 -28
  26. mapFolding/syntheticModules/numba_doTheNeedful.py +6 -6
  27. mapFolding/theDao.py +168 -169
  28. mapFolding/theSSOT.py +190 -162
  29. mapFolding/theSSOTnumba.py +91 -75
  30. mapFolding-0.4.0.dist-info/METADATA +122 -0
  31. mapFolding-0.4.0.dist-info/RECORD +41 -0
  32. tests/conftest.py +238 -128
  33. tests/test_oeis.py +80 -80
  34. tests/test_other.py +137 -224
  35. tests/test_tasks.py +21 -21
  36. tests/test_types.py +2 -2
  37. mapFolding-0.3.11.dist-info/METADATA +0 -155
  38. mapFolding-0.3.11.dist-info/RECORD +0 -39
  39. tests/conftest_tmpRegistry.py +0 -62
  40. tests/conftest_uniformTests.py +0 -53
  41. {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/LICENSE +0 -0
  42. {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/WHEEL +0 -0
  43. {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/entry_points.txt +0 -0
  44. {mapFolding-0.3.11.dist-info → mapFolding-0.4.0.dist-info}/top_level.txt +0 -0
@@ -9,200 +9,200 @@ dtypeMedium = jax.numpy.uint32
9
9
  dtypeMaximum = jax.numpy.uint32
10
10
 
11
11
  def countFolds(listDimensions: List[int]) -> int:
12
- listDimensionsPositive: List[int] = validateListDimensions(listDimensions)
12
+ listDimensionsPositive: List[int] = validateListDimensions(listDimensions)
13
13
 
14
- n: int = getLeavesTotal(listDimensionsPositive)
15
- d: int = len(listDimensions)
16
- import numpy
17
- D: numpy.ndarray = makeConnectionGraph(listDimensionsPositive)
18
- connectionGraph = jax.numpy.asarray(D, dtype=dtypeMedium)
19
- del listDimensionsPositive
14
+ n: int = getLeavesTotal(listDimensionsPositive)
15
+ d: int = len(listDimensions)
16
+ import numpy
17
+ D: numpy.ndarray = makeConnectionGraph(listDimensionsPositive)
18
+ connectionGraph = jax.numpy.asarray(D, dtype=dtypeMedium)
19
+ del listDimensionsPositive
20
20
 
21
- return foldingsJAX(n, d, connectionGraph)
21
+ return foldingsJAX(n, d, connectionGraph)
22
22
 
23
23
  def foldingsJAX(leavesTotal: jaxtyping.UInt32, dimensionsTotal: jaxtyping.UInt32, connectionGraph: jaxtyping.Array) -> jaxtyping.UInt32:
24
24
 
25
- def doNothing(argument):
26
- return argument
25
+ def doNothing(argument):
26
+ return argument
27
27
 
28
- def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
29
- comparand = comparisonValues[6]
30
- return comparand > 0
28
+ def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
29
+ comparand = comparisonValues[6]
30
+ return comparand > 0
31
31
 
32
- def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
33
- _0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
32
+ def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
33
+ _0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
34
34
 
35
- sentinel = leafBelow.at[0].get().astype(jax.numpy.uint32)
35
+ sentinel = leafBelow.at[0].get().astype(jax.numpy.uint32)
36
36
 
37
- allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
38
- lambda argumentX: dao(findGapsDo(argumentX)),
39
- lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
40
- allValues)
37
+ allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
38
+ lambda argumentX: dao(findGapsDo(argumentX)),
39
+ lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
40
+ allValues)
41
41
 
42
- return allValues
42
+ return allValues
43
43
 
44
- def findGapsCondition(leafBelowSentinel, activeLeafNumber):
45
- return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
44
+ def findGapsCondition(leafBelowSentinel, activeLeafNumber):
45
+ return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
46
46
 
47
- def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
48
- def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
49
- return comparisonValues[-1] <= dimensionsTotal
47
+ def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
48
+ def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
49
+ return comparisonValues[-1] <= dimensionsTotal
50
50
 
51
- def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
52
- def ifLeafIsUnconstrainedCondition(comparand):
53
- return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
51
+ def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
52
+ def ifLeafIsUnconstrainedCondition(comparand):
53
+ return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
54
54
 
55
- def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
56
- unconstrained_unconstrainedLeaf = unconstrainedValues[3]
57
- unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
58
- return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
55
+ def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
56
+ unconstrained_unconstrainedLeaf = unconstrainedValues[3]
57
+ unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
58
+ return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
59
59
 
60
- def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
61
- def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
62
- return comparisonValues[-1] != activeLeaf1ndex
60
+ def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
61
+ def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
62
+ return comparisonValues[-1] != activeLeaf1ndex
63
63
 
64
- def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
65
- countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee = countGapsDoValues
64
+ def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
65
+ countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee = countGapsDoValues
66
66
 
67
- countGapsPotentialGaps = countGapsPotentialGaps.at[countGapsGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
68
- countGapsGap1ndexLowerBound = jax.numpy.where(jax.numpy.equal(countGapsCountDimensionsGapped[countGapsLeaf1ndexConnectee], 0), countGapsGap1ndexLowerBound + 1, countGapsGap1ndexLowerBound)
69
- countGapsCountDimensionsGapped = countGapsCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
70
- countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.uint32)
67
+ countGapsPotentialGaps = countGapsPotentialGaps.at[countGapsGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
68
+ countGapsGap1ndexLowerBound = jax.numpy.where(jax.numpy.equal(countGapsCountDimensionsGapped[countGapsLeaf1ndexConnectee], 0), countGapsGap1ndexLowerBound + 1, countGapsGap1ndexLowerBound)
69
+ countGapsCountDimensionsGapped = countGapsCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
70
+ countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.uint32)
71
71
 
72
- return (countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee)
72
+ return (countGapsCountDimensionsGapped, countGapsPotentialGaps, countGapsGap1ndexLowerBound, countGapsLeaf1ndexConnectee)
73
73
 
74
- unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
74
+ unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
75
75
 
76
- leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.uint32)
76
+ leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.uint32)
77
77
 
78
- countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
79
- countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
80
- unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
78
+ countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
79
+ countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
80
+ unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
81
81
 
82
- return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
82
+ return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
83
83
 
84
- dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
84
+ dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
85
85
 
86
- ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
87
- ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
88
- dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
86
+ ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
87
+ ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
88
+ dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
89
89
 
90
- dimensionNumber = 1 + dimensionNumber
91
- return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
90
+ dimensionNumber = 1 + dimensionNumber
91
+ return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
92
92
 
93
- def almostUselessCondition(comparand):
94
- return comparand == dimensionsTotal
93
+ def almostUselessCondition(comparand):
94
+ return comparand == dimensionsTotal
95
95
 
96
- def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
97
- def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
98
- return comparisonValues[-1] < activeLeaf1ndex
96
+ def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
97
+ def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
98
+ return comparisonValues[-1] < activeLeaf1ndex
99
99
 
100
- def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
101
- leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
102
- leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
103
- gapNumberLowerBound = 1 + gapNumberLowerBound
104
- leafNumber = 1 + leafNumber
105
- return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
106
- return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
100
+ def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
101
+ leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
102
+ leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
103
+ gapNumberLowerBound = 1 + gapNumberLowerBound
104
+ leafNumber = 1 + leafNumber
105
+ return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
106
+ return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
107
107
 
108
- def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
109
- return comparisonValues[-1] < gap1ndexCeiling
108
+ def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
109
+ return comparisonValues[-1] < gap1ndexCeiling
110
110
 
111
- def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
112
- gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
113
- gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
114
- activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.uint32)
115
- gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
116
- index = 1 + index
117
- return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
111
+ def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
112
+ gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
113
+ gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
114
+ activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.uint32)
115
+ gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
116
+ index = 1 + index
117
+ return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
118
118
 
119
- _0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
119
+ _0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
120
120
 
121
- unconstrainedLeaf = jax.numpy.uint32(0)
122
- dimension1ndex = jax.numpy.uint32(1)
123
- gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.uint32)
124
- activeGap1ndex = gap1ndexCeiling
125
- for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
126
- for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
127
- countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
128
- del dimension1ndex
121
+ unconstrainedLeaf = jax.numpy.uint32(0)
122
+ dimension1ndex = jax.numpy.uint32(1)
123
+ gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.uint32)
124
+ activeGap1ndex = gap1ndexCeiling
125
+ for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
126
+ for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
127
+ countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
128
+ del dimension1ndex
129
129
 
130
- leaf1ndex = jax.numpy.uint32(0)
131
- for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
132
- for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
133
- gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
134
- del leaf1ndex
130
+ leaf1ndex = jax.numpy.uint32(0)
131
+ for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
132
+ for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
133
+ gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
134
+ del leaf1ndex
135
135
 
136
- indexMiniGap = activeGap1ndex
137
- miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
138
- miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
139
- countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
140
- del indexMiniGap
136
+ indexMiniGap = activeGap1ndex
137
+ miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
138
+ miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
139
+ countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
140
+ del indexMiniGap
141
141
 
142
- return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
142
+ return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
143
143
 
144
- def incrementCondition(leafBelowSentinel, activeLeafNumber):
145
- return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
144
+ def incrementCondition(leafBelowSentinel, activeLeafNumber):
145
+ return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
146
146
 
147
- def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
148
- foldingsSubTotal = allValues[5]
149
- foldingsSubTotal = leavesTotal + foldingsSubTotal
150
- return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
151
-
152
- def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
153
- def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
154
- comparand = backtrackingValues[2]
155
- return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
147
+ def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
148
+ foldingsSubTotal = allValues[5]
149
+ foldingsSubTotal = leavesTotal + foldingsSubTotal
150
+ return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
151
+
152
+ def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
153
+ def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
154
+ comparand = backtrackingValues[2]
155
+ return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
156
156
 
157
- def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
158
- backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
157
+ def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
158
+ backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
159
159
 
160
- activeLeafNumber = activeLeafNumber - 1
161
- backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
162
- backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
160
+ activeLeafNumber = activeLeafNumber - 1
161
+ backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
162
+ backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
163
163
 
164
- return (backtrackAbove, backtrackBelow, activeLeafNumber)
164
+ return (backtrackAbove, backtrackBelow, activeLeafNumber)
165
165
 
166
- def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
167
- return activeLeafNumber > 0
166
+ def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
167
+ return activeLeafNumber > 0
168
168
 
169
- def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
170
- placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
171
- activeGapNumber = activeGapNumber - 1
172
- placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
173
- placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
174
- placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
175
- placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
176
- placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
177
-
178
- activeLeafNumber = 1 + activeLeafNumber
179
- return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
180
-
181
- leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
182
-
183
- whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
184
- whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
185
- leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
186
-
187
- if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
188
- if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
189
- leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
190
-
191
- return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
192
-
193
- # Dynamic values
194
- A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
195
- B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
196
- count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
197
- gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
198
- gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
199
-
200
- foldingsTotal = jax.numpy.uint32(0)
201
- l = jax.numpy.uint32(1)
202
- g = jax.numpy.uint32(0)
203
-
204
- foldingsValues = (A, B, count, gapter, gap, foldingsTotal, l, g)
205
- foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
206
- return foldingsValues[5]
169
+ def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
170
+ placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
171
+ activeGapNumber = activeGapNumber - 1
172
+ placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
173
+ placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
174
+ placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
175
+ placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
176
+ placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
177
+
178
+ activeLeafNumber = 1 + activeLeafNumber
179
+ return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
180
+
181
+ leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
182
+
183
+ whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
184
+ whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
185
+ leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
186
+
187
+ if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
188
+ if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
189
+ leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
190
+
191
+ return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
192
+
193
+ # Dynamic values
194
+ A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
195
+ B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
196
+ count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
197
+ gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeMedium)
198
+ gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
199
+
200
+ foldingsTotal = jax.numpy.uint32(0)
201
+ l = jax.numpy.uint32(1)
202
+ g = jax.numpy.uint32(0)
203
+
204
+ foldingsValues = (A, B, count, gapter, gap, foldingsTotal, l, g)
205
+ foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
206
+ return foldingsValues[5]
207
207
 
208
208
  foldingsJAX = jax.jit(foldingsJAX, static_argnums=(0, 1))