mapFolding 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mapFolding/__init__.py +1 -1
- mapFolding/babbage.py +9 -4
- mapFolding/beDRY.py +47 -6
- mapFolding/benchmarks/benchmarking.py +3 -2
- mapFolding/importSelector.py +7 -0
- mapFolding/lovelace.py +92 -96
- mapFolding/{JAX/lunnanJAX.py → reference/jax.py} +2 -0
- mapFolding/someAssemblyRequired/inlineAfunction.py +152 -0
- mapFolding/someAssemblyRequired/jobsAndTasks.py +47 -0
- mapFolding/someAssemblyRequired/makeNuitkaSource.py +99 -0
- mapFolding/someAssemblyRequired/makeNumbaJob.py +121 -0
- mapFolding/startHere.py +8 -28
- mapFolding/theSSOT.py +13 -5
- {mapFolding-0.2.3.dist-info → mapFolding-0.2.5.dist-info}/METADATA +8 -6
- mapFolding-0.2.5.dist-info/RECORD +33 -0
- tests/conftest.py +8 -1
- tests/test_other.py +158 -88
- mapFolding/JAX/taskJAX.py +0 -313
- mapFolding/benchmarks/test_benchmarks.py +0 -74
- mapFolding-0.2.3.dist-info/RECORD +0 -30
- {mapFolding-0.2.3.dist-info → mapFolding-0.2.5.dist-info}/WHEEL +0 -0
- {mapFolding-0.2.3.dist-info → mapFolding-0.2.5.dist-info}/entry_points.txt +0 -0
- {mapFolding-0.2.3.dist-info → mapFolding-0.2.5.dist-info}/top_level.txt +0 -0
tests/test_other.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
|
|
2
|
-
from typing import List, Optional, Dict, Any, Union
|
|
1
|
+
import pathlib
|
|
3
2
|
from tests.conftest import *
|
|
4
3
|
from tests.pythons_idiotic_namespace import *
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
import itertools
|
|
6
|
+
import numba
|
|
7
|
+
import numpy
|
|
5
8
|
import pytest
|
|
9
|
+
import random
|
|
6
10
|
import sys
|
|
7
11
|
import unittest.mock
|
|
8
|
-
import numpy
|
|
9
|
-
import numba
|
|
10
12
|
|
|
11
13
|
@pytest.mark.parametrize("listDimensions,expected_intInnit,expected_parseListDimensions,expected_validateListDimensions,expected_getLeavesTotal", [
|
|
12
14
|
(None, ValueError, ValueError, ValueError, ValueError), # None instead of list
|
|
@@ -65,7 +67,7 @@ def test_getLeavesTotal_edge_cases() -> None:
|
|
|
65
67
|
])
|
|
66
68
|
def test_countFolds_writeFoldsTotal(
|
|
67
69
|
listDimensionsTestFunctionality: List[int],
|
|
68
|
-
pathTempTesting: Path,
|
|
70
|
+
pathTempTesting: pathlib.Path,
|
|
69
71
|
mockFoldingFunction,
|
|
70
72
|
foldsValue: int,
|
|
71
73
|
writeFoldsTarget: Optional[str]
|
|
@@ -82,7 +84,7 @@ def test_countFolds_writeFoldsTotal(
|
|
|
82
84
|
mock_countFolds = mockFoldingFunction(foldsValue, listDimensionsTestFunctionality)
|
|
83
85
|
|
|
84
86
|
with unittest.mock.patch("mapFolding.babbage._countFolds", side_effect=mock_countFolds):
|
|
85
|
-
returned = countFolds(listDimensionsTestFunctionality,
|
|
87
|
+
returned = countFolds(listDimensionsTestFunctionality, pathishWriteFoldsTotal=pathWriteTarget)
|
|
86
88
|
|
|
87
89
|
standardComparison(foldsValue, lambda: returned) # Check return value
|
|
88
90
|
standardComparison(str(foldsValue), lambda: (pathTempTesting / filenameFoldsTotalExpected).read_text()) # Check file content
|
|
@@ -97,18 +99,19 @@ def test_oopsieKwargsie() -> None:
|
|
|
97
99
|
for testName, testFunction in makeTestSuiteOopsieKwargsie(oopsieKwargsie).items():
|
|
98
100
|
testFunction()
|
|
99
101
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
102
|
+
@pytest.mark.parametrize("CPUlimit, expectedLimit", [
|
|
103
|
+
(None, numba.config.NUMBA_DEFAULT_NUM_THREADS), # type: ignore
|
|
104
|
+
(False, numba.config.NUMBA_DEFAULT_NUM_THREADS), # type: ignore
|
|
105
|
+
(True, 1),
|
|
106
|
+
(4, 4),
|
|
107
|
+
(0.5, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS // 2)), # type: ignore
|
|
108
|
+
(-0.5, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS // 2)), # type: ignore
|
|
109
|
+
(-2, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS - 2)), # type: ignore
|
|
110
|
+
(0, numba.config.NUMBA_DEFAULT_NUM_THREADS), # type: ignore
|
|
111
|
+
(1, 1),
|
|
112
|
+
])
|
|
113
|
+
def test_setCPUlimit(CPUlimit, expectedLimit) -> None:
|
|
114
|
+
standardComparison(expectedLimit, setCPUlimit, CPUlimit)
|
|
112
115
|
|
|
113
116
|
def test_makeConnectionGraph_nonNegative(listDimensionsTestFunctionality: List[int]) -> None:
|
|
114
117
|
connectionGraph = makeConnectionGraph(listDimensionsTestFunctionality)
|
|
@@ -119,80 +122,147 @@ def test_makeConnectionGraph_datatype(listDimensionsTestFunctionality: List[int]
|
|
|
119
122
|
connectionGraph = makeConnectionGraph(listDimensionsTestFunctionality, datatype=datatype)
|
|
120
123
|
assert connectionGraph.dtype == datatype, f"Expected datatype {datatype}, but got {connectionGraph.dtype}."
|
|
121
124
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
125
|
+
|
|
126
|
+
"""5 parameters
|
|
127
|
+
listDimensionsTestFunctionality
|
|
128
|
+
|
|
129
|
+
computationDivisions
|
|
130
|
+
None
|
|
131
|
+
random: int, first included: 2, first excluded: leavesTotal
|
|
132
|
+
maximum
|
|
133
|
+
cpu
|
|
134
|
+
|
|
135
|
+
CPUlimit
|
|
136
|
+
None
|
|
137
|
+
True
|
|
138
|
+
False
|
|
139
|
+
0
|
|
140
|
+
1
|
|
141
|
+
-1
|
|
142
|
+
random: 0 < float < 1
|
|
143
|
+
random: -1 < float < 0
|
|
144
|
+
random: int, first included: 2, first excluded: (min(leavesTotal, 16) - 1)
|
|
145
|
+
random: int, first included: -1 * (min(leavesTotal, 16) - 1), first excluded: -1
|
|
146
|
+
|
|
147
|
+
datatypeDefault
|
|
148
|
+
None
|
|
149
|
+
numpy.int64
|
|
150
|
+
numpy.intc
|
|
151
|
+
numpy.uint16
|
|
152
|
+
|
|
153
|
+
datatypeLarge
|
|
154
|
+
None
|
|
155
|
+
numpy.int64
|
|
156
|
+
numpy.intp
|
|
157
|
+
numpy.uint32
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
@pytest.fixture
|
|
162
|
+
def parameterIterator():
|
|
163
|
+
"""Generate random combinations of parameters for outfitCountFolds testing."""
|
|
164
|
+
parameterSets = {
|
|
165
|
+
'computationDivisions': [
|
|
166
|
+
None,
|
|
167
|
+
'maximum',
|
|
168
|
+
'cpu',
|
|
169
|
+
],
|
|
170
|
+
'CPUlimit': [
|
|
171
|
+
None, True, False, 0, 1, -1,
|
|
172
|
+
],
|
|
173
|
+
'datatypeDefault': [
|
|
174
|
+
None,
|
|
175
|
+
numpy.int64,
|
|
176
|
+
numpy.intc,
|
|
177
|
+
numpy.uint16
|
|
178
|
+
],
|
|
179
|
+
'datatypeLarge': [
|
|
180
|
+
None,
|
|
181
|
+
numpy.int64,
|
|
182
|
+
numpy.intp,
|
|
183
|
+
numpy.uint32
|
|
184
|
+
]
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
def makeParametersDynamic(listDimensions):
|
|
188
|
+
"""Add context-dependent parameter values."""
|
|
189
|
+
parametersDynamic = parameterSets.copy()
|
|
190
|
+
leavesTotal = getLeavesTotal(listDimensions)
|
|
191
|
+
concurrencyLimit = min(leavesTotal, 16)
|
|
192
|
+
|
|
193
|
+
# Add dynamic computationDivisions
|
|
194
|
+
parametersDynamic['computationDivisions'].extend(
|
|
195
|
+
[random.randint(2, leavesTotal-1) for iterator in range(3)]
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Add dynamic CPUlimit values
|
|
199
|
+
parameterDynamicCPU = [
|
|
200
|
+
random.random(), # 0 to 1
|
|
201
|
+
-random.random(), # -1 to 0
|
|
202
|
+
]
|
|
203
|
+
parameterDynamicCPU.extend(
|
|
204
|
+
[random.randint(2, concurrencyLimit-1) for iterator in range(2)]
|
|
205
|
+
)
|
|
206
|
+
parameterDynamicCPU.extend(
|
|
207
|
+
[random.randint(-concurrencyLimit+1, -2) for iterator in range(2)]
|
|
208
|
+
)
|
|
209
|
+
parametersDynamic['CPUlimit'].extend(parameterDynamicCPU)
|
|
210
|
+
|
|
211
|
+
return parametersDynamic
|
|
212
|
+
|
|
213
|
+
def generateCombinations(listDimensions):
|
|
214
|
+
parametersDynamic = makeParametersDynamic(listDimensions)
|
|
215
|
+
parameterKeys = list(parametersDynamic.keys())
|
|
216
|
+
parameterValues = [parametersDynamic[key] for key in parameterKeys]
|
|
217
|
+
|
|
218
|
+
# Shuffle each parameter list
|
|
219
|
+
for valueList in parameterValues:
|
|
220
|
+
random.shuffle(valueList)
|
|
221
|
+
|
|
222
|
+
# Use zip_longest to iterate, filling with None when shorter lists are exhausted
|
|
223
|
+
for combination in itertools.zip_longest(*parameterValues, fillvalue=None):
|
|
224
|
+
yield dict(zip(parameterKeys, combination))
|
|
225
|
+
|
|
226
|
+
return generateCombinations
|
|
227
|
+
# Must mock the set cpu count to avoid errors on GitHub
|
|
228
|
+
# def test_outfitCountFolds_basic(listDimensionsTestFunctionality, parameterIterator):
|
|
229
|
+
# """Basic validation of outfitCountFolds return value structure."""
|
|
230
|
+
# parameters = next(parameterIterator(listDimensionsTestFunctionality))
|
|
231
|
+
|
|
143
232
|
# stateInitialized = outfitCountFolds(
|
|
144
233
|
# listDimensionsTestFunctionality,
|
|
145
|
-
#
|
|
146
|
-
# CPUlimit=CPUlimit,
|
|
147
|
-
# **datatypeOverrides
|
|
234
|
+
# **{k: v for k, v in parameters.items() if v is not None}
|
|
148
235
|
# )
|
|
149
236
|
|
|
150
|
-
# #
|
|
151
|
-
#
|
|
152
|
-
#
|
|
153
|
-
# assert stateInitialized[keyRequired] is not None, f"Key has None value: {keyRequired}"
|
|
237
|
+
# # Basic structure tests
|
|
238
|
+
# assert isinstance(stateInitialized, dict)
|
|
239
|
+
# assert len(stateInitialized) == 7 # 6 ndarray + 1 tuple
|
|
154
240
|
|
|
155
|
-
#
|
|
156
|
-
#
|
|
157
|
-
#
|
|
158
|
-
# f"Type mismatch for {keyRequired}: expected {expectedType}, got {type(stateInitialized[keyRequired])}"
|
|
241
|
+
# # Check for specific keys
|
|
242
|
+
# requiredKeys = set(computationState.__annotations__.keys())
|
|
243
|
+
# assert set(stateInitialized.keys()) == requiredKeys
|
|
159
244
|
|
|
160
|
-
# #
|
|
161
|
-
#
|
|
162
|
-
#
|
|
245
|
+
# # Check types more carefully
|
|
246
|
+
# for key, value in stateInitialized.items():
|
|
247
|
+
# if key == 'mapShape':
|
|
248
|
+
# assert isinstance(value, tuple)
|
|
249
|
+
# assert all(isinstance(dim, int) for dim in value)
|
|
250
|
+
# else:
|
|
251
|
+
# assert isinstance(value, numpy.ndarray), f"{key} should be ndarray but is {type(value)}"
|
|
252
|
+
# assert issubclass(value.dtype.type, numpy.integer), \
|
|
253
|
+
# f"{key} should have integer dtype but has {value.dtype}"
|
|
163
254
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
#
|
|
167
|
-
|
|
168
|
-
#
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
#
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
#
|
|
177
|
-
|
|
178
|
-
# f"Array {arrayName} too small for enum {indexEnum.__name__}"
|
|
179
|
-
|
|
180
|
-
# # Test each enum index
|
|
181
|
-
# for enumMember in indexEnum:
|
|
182
|
-
# assert array[enumMember.value] >= 0, \
|
|
183
|
-
# f"Negative value at {arrayName}[{enumMember.name}]"
|
|
184
|
-
|
|
185
|
-
# # 4. Special value checks
|
|
186
|
-
# assert stateInitialized['my'][indexMy.leaf1ndex.value] == 1, \
|
|
187
|
-
# "Initial leaf index should be 1"
|
|
188
|
-
|
|
189
|
-
# # 5. Shape consistency
|
|
190
|
-
# leavesTotal = getLeavesTotal(listDimensionsTestFunctionality)
|
|
191
|
-
# assert stateInitialized['foldsSubTotals'].shape == (leavesTotal,), \
|
|
192
|
-
# "foldsSubTotals shape mismatch"
|
|
193
|
-
# assert stateInitialized['gapsWhere'].shape == (leavesTotal * leavesTotal + 1,), \
|
|
194
|
-
# "gapsWhere shape mismatch"
|
|
195
|
-
# assert stateInitialized['track'].shape == (len(indexTrack), leavesTotal + 1), \
|
|
196
|
-
# "track shape mismatch"
|
|
197
|
-
|
|
198
|
-
# TODO test `outfitCountFolds`; no negative values in arrays; compare datatypes to the typeddict; compare the connection graph to making a graph
|
|
255
|
+
def test_pathJobDEFAULT_colab():
|
|
256
|
+
"""Test that pathJobDEFAULT is set correctly when running in Google Colab."""
|
|
257
|
+
# Mock sys.modules to simulate running in Colab
|
|
258
|
+
with unittest.mock.patch.dict('sys.modules', {'google.colab': unittest.mock.MagicMock()}):
|
|
259
|
+
# Force reload of theSSOT to trigger Colab path logic
|
|
260
|
+
import importlib
|
|
261
|
+
import mapFolding.theSSOT
|
|
262
|
+
importlib.reload(mapFolding.theSSOT)
|
|
263
|
+
|
|
264
|
+
# Check that path was set to Colab-specific value
|
|
265
|
+
assert mapFolding.theSSOT.pathJobDEFAULT == pathlib.Path("/content/drive/MyDrive") / "jobs"
|
|
266
|
+
|
|
267
|
+
# Reload one more time to restore original state
|
|
268
|
+
importlib.reload(mapFolding.theSSOT)
|
mapFolding/JAX/taskJAX.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
|
1
|
-
from mapFolding import validateListDimensions, getLeavesTotal
|
|
2
|
-
from typing import List, Tuple
|
|
3
|
-
import jax
|
|
4
|
-
import jaxtyping
|
|
5
|
-
|
|
6
|
-
dtypeDefault = jax.numpy.int32
|
|
7
|
-
dtypeMaximum = jax.numpy.int32
|
|
8
|
-
|
|
9
|
-
def countFolds(listDimensions: List[int]):
|
|
10
|
-
"""Calculate foldings across multiple devices using pmap"""
|
|
11
|
-
p = validateListDimensions(listDimensions)
|
|
12
|
-
n = getLeavesTotal(p)
|
|
13
|
-
|
|
14
|
-
# Get number of devices (GPUs/TPUs)
|
|
15
|
-
deviceCount = jax.device_count()
|
|
16
|
-
|
|
17
|
-
if deviceCount > 1:
|
|
18
|
-
# Split work across devices
|
|
19
|
-
tasksPerDevice = (n + deviceCount - 1) // deviceCount
|
|
20
|
-
paddedTaskCount = tasksPerDevice * deviceCount
|
|
21
|
-
|
|
22
|
-
# Create padded array of task indices
|
|
23
|
-
arrayTaskIndices = jax.numpy.arange(paddedTaskCount, dtype=dtypeDefault)
|
|
24
|
-
arrayTaskIndices = arrayTaskIndices.reshape((deviceCount, tasksPerDevice))
|
|
25
|
-
|
|
26
|
-
# Create pmapped function
|
|
27
|
-
parallelFoldingsTask = jax.pmap(lambda x: jax.vmap(lambda y: foldingsTask(tuple(p), y))(x))
|
|
28
|
-
|
|
29
|
-
# Run computation across devices
|
|
30
|
-
arrayResults = parallelFoldingsTask(arrayTaskIndices)
|
|
31
|
-
|
|
32
|
-
# Sum valid results (ignore padding)
|
|
33
|
-
return jax.numpy.sum(arrayResults[:, :min(tasksPerDevice, n - tasksPerDevice * (deviceCount-1))])
|
|
34
|
-
else:
|
|
35
|
-
# Fall back to sequential execution if no multiple devices available
|
|
36
|
-
arrayTaskIndices = jax.numpy.arange(n, dtype=dtypeDefault)
|
|
37
|
-
batchedFoldingsTask = jax.vmap(lambda x: foldingsTask(tuple(p), x))
|
|
38
|
-
return jax.numpy.sum(batchedFoldingsTask(arrayTaskIndices))
|
|
39
|
-
|
|
40
|
-
def foldingsTask(p, taskIndex) -> jaxtyping.UInt32:
|
|
41
|
-
arrayDimensions = jax.numpy.asarray(p, dtype=dtypeDefault)
|
|
42
|
-
leavesTotal = jax.numpy.prod(arrayDimensions)
|
|
43
|
-
dimensionsTotal = jax.numpy.size(arrayDimensions)
|
|
44
|
-
|
|
45
|
-
"""How to build a leaf connection graph, also called a "Cartesian Product Decomposition"
|
|
46
|
-
or a "Dimensional Product Mapping", with sentinels:
|
|
47
|
-
Step 1: find the cumulative product of the map's dimensions"""
|
|
48
|
-
cumulativeProduct = jax.numpy.ones(dimensionsTotal + 1, dtype=dtypeDefault)
|
|
49
|
-
cumulativeProduct = cumulativeProduct.at[1:].set(jax.numpy.cumprod(arrayDimensions))
|
|
50
|
-
|
|
51
|
-
"""Step 2: for each dimension, create a coordinate system """
|
|
52
|
-
"""coordinateSystem[dimension1ndex][leaf1ndex] holds the dimension1ndex-th coordinate of leaf leaf1ndex"""
|
|
53
|
-
coordinateSystem = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
|
|
54
|
-
|
|
55
|
-
# Create mesh of indices for vectorized computation
|
|
56
|
-
dimension1ndices, leaf1ndices = jax.numpy.meshgrid(
|
|
57
|
-
jax.numpy.arange(1, dimensionsTotal + 1),
|
|
58
|
-
jax.numpy.arange(1, leavesTotal + 1),
|
|
59
|
-
indexing='ij'
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
# Compute all coordinates at once using broadcasting
|
|
63
|
-
coordinateSystem = coordinateSystem.at[1:, 1:].set(
|
|
64
|
-
((leaf1ndices - 1) // cumulativeProduct.at[dimension1ndices - 1].get()) %
|
|
65
|
-
arrayDimensions.at[dimension1ndices - 1].get() + 1
|
|
66
|
-
)
|
|
67
|
-
del dimension1ndices, leaf1ndices
|
|
68
|
-
|
|
69
|
-
"""Step 3: create a huge empty connection graph"""
|
|
70
|
-
connectionGraph = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
|
|
71
|
-
|
|
72
|
-
# Create 3D mesh of indices for vectorized computation
|
|
73
|
-
dimension1ndices, activeLeaf1ndices, connectee1ndices = jax.numpy.meshgrid(
|
|
74
|
-
jax.numpy.arange(1, dimensionsTotal + 1),
|
|
75
|
-
jax.numpy.arange(1, leavesTotal + 1),
|
|
76
|
-
jax.numpy.arange(1, leavesTotal + 1),
|
|
77
|
-
indexing='ij'
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
# Create masks for valid indices
|
|
81
|
-
maskActiveConnectee = connectee1ndices <= activeLeaf1ndices
|
|
82
|
-
|
|
83
|
-
# Calculate coordinate parity comparison
|
|
84
|
-
coordsParity = (coordinateSystem.at[dimension1ndices, activeLeaf1ndices].get() & 1) == \
|
|
85
|
-
(coordinateSystem.at[dimension1ndices, connectee1ndices].get() & 1)
|
|
86
|
-
|
|
87
|
-
# Compute distance conditions
|
|
88
|
-
isFirstCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == 1
|
|
89
|
-
isLastCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == \
|
|
90
|
-
arrayDimensions.at[dimension1ndices - 1].get()
|
|
91
|
-
exceedsActive = connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get() > activeLeaf1ndices
|
|
92
|
-
|
|
93
|
-
# Compute connection values for even and odd parities
|
|
94
|
-
evenParityValues = jax.numpy.where(
|
|
95
|
-
isFirstCoord,
|
|
96
|
-
connectee1ndices,
|
|
97
|
-
connectee1ndices - cumulativeProduct.at[dimension1ndices - 1].get()
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
oddParityValues = jax.numpy.where(
|
|
101
|
-
jax.numpy.logical_or(isLastCoord, exceedsActive),
|
|
102
|
-
connectee1ndices,
|
|
103
|
-
connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get()
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
# Combine based on parity and valid indices
|
|
107
|
-
connectionValues = jax.numpy.where(
|
|
108
|
-
coordsParity,
|
|
109
|
-
evenParityValues,
|
|
110
|
-
oddParityValues
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
# Update only valid connections
|
|
114
|
-
connectionGraph = connectionGraph.at[dimension1ndices, activeLeaf1ndices, connectee1ndices].set(
|
|
115
|
-
jax.numpy.where(maskActiveConnectee, connectionValues, 0)
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
def doNothing(argument):
|
|
119
|
-
return argument
|
|
120
|
-
|
|
121
|
-
def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
|
|
122
|
-
comparand = comparisonValues[6]
|
|
123
|
-
return comparand > 0
|
|
124
|
-
|
|
125
|
-
def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
126
|
-
_0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
|
|
127
|
-
|
|
128
|
-
sentinel = leafBelow.at[0].get().astype(jax.numpy.int32)
|
|
129
|
-
|
|
130
|
-
allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
|
|
131
|
-
lambda argumentX: dao(findGapsDo(argumentX)),
|
|
132
|
-
lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
|
|
133
|
-
allValues)
|
|
134
|
-
|
|
135
|
-
return allValues
|
|
136
|
-
|
|
137
|
-
def findGapsCondition(leafBelowSentinel, activeLeafNumber):
|
|
138
|
-
return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
|
|
139
|
-
|
|
140
|
-
def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
141
|
-
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
|
|
142
|
-
return comparisonValues[-1] <= dimensionsTotal
|
|
143
|
-
|
|
144
|
-
def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
145
|
-
def ifLeafIsUnconstrainedCondition(comparand):
|
|
146
|
-
return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
|
|
147
|
-
|
|
148
|
-
def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
149
|
-
unconstrained_unconstrainedLeaf = unconstrainedValues[3]
|
|
150
|
-
unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
|
|
151
|
-
return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
|
|
152
|
-
|
|
153
|
-
def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
154
|
-
def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
155
|
-
return comparisonValues[-1] != activeLeaf1ndex
|
|
156
|
-
|
|
157
|
-
def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
158
|
-
# if taskDivisions == False or activeLeaf1ndex != leavesTotal or leaf1ndexConnectee % leavesTotal == taskIndex:
|
|
159
|
-
def taskDivisionComparison():
|
|
160
|
-
return jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
|
|
161
|
-
# return taskDivisions == False or jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
|
|
162
|
-
|
|
163
|
-
def taskDivisionDo(taskDivisionDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
164
|
-
taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound = taskDivisionDoValues
|
|
165
|
-
|
|
166
|
-
taskDivisionPotentialGaps = taskDivisionPotentialGaps.at[taskDivisionGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
|
|
167
|
-
taskDivisionGap1ndexLowerBound = jax.numpy.where(
|
|
168
|
-
jax.numpy.equal(taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].get(), 0), taskDivisionGap1ndexLowerBound + 1, taskDivisionGap1ndexLowerBound)
|
|
169
|
-
taskDivisionCountDimensionsGapped = taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
|
|
170
|
-
|
|
171
|
-
return (taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound)
|
|
172
|
-
|
|
173
|
-
countGapsLeaf1ndexConnectee = countGapsDoValues[3]
|
|
174
|
-
taskDivisionValues = (countGapsDoValues[0], countGapsDoValues[1], countGapsDoValues[2])
|
|
175
|
-
taskDivisionValues = jax.lax.cond(taskDivisionComparison(), taskDivisionDo, doNothing, taskDivisionValues)
|
|
176
|
-
|
|
177
|
-
countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.int32)
|
|
178
|
-
|
|
179
|
-
return (taskDivisionValues[0], taskDivisionValues[1], taskDivisionValues[2], countGapsLeaf1ndexConnectee)
|
|
180
|
-
|
|
181
|
-
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
|
|
182
|
-
|
|
183
|
-
leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.int32)
|
|
184
|
-
|
|
185
|
-
countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
|
|
186
|
-
countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
|
|
187
|
-
unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
|
|
188
|
-
|
|
189
|
-
return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
|
|
190
|
-
|
|
191
|
-
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
192
|
-
|
|
193
|
-
ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
|
|
194
|
-
ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
|
|
195
|
-
dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
|
|
196
|
-
|
|
197
|
-
dimensionNumber = 1 + dimensionNumber
|
|
198
|
-
return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
|
|
199
|
-
|
|
200
|
-
def almostUselessCondition(comparand):
|
|
201
|
-
return comparand == dimensionsTotal
|
|
202
|
-
|
|
203
|
-
def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
204
|
-
def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
|
|
205
|
-
return comparisonValues[-1] < activeLeaf1ndex
|
|
206
|
-
|
|
207
|
-
def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
208
|
-
leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
209
|
-
leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
|
|
210
|
-
gapNumberLowerBound = 1 + gapNumberLowerBound
|
|
211
|
-
leafNumber = 1 + leafNumber
|
|
212
|
-
return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
|
|
213
|
-
return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
214
|
-
|
|
215
|
-
def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
216
|
-
return comparisonValues[-1] < gap1ndexCeiling
|
|
217
|
-
|
|
218
|
-
def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
219
|
-
gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
|
|
220
|
-
gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
|
|
221
|
-
activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.int32)
|
|
222
|
-
gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
|
|
223
|
-
index = 1 + index
|
|
224
|
-
return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
|
|
225
|
-
|
|
226
|
-
_0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
227
|
-
|
|
228
|
-
unconstrainedLeaf = jax.numpy.int32(0)
|
|
229
|
-
dimension1ndex = jax.numpy.int32(1)
|
|
230
|
-
gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.int32)
|
|
231
|
-
activeGap1ndex = gap1ndexCeiling
|
|
232
|
-
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
|
|
233
|
-
for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
|
|
234
|
-
countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
|
|
235
|
-
del dimension1ndex
|
|
236
|
-
|
|
237
|
-
leaf1ndex = jax.numpy.int32(0)
|
|
238
|
-
for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
|
|
239
|
-
for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
|
|
240
|
-
gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
|
|
241
|
-
del leaf1ndex
|
|
242
|
-
|
|
243
|
-
indexMiniGap = activeGap1ndex
|
|
244
|
-
miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
|
|
245
|
-
miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
|
|
246
|
-
countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
|
|
247
|
-
del indexMiniGap
|
|
248
|
-
|
|
249
|
-
return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
250
|
-
|
|
251
|
-
def incrementCondition(leafBelowSentinel, activeLeafNumber):
|
|
252
|
-
return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
|
|
253
|
-
|
|
254
|
-
def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
255
|
-
foldingsSubTotal = allValues[5]
|
|
256
|
-
foldingsSubTotal = leavesTotal + foldingsSubTotal
|
|
257
|
-
return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
|
|
258
|
-
|
|
259
|
-
def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
260
|
-
def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
261
|
-
comparand = backtrackingValues[2]
|
|
262
|
-
return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
|
|
263
|
-
|
|
264
|
-
def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
|
|
265
|
-
backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
|
|
266
|
-
|
|
267
|
-
activeLeafNumber = activeLeafNumber - 1
|
|
268
|
-
backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
|
|
269
|
-
backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
|
|
270
|
-
|
|
271
|
-
return (backtrackAbove, backtrackBelow, activeLeafNumber)
|
|
272
|
-
|
|
273
|
-
def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
|
|
274
|
-
return activeLeafNumber > 0
|
|
275
|
-
|
|
276
|
-
def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
|
|
277
|
-
placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
|
|
278
|
-
activeGapNumber = activeGapNumber - 1
|
|
279
|
-
placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
|
|
280
|
-
placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
|
|
281
|
-
placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
282
|
-
placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
|
|
283
|
-
placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
|
|
284
|
-
|
|
285
|
-
activeLeafNumber = 1 + activeLeafNumber
|
|
286
|
-
return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
|
|
287
|
-
|
|
288
|
-
leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
|
|
289
|
-
|
|
290
|
-
whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
|
|
291
|
-
whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
|
|
292
|
-
leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
|
|
293
|
-
|
|
294
|
-
if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
|
|
295
|
-
if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
|
|
296
|
-
leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
|
|
297
|
-
|
|
298
|
-
return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
|
|
299
|
-
|
|
300
|
-
# Dynamic values
|
|
301
|
-
A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
302
|
-
B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
303
|
-
count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
304
|
-
gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
|
|
305
|
-
gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
|
|
306
|
-
|
|
307
|
-
foldingsSubTotal = jax.numpy.int32(0)
|
|
308
|
-
l = jax.numpy.int32(1)
|
|
309
|
-
g = jax.numpy.int32(0)
|
|
310
|
-
|
|
311
|
-
foldingsValues = (A, B, count, gapter, gap, foldingsSubTotal, l, g)
|
|
312
|
-
foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
|
|
313
|
-
return foldingsValues[5]
|