mapFolding 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_other.py CHANGED
@@ -1,12 +1,14 @@
1
- from pathlib import Path
2
- from typing import List, Optional, Dict, Any, Union
1
+ import pathlib
3
2
  from tests.conftest import *
4
3
  from tests.pythons_idiotic_namespace import *
4
+ from typing import List, Optional
5
+ import itertools
6
+ import numba
7
+ import numpy
5
8
  import pytest
9
+ import random
6
10
  import sys
7
11
  import unittest.mock
8
- import numpy
9
- import numba
10
12
 
11
13
  @pytest.mark.parametrize("listDimensions,expected_intInnit,expected_parseListDimensions,expected_validateListDimensions,expected_getLeavesTotal", [
12
14
  (None, ValueError, ValueError, ValueError, ValueError), # None instead of list
@@ -65,7 +67,7 @@ def test_getLeavesTotal_edge_cases() -> None:
65
67
  ])
66
68
  def test_countFolds_writeFoldsTotal(
67
69
  listDimensionsTestFunctionality: List[int],
68
- pathTempTesting: Path,
70
+ pathTempTesting: pathlib.Path,
69
71
  mockFoldingFunction,
70
72
  foldsValue: int,
71
73
  writeFoldsTarget: Optional[str]
@@ -82,7 +84,7 @@ def test_countFolds_writeFoldsTotal(
82
84
  mock_countFolds = mockFoldingFunction(foldsValue, listDimensionsTestFunctionality)
83
85
 
84
86
  with unittest.mock.patch("mapFolding.babbage._countFolds", side_effect=mock_countFolds):
85
- returned = countFolds(listDimensionsTestFunctionality, writeFoldsTotal=pathWriteTarget)
87
+ returned = countFolds(listDimensionsTestFunctionality, pathishWriteFoldsTotal=pathWriteTarget)
86
88
 
87
89
  standardComparison(foldsValue, lambda: returned) # Check return value
88
90
  standardComparison(str(foldsValue), lambda: (pathTempTesting / filenameFoldsTotalExpected).read_text()) # Check file content
@@ -97,18 +99,19 @@ def test_oopsieKwargsie() -> None:
97
99
  for testName, testFunction in makeTestSuiteOopsieKwargsie(oopsieKwargsie).items():
98
100
  testFunction()
99
101
 
100
- # TODO mock CPU counts?
101
- # @pytest.mark.parametrize("CPUlimit, expectedLimit", [
102
- # (None, numba.config.NUMBA_DEFAULT_NUM_THREADS),
103
- # (False, numba.config.NUMBA_DEFAULT_NUM_THREADS),
104
- # (True, 1),
105
- # (4, 4),
106
- # (0.5, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS // 2)),
107
- # (-0.5, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS // 2)),
108
- # (-2, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS - 2)),
109
- # ])
110
- # def test_setCPUlimit(CPUlimit, expectedLimit) -> None:
111
- # standardComparison(expectedLimit, setCPUlimit, CPUlimit)
102
+ @pytest.mark.parametrize("CPUlimit, expectedLimit", [
103
+ (None, numba.config.NUMBA_DEFAULT_NUM_THREADS), # type: ignore
104
+ (False, numba.config.NUMBA_DEFAULT_NUM_THREADS), # type: ignore
105
+ (True, 1),
106
+ (4, 4),
107
+ (0.5, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS // 2)), # type: ignore
108
+ (-0.5, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS // 2)), # type: ignore
109
+ (-2, max(1, numba.config.NUMBA_DEFAULT_NUM_THREADS - 2)), # type: ignore
110
+ (0, numba.config.NUMBA_DEFAULT_NUM_THREADS), # type: ignore
111
+ (1, 1),
112
+ ])
113
+ def test_setCPUlimit(CPUlimit, expectedLimit) -> None:
114
+ standardComparison(expectedLimit, setCPUlimit, CPUlimit)
112
115
 
113
116
  def test_makeConnectionGraph_nonNegative(listDimensionsTestFunctionality: List[int]) -> None:
114
117
  connectionGraph = makeConnectionGraph(listDimensionsTestFunctionality)
@@ -119,80 +122,147 @@ def test_makeConnectionGraph_datatype(listDimensionsTestFunctionality: List[int]
119
122
  connectionGraph = makeConnectionGraph(listDimensionsTestFunctionality, datatype=datatype)
120
123
  assert connectionGraph.dtype == datatype, f"Expected datatype {datatype}, but got {connectionGraph.dtype}."
121
124
 
122
- # @pytest.mark.parametrize("computationDivisions,CPUlimit,datatypeOverrides", [
123
- # (None, None, {}), # Basic case
124
- # ("maximum", True, {"datatypeDefault": numpy.int32}), # Max divisions, min CPU, custom dtype
125
- # ("cpu", 4, {"datatypeLarge": numpy.int64}), # CPU-based divisions, fixed CPU limit
126
- # (3, 0.5, {}), # Fixed divisions, fractional CPU
127
- # ])
128
- # def test_outfitCountFolds(
129
- # listDimensionsTestFunctionality: List[int],
130
- # computationDivisions: Optional[Union[int, str]],
131
- # CPUlimit: Optional[Union[bool, float, int]],
132
- # datatypeOverrides: Dict[str, Any]
133
- # ) -> None:
134
- # """Test outfitCountFolds as a nexus of configuration and initialization.
135
-
136
- # Strategy:
137
- # 1. Validate structure against computationState TypedDict
138
- # 2. Compare with direct function calls
139
- # 3. Verify enum-based indexing
140
- # 4. Check datatypes and shapes
141
- # """
142
- # # Get initialized state
125
+
126
+ """5 parameters
127
+ listDimensionsTestFunctionality
128
+
129
+ computationDivisions
130
+ None
131
+ random: int, first included: 2, first excluded: leavesTotal
132
+ maximum
133
+ cpu
134
+
135
+ CPUlimit
136
+ None
137
+ True
138
+ False
139
+ 0
140
+ 1
141
+ -1
142
+ random: 0 < float < 1
143
+ random: -1 < float < 0
144
+ random: int, first included: 2, first excluded: (min(leavesTotal, 16) - 1)
145
+ random: int, first included: -1 * (min(leavesTotal, 16) - 1), first excluded: -1
146
+
147
+ datatypeDefault
148
+ None
149
+ numpy.int64
150
+ numpy.intc
151
+ numpy.uint16
152
+
153
+ datatypeLarge
154
+ None
155
+ numpy.int64
156
+ numpy.intp
157
+ numpy.uint32
158
+
159
+ """
160
+
161
+ @pytest.fixture
162
+ def parameterIterator():
163
+ """Generate random combinations of parameters for outfitCountFolds testing."""
164
+ parameterSets = {
165
+ 'computationDivisions': [
166
+ None,
167
+ 'maximum',
168
+ 'cpu',
169
+ ],
170
+ 'CPUlimit': [
171
+ None, True, False, 0, 1, -1,
172
+ ],
173
+ 'datatypeDefault': [
174
+ None,
175
+ numpy.int64,
176
+ numpy.intc,
177
+ numpy.uint16
178
+ ],
179
+ 'datatypeLarge': [
180
+ None,
181
+ numpy.int64,
182
+ numpy.intp,
183
+ numpy.uint32
184
+ ]
185
+ }
186
+
187
+ def makeParametersDynamic(listDimensions):
188
+ """Add context-dependent parameter values."""
189
+ parametersDynamic = parameterSets.copy()
190
+ leavesTotal = getLeavesTotal(listDimensions)
191
+ concurrencyLimit = min(leavesTotal, 16)
192
+
193
+ # Add dynamic computationDivisions
194
+ parametersDynamic['computationDivisions'].extend(
195
+ [random.randint(2, leavesTotal-1) for iterator in range(3)]
196
+ )
197
+
198
+ # Add dynamic CPUlimit values
199
+ parameterDynamicCPU = [
200
+ random.random(), # 0 to 1
201
+ -random.random(), # -1 to 0
202
+ ]
203
+ parameterDynamicCPU.extend(
204
+ [random.randint(2, concurrencyLimit-1) for iterator in range(2)]
205
+ )
206
+ parameterDynamicCPU.extend(
207
+ [random.randint(-concurrencyLimit+1, -2) for iterator in range(2)]
208
+ )
209
+ parametersDynamic['CPUlimit'].extend(parameterDynamicCPU)
210
+
211
+ return parametersDynamic
212
+
213
+ def generateCombinations(listDimensions):
214
+ parametersDynamic = makeParametersDynamic(listDimensions)
215
+ parameterKeys = list(parametersDynamic.keys())
216
+ parameterValues = [parametersDynamic[key] for key in parameterKeys]
217
+
218
+ # Shuffle each parameter list
219
+ for valueList in parameterValues:
220
+ random.shuffle(valueList)
221
+
222
+ # Use zip_longest to iterate, filling with None when shorter lists are exhausted
223
+ for combination in itertools.zip_longest(*parameterValues, fillvalue=None):
224
+ yield dict(zip(parameterKeys, combination))
225
+
226
+ return generateCombinations
227
+ # Must mock the set cpu count to avoid errors on GitHub
228
+ # def test_outfitCountFolds_basic(listDimensionsTestFunctionality, parameterIterator):
229
+ # """Basic validation of outfitCountFolds return value structure."""
230
+ # parameters = next(parameterIterator(listDimensionsTestFunctionality))
231
+
143
232
  # stateInitialized = outfitCountFolds(
144
233
  # listDimensionsTestFunctionality,
145
- # computationDivisions=computationDivisions,
146
- # CPUlimit=CPUlimit,
147
- # **datatypeOverrides
234
+ # **{k: v for k, v in parameters.items() if v is not None}
148
235
  # )
149
236
 
150
- # # 1. TypedDict structure validation
151
- # for keyRequired in computationState.__annotations__:
152
- # assert keyRequired in stateInitialized, f"Missing required key: {keyRequired}"
153
- # assert stateInitialized[keyRequired] is not None, f"Key has None value: {keyRequired}"
237
+ # # Basic structure tests
238
+ # assert isinstance(stateInitialized, dict)
239
+ # assert len(stateInitialized) == 7 # 6 ndarray + 1 tuple
154
240
 
155
- # # Type checking
156
- # expectedType = computationState.__annotations__[keyRequired]
157
- # assert isinstance(stateInitialized[keyRequired], expectedType), \
158
- # f"Type mismatch for {keyRequired}: expected {expectedType}, got {type(stateInitialized[keyRequired])}"
241
+ # # Check for specific keys
242
+ # requiredKeys = set(computationState.__annotations__.keys())
243
+ # assert set(stateInitialized.keys()) == requiredKeys
159
244
 
160
- # # 2. Compare with direct function calls
161
- # directMapShape = tuple(sorted(validateListDimensions(listDimensionsTestFunctionality)))
162
- # assert stateInitialized['mapShape'] == directMapShape
245
+ # # Check types more carefully
246
+ # for key, value in stateInitialized.items():
247
+ # if key == 'mapShape':
248
+ # assert isinstance(value, tuple)
249
+ # assert all(isinstance(dim, int) for dim in value)
250
+ # else:
251
+ # assert isinstance(value, numpy.ndarray), f"{key} should be ndarray but is {type(value)}"
252
+ # assert issubclass(value.dtype.type, numpy.integer), \
253
+ # f"{key} should have integer dtype but has {value.dtype}"
163
254
 
164
- # directConnectionGraph = makeConnectionGraph(
165
- # directMapShape,
166
- # datatype=datatypeOverrides.get('datatypeDefault', dtypeDefault)
167
- # )
168
- # assert numpy.array_equal(stateInitialized['connectionGraph'], directConnectionGraph)
169
-
170
- # # 3. Enum-based indexing validation
171
- # for arrayName, indexEnum in [
172
- # ('my', indexMy),
173
- # ('the', indexThe),
174
- # ('track', indexTrack)
175
- # ]:
176
- # array = stateInitialized[arrayName]
177
- # assert array.shape[0] >= len(indexEnum), \
178
- # f"Array {arrayName} too small for enum {indexEnum.__name__}"
179
-
180
- # # Test each enum index
181
- # for enumMember in indexEnum:
182
- # assert array[enumMember.value] >= 0, \
183
- # f"Negative value at {arrayName}[{enumMember.name}]"
184
-
185
- # # 4. Special value checks
186
- # assert stateInitialized['my'][indexMy.leaf1ndex.value] == 1, \
187
- # "Initial leaf index should be 1"
188
-
189
- # # 5. Shape consistency
190
- # leavesTotal = getLeavesTotal(listDimensionsTestFunctionality)
191
- # assert stateInitialized['foldsSubTotals'].shape == (leavesTotal,), \
192
- # "foldsSubTotals shape mismatch"
193
- # assert stateInitialized['gapsWhere'].shape == (leavesTotal * leavesTotal + 1,), \
194
- # "gapsWhere shape mismatch"
195
- # assert stateInitialized['track'].shape == (len(indexTrack), leavesTotal + 1), \
196
- # "track shape mismatch"
197
-
198
- # TODO test `outfitCountFolds`; no negative values in arrays; compare datatypes to the typeddict; compare the connection graph to making a graph
255
+ def test_pathJobDEFAULT_colab():
256
+ """Test that pathJobDEFAULT is set correctly when running in Google Colab."""
257
+ # Mock sys.modules to simulate running in Colab
258
+ with unittest.mock.patch.dict('sys.modules', {'google.colab': unittest.mock.MagicMock()}):
259
+ # Force reload of theSSOT to trigger Colab path logic
260
+ import importlib
261
+ import mapFolding.theSSOT
262
+ importlib.reload(mapFolding.theSSOT)
263
+
264
+ # Check that path was set to Colab-specific value
265
+ assert mapFolding.theSSOT.pathJobDEFAULT == pathlib.Path("/content/drive/MyDrive") / "jobs"
266
+
267
+ # Reload one more time to restore original state
268
+ importlib.reload(mapFolding.theSSOT)
mapFolding/JAX/taskJAX.py DELETED
@@ -1,313 +0,0 @@
1
- from mapFolding import validateListDimensions, getLeavesTotal
2
- from typing import List, Tuple
3
- import jax
4
- import jaxtyping
5
-
6
- dtypeDefault = jax.numpy.int32
7
- dtypeMaximum = jax.numpy.int32
8
-
9
- def countFolds(listDimensions: List[int]):
10
- """Calculate foldings across multiple devices using pmap"""
11
- p = validateListDimensions(listDimensions)
12
- n = getLeavesTotal(p)
13
-
14
- # Get number of devices (GPUs/TPUs)
15
- deviceCount = jax.device_count()
16
-
17
- if deviceCount > 1:
18
- # Split work across devices
19
- tasksPerDevice = (n + deviceCount - 1) // deviceCount
20
- paddedTaskCount = tasksPerDevice * deviceCount
21
-
22
- # Create padded array of task indices
23
- arrayTaskIndices = jax.numpy.arange(paddedTaskCount, dtype=dtypeDefault)
24
- arrayTaskIndices = arrayTaskIndices.reshape((deviceCount, tasksPerDevice))
25
-
26
- # Create pmapped function
27
- parallelFoldingsTask = jax.pmap(lambda x: jax.vmap(lambda y: foldingsTask(tuple(p), y))(x))
28
-
29
- # Run computation across devices
30
- arrayResults = parallelFoldingsTask(arrayTaskIndices)
31
-
32
- # Sum valid results (ignore padding)
33
- return jax.numpy.sum(arrayResults[:, :min(tasksPerDevice, n - tasksPerDevice * (deviceCount-1))])
34
- else:
35
- # Fall back to sequential execution if no multiple devices available
36
- arrayTaskIndices = jax.numpy.arange(n, dtype=dtypeDefault)
37
- batchedFoldingsTask = jax.vmap(lambda x: foldingsTask(tuple(p), x))
38
- return jax.numpy.sum(batchedFoldingsTask(arrayTaskIndices))
39
-
40
- def foldingsTask(p, taskIndex) -> jaxtyping.UInt32:
41
- arrayDimensions = jax.numpy.asarray(p, dtype=dtypeDefault)
42
- leavesTotal = jax.numpy.prod(arrayDimensions)
43
- dimensionsTotal = jax.numpy.size(arrayDimensions)
44
-
45
- """How to build a leaf connection graph, also called a "Cartesian Product Decomposition"
46
- or a "Dimensional Product Mapping", with sentinels:
47
- Step 1: find the cumulative product of the map's dimensions"""
48
- cumulativeProduct = jax.numpy.ones(dimensionsTotal + 1, dtype=dtypeDefault)
49
- cumulativeProduct = cumulativeProduct.at[1:].set(jax.numpy.cumprod(arrayDimensions))
50
-
51
- """Step 2: for each dimension, create a coordinate system """
52
- """coordinateSystem[dimension1ndex][leaf1ndex] holds the dimension1ndex-th coordinate of leaf leaf1ndex"""
53
- coordinateSystem = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
54
-
55
- # Create mesh of indices for vectorized computation
56
- dimension1ndices, leaf1ndices = jax.numpy.meshgrid(
57
- jax.numpy.arange(1, dimensionsTotal + 1),
58
- jax.numpy.arange(1, leavesTotal + 1),
59
- indexing='ij'
60
- )
61
-
62
- # Compute all coordinates at once using broadcasting
63
- coordinateSystem = coordinateSystem.at[1:, 1:].set(
64
- ((leaf1ndices - 1) // cumulativeProduct.at[dimension1ndices - 1].get()) %
65
- arrayDimensions.at[dimension1ndices - 1].get() + 1
66
- )
67
- del dimension1ndices, leaf1ndices
68
-
69
- """Step 3: create a huge empty connection graph"""
70
- connectionGraph = jax.numpy.zeros((dimensionsTotal + 1, leavesTotal + 1, leavesTotal + 1), dtype=dtypeDefault)
71
-
72
- # Create 3D mesh of indices for vectorized computation
73
- dimension1ndices, activeLeaf1ndices, connectee1ndices = jax.numpy.meshgrid(
74
- jax.numpy.arange(1, dimensionsTotal + 1),
75
- jax.numpy.arange(1, leavesTotal + 1),
76
- jax.numpy.arange(1, leavesTotal + 1),
77
- indexing='ij'
78
- )
79
-
80
- # Create masks for valid indices
81
- maskActiveConnectee = connectee1ndices <= activeLeaf1ndices
82
-
83
- # Calculate coordinate parity comparison
84
- coordsParity = (coordinateSystem.at[dimension1ndices, activeLeaf1ndices].get() & 1) == \
85
- (coordinateSystem.at[dimension1ndices, connectee1ndices].get() & 1)
86
-
87
- # Compute distance conditions
88
- isFirstCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == 1
89
- isLastCoord = coordinateSystem.at[dimension1ndices, connectee1ndices].get() == \
90
- arrayDimensions.at[dimension1ndices - 1].get()
91
- exceedsActive = connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get() > activeLeaf1ndices
92
-
93
- # Compute connection values for even and odd parities
94
- evenParityValues = jax.numpy.where(
95
- isFirstCoord,
96
- connectee1ndices,
97
- connectee1ndices - cumulativeProduct.at[dimension1ndices - 1].get()
98
- )
99
-
100
- oddParityValues = jax.numpy.where(
101
- jax.numpy.logical_or(isLastCoord, exceedsActive),
102
- connectee1ndices,
103
- connectee1ndices + cumulativeProduct.at[dimension1ndices - 1].get()
104
- )
105
-
106
- # Combine based on parity and valid indices
107
- connectionValues = jax.numpy.where(
108
- coordsParity,
109
- evenParityValues,
110
- oddParityValues
111
- )
112
-
113
- # Update only valid connections
114
- connectionGraph = connectionGraph.at[dimension1ndices, activeLeaf1ndices, connectee1ndices].set(
115
- jax.numpy.where(maskActiveConnectee, connectionValues, 0)
116
- )
117
-
118
- def doNothing(argument):
119
- return argument
120
-
121
- def while_activeLeaf1ndex_greaterThan_0(comparisonValues: Tuple):
122
- comparand = comparisonValues[6]
123
- return comparand > 0
124
-
125
- def countFoldings(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
126
- _0, leafBelow, _2, _3, _4, _5, activeLeaf1ndex, _7 = allValues
127
-
128
- sentinel = leafBelow.at[0].get().astype(jax.numpy.int32)
129
-
130
- allValues = jax.lax.cond(findGapsCondition(sentinel, activeLeaf1ndex),
131
- lambda argumentX: dao(findGapsDo(argumentX)),
132
- lambda argumentY: jax.lax.cond(incrementCondition(sentinel, activeLeaf1ndex), lambda argumentZ: dao(incrementDo(argumentZ)), dao, argumentY),
133
- allValues)
134
-
135
- return allValues
136
-
137
- def findGapsCondition(leafBelowSentinel, activeLeafNumber):
138
- return jax.numpy.logical_or(jax.numpy.logical_and(leafBelowSentinel == 1, activeLeafNumber <= leavesTotal), activeLeafNumber <= 1)
139
-
140
- def findGapsDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
141
- def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1(comparisonValues: Tuple):
142
- return comparisonValues[-1] <= dimensionsTotal
143
-
144
- def for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
145
- def ifLeafIsUnconstrainedCondition(comparand):
146
- return jax.numpy.equal(connectionGraph[comparand, activeLeaf1ndex, activeLeaf1ndex], activeLeaf1ndex)
147
-
148
- def ifLeafIsUnconstrainedDo(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
149
- unconstrained_unconstrainedLeaf = unconstrainedValues[3]
150
- unconstrained_unconstrainedLeaf = 1 + unconstrained_unconstrainedLeaf
151
- return (unconstrainedValues[0], unconstrainedValues[1], unconstrainedValues[2], unconstrained_unconstrainedLeaf)
152
-
153
- def ifLeafIsUnconstrainedElse(unconstrainedValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
154
- def while_leaf1ndexConnectee_notEquals_activeLeaf1ndex(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
155
- return comparisonValues[-1] != activeLeaf1ndex
156
-
157
- def countGaps(countGapsDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
158
- # if taskDivisions == False or activeLeaf1ndex != leavesTotal or leaf1ndexConnectee % leavesTotal == taskIndex:
159
- def taskDivisionComparison():
160
- return jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
161
- # return taskDivisions == False or jax.numpy.logical_or(activeLeaf1ndex != leavesTotal, jax.numpy.equal(countGapsLeaf1ndexConnectee % leavesTotal, taskIndex))
162
-
163
- def taskDivisionDo(taskDivisionDoValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
164
- taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound = taskDivisionDoValues
165
-
166
- taskDivisionPotentialGaps = taskDivisionPotentialGaps.at[taskDivisionGap1ndexLowerBound].set(countGapsLeaf1ndexConnectee)
167
- taskDivisionGap1ndexLowerBound = jax.numpy.where(
168
- jax.numpy.equal(taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].get(), 0), taskDivisionGap1ndexLowerBound + 1, taskDivisionGap1ndexLowerBound)
169
- taskDivisionCountDimensionsGapped = taskDivisionCountDimensionsGapped.at[countGapsLeaf1ndexConnectee].add(1)
170
-
171
- return (taskDivisionCountDimensionsGapped, taskDivisionPotentialGaps, taskDivisionGap1ndexLowerBound)
172
-
173
- countGapsLeaf1ndexConnectee = countGapsDoValues[3]
174
- taskDivisionValues = (countGapsDoValues[0], countGapsDoValues[1], countGapsDoValues[2])
175
- taskDivisionValues = jax.lax.cond(taskDivisionComparison(), taskDivisionDo, doNothing, taskDivisionValues)
176
-
177
- countGapsLeaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, leafBelow.at[countGapsLeaf1ndexConnectee].get()].get().astype(jax.numpy.int32)
178
-
179
- return (taskDivisionValues[0], taskDivisionValues[1], taskDivisionValues[2], countGapsLeaf1ndexConnectee)
180
-
181
- unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf = unconstrainedValues
182
-
183
- leaf1ndexConnectee = connectionGraph.at[dimensionNumber, activeLeaf1ndex, activeLeaf1ndex].get().astype(jax.numpy.int32)
184
-
185
- countGapsValues = (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee)
186
- countGapsValues = jax.lax.while_loop(while_leaf1ndexConnectee_notEquals_activeLeaf1ndex, countGaps, countGapsValues)
187
- unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, leaf1ndexConnectee = countGapsValues
188
-
189
- return (unconstrained_countDimensionsGapped, unconstrained_gapsWhere, unconstrained_gap1ndexCeiling, unconstrained_unconstrainedLeaf)
190
-
191
- dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
192
-
193
- ifLeafIsUnconstrainedValues = (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf)
194
- ifLeafIsUnconstrainedValues = jax.lax.cond(ifLeafIsUnconstrainedCondition(dimensionNumber), ifLeafIsUnconstrainedDo, ifLeafIsUnconstrainedElse, ifLeafIsUnconstrainedValues)
195
- dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf = ifLeafIsUnconstrainedValues
196
-
197
- dimensionNumber = 1 + dimensionNumber
198
- return (dimensions_countDimensionsGapped, dimensions_gapsWhere, dimensions_gap1ndexCeiling, dimensions_unconstrainedLeaf, dimensionNumber)
199
-
200
- def almostUselessCondition(comparand):
201
- return comparand == dimensionsTotal
202
-
203
- def almostUselessConditionDo(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
204
- def for_leaf1ndex_in_range_activeLeaf1ndex(comparisonValues):
205
- return comparisonValues[-1] < activeLeaf1ndex
206
-
207
- def for_leaf1ndex_in_range_activeLeaf1ndex_do(for_leaf1ndex_in_range_activeLeaf1ndexValues: Tuple[jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
208
- leafInRangePotentialGaps, gapNumberLowerBound, leafNumber = for_leaf1ndex_in_range_activeLeaf1ndexValues
209
- leafInRangePotentialGaps = leafInRangePotentialGaps.at[gapNumberLowerBound].set(leafNumber)
210
- gapNumberLowerBound = 1 + gapNumberLowerBound
211
- leafNumber = 1 + leafNumber
212
- return (leafInRangePotentialGaps, gapNumberLowerBound, leafNumber)
213
- return jax.lax.while_loop(for_leaf1ndex_in_range_activeLeaf1ndex, for_leaf1ndex_in_range_activeLeaf1ndex_do, for_leaf1ndex_in_range_activeLeaf1ndexValues)
214
-
215
- def for_range_from_activeGap1ndex_to_gap1ndexCeiling(comparisonValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
216
- return comparisonValues[-1] < gap1ndexCeiling
217
-
218
- def miniGapDo(gapToGapValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
219
- gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index = gapToGapValues
220
- gapToGapPotentialGaps = gapToGapPotentialGaps.at[activeGapNumber].set(gapToGapPotentialGaps.at[index].get())
221
- activeGapNumber = jax.numpy.where(jax.numpy.equal(gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].get(), dimensionsTotal - unconstrainedLeaf), activeGapNumber + 1, activeGapNumber).astype(jax.numpy.int32)
222
- gapToGapCountDimensionsGapped = gapToGapCountDimensionsGapped.at[gapToGapPotentialGaps.at[index].get()].set(0)
223
- index = 1 + index
224
- return (gapToGapCountDimensionsGapped, gapToGapPotentialGaps, activeGapNumber, index)
225
-
226
- _0, leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
227
-
228
- unconstrainedLeaf = jax.numpy.int32(0)
229
- dimension1ndex = jax.numpy.int32(1)
230
- gap1ndexCeiling = gapRangeStart.at[activeLeaf1ndex - 1].get().astype(jax.numpy.int32)
231
- activeGap1ndex = gap1ndexCeiling
232
- for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = (countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex)
233
- for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values = jax.lax.while_loop(for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1_do, for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values)
234
- countDimensionsGapped, gapsWhere, gap1ndexCeiling, unconstrainedLeaf, dimension1ndex = for_dimension1ndex_in_range_1_to_dimensionsTotalPlus1Values
235
- del dimension1ndex
236
-
237
- leaf1ndex = jax.numpy.int32(0)
238
- for_leaf1ndex_in_range_activeLeaf1ndexValues = (gapsWhere, gap1ndexCeiling, leaf1ndex)
239
- for_leaf1ndex_in_range_activeLeaf1ndexValues = jax.lax.cond(almostUselessCondition(unconstrainedLeaf), almostUselessConditionDo, doNothing, for_leaf1ndex_in_range_activeLeaf1ndexValues)
240
- gapsWhere, gap1ndexCeiling, leaf1ndex = for_leaf1ndex_in_range_activeLeaf1ndexValues
241
- del leaf1ndex
242
-
243
- indexMiniGap = activeGap1ndex
244
- miniGapValues = (countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap)
245
- miniGapValues = jax.lax.while_loop(for_range_from_activeGap1ndex_to_gap1ndexCeiling, miniGapDo, miniGapValues)
246
- countDimensionsGapped, gapsWhere, activeGap1ndex, indexMiniGap = miniGapValues
247
- del indexMiniGap
248
-
249
- return (allValues[0], leafBelow, countDimensionsGapped, gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
250
-
251
- def incrementCondition(leafBelowSentinel, activeLeafNumber):
252
- return jax.numpy.logical_and(activeLeafNumber > leavesTotal, leafBelowSentinel == 1)
253
-
254
- def incrementDo(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
255
- foldingsSubTotal = allValues[5]
256
- foldingsSubTotal = leavesTotal + foldingsSubTotal
257
- return (allValues[0], allValues[1], allValues[2], allValues[3], allValues[4], foldingsSubTotal, allValues[6], allValues[7])
258
-
259
- def dao(allValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32, jaxtyping.UInt32]):
260
- def whileBacktrackingCondition(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
261
- comparand = backtrackingValues[2]
262
- return jax.numpy.logical_and(comparand > 0, jax.numpy.equal(activeGap1ndex, gapRangeStart.at[comparand - 1].get()))
263
-
264
- def whileBacktrackingDo(backtrackingValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32]):
265
- backtrackAbove, backtrackBelow, activeLeafNumber = backtrackingValues
266
-
267
- activeLeafNumber = activeLeafNumber - 1
268
- backtrackBelow = backtrackBelow.at[backtrackAbove.at[activeLeafNumber].get()].set(backtrackBelow.at[activeLeafNumber].get())
269
- backtrackAbove = backtrackAbove.at[backtrackBelow.at[activeLeafNumber].get()].set(backtrackAbove.at[activeLeafNumber].get())
270
-
271
- return (backtrackAbove, backtrackBelow, activeLeafNumber)
272
-
273
- def if_activeLeaf1ndex_greaterThan_0(activeLeafNumber):
274
- return activeLeafNumber > 0
275
-
276
- def if_activeLeaf1ndex_greaterThan_0_do(leafPlacementValues: Tuple[jaxtyping.Array, jaxtyping.Array, jaxtyping.Array, jaxtyping.UInt32, jaxtyping.UInt32]):
277
- placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber = leafPlacementValues
278
- activeGapNumber = activeGapNumber - 1
279
- placeLeafAbove = placeLeafAbove.at[activeLeafNumber].set(gapsWhere.at[activeGapNumber].get())
280
- placeLeafBelow = placeLeafBelow.at[activeLeafNumber].set(placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].get())
281
- placeLeafBelow = placeLeafBelow.at[placeLeafAbove.at[activeLeafNumber].get()].set(activeLeafNumber)
282
- placeLeafAbove = placeLeafAbove.at[placeLeafBelow.at[activeLeafNumber].get()].set(activeLeafNumber)
283
- placeGapRangeStart = placeGapRangeStart.at[activeLeafNumber].set(activeGapNumber)
284
-
285
- activeLeafNumber = 1 + activeLeafNumber
286
- return (placeLeafAbove, placeLeafBelow, placeGapRangeStart, activeLeafNumber, activeGapNumber)
287
-
288
- leafAbove, leafBelow, _2, gapRangeStart, gapsWhere, _5, activeLeaf1ndex, activeGap1ndex = allValues
289
-
290
- whileBacktrackingValues = (leafAbove, leafBelow, activeLeaf1ndex)
291
- whileBacktrackingValues = jax.lax.while_loop(whileBacktrackingCondition, whileBacktrackingDo, whileBacktrackingValues)
292
- leafAbove, leafBelow, activeLeaf1ndex = whileBacktrackingValues
293
-
294
- if_activeLeaf1ndex_greaterThan_0_values = (leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex)
295
- if_activeLeaf1ndex_greaterThan_0_values = jax.lax.cond(if_activeLeaf1ndex_greaterThan_0(activeLeaf1ndex), if_activeLeaf1ndex_greaterThan_0_do, doNothing, if_activeLeaf1ndex_greaterThan_0_values)
296
- leafAbove, leafBelow, gapRangeStart, activeLeaf1ndex, activeGap1ndex = if_activeLeaf1ndex_greaterThan_0_values
297
-
298
- return (leafAbove, leafBelow, allValues[2], gapRangeStart, gapsWhere, allValues[5], activeLeaf1ndex, activeGap1ndex)
299
-
300
- # Dynamic values
301
- A = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
302
- B = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
303
- count = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
304
- gapter = jax.numpy.zeros(leavesTotal + 1, dtype=dtypeDefault)
305
- gap = jax.numpy.zeros(leavesTotal * leavesTotal + 1, dtype=dtypeMaximum)
306
-
307
- foldingsSubTotal = jax.numpy.int32(0)
308
- l = jax.numpy.int32(1)
309
- g = jax.numpy.int32(0)
310
-
311
- foldingsValues = (A, B, count, gapter, gap, foldingsSubTotal, l, g)
312
- foldingsValues = jax.lax.while_loop(while_activeLeaf1ndex_greaterThan_0, countFoldings, foldingsValues)
313
- return foldingsValues[5]