pytest-balance 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytest_balance/__init__.py +3 -0
- pytest_balance/__main__.py +5 -0
- pytest_balance/algorithms/__init__.py +0 -0
- pytest_balance/algorithms/lpt.py +68 -0
- pytest_balance/algorithms/partitioner.py +62 -0
- pytest_balance/ci/__init__.py +0 -0
- pytest_balance/ci/detect.py +209 -0
- pytest_balance/ci/splitter.py +46 -0
- pytest_balance/cli.py +367 -0
- pytest_balance/plugin.py +255 -0
- pytest_balance/py.typed +0 -0
- pytest_balance/report.py +64 -0
- pytest_balance/store/__init__.py +0 -0
- pytest_balance/store/merger.py +53 -0
- pytest_balance/store/models.py +41 -0
- pytest_balance/store/reader.py +94 -0
- pytest_balance/store/writer.py +53 -0
- pytest_balance/xdist/__init__.py +0 -0
- pytest_balance/xdist/hooks.py +37 -0
- pytest_balance/xdist/scheduler.py +154 -0
- pytest_balance-0.1.0.dist-info/METADATA +358 -0
- pytest_balance-0.1.0.dist-info/RECORD +25 -0
- pytest_balance-0.1.0.dist-info/WHEEL +4 -0
- pytest_balance-0.1.0.dist-info/entry_points.txt +5 -0
- pytest_balance-0.1.0.dist-info/licenses/LICENSE +21 -0
|
File without changes
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Longest Processing Time First (LPT) partitioning algorithm."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import heapq
|
|
6
|
+
|
|
7
|
+
from pytest_balance.algorithms.partitioner import Scope, group_by_scope
|
|
8
|
+
from pytest_balance.store.models import DurationEstimate
|
|
9
|
+
from pytest_balance.store.reader import default_estimate
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def partition(durations: dict[str, float], n: int) -> list[list[str]]:
|
|
13
|
+
"""Partition items into n buckets minimizing makespan using LPT.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
durations: Mapping of item ID to estimated duration.
|
|
17
|
+
n: Number of buckets.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
List of n lists, each containing item IDs assigned to that bucket.
|
|
21
|
+
"""
|
|
22
|
+
if n <= 0:
|
|
23
|
+
raise ValueError(f"n must be >= 1, got {n}")
|
|
24
|
+
|
|
25
|
+
buckets: list[list[str]] = [[] for _ in range(n)]
|
|
26
|
+
|
|
27
|
+
if not durations:
|
|
28
|
+
return buckets
|
|
29
|
+
|
|
30
|
+
# Sort by duration descending, break ties by name for determinism
|
|
31
|
+
sorted_items = sorted(durations.items(), key=lambda x: (-x[1], x[0]))
|
|
32
|
+
|
|
33
|
+
# Min-heap of (total_duration, bucket_index)
|
|
34
|
+
heap: list[tuple[float, int]] = [(0.0, i) for i in range(n)]
|
|
35
|
+
heapq.heapify(heap)
|
|
36
|
+
|
|
37
|
+
for item_id, _ in sorted_items:
|
|
38
|
+
total, idx = heapq.heappop(heap)
|
|
39
|
+
buckets[idx].append(item_id)
|
|
40
|
+
heapq.heappush(heap, (total + durations[item_id], idx))
|
|
41
|
+
|
|
42
|
+
return buckets
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def compute_order(
|
|
46
|
+
collection: list[str],
|
|
47
|
+
estimates: dict[str, DurationEstimate],
|
|
48
|
+
scope: Scope,
|
|
49
|
+
) -> list[int]:
|
|
50
|
+
"""Return indices of `collection` ordered LPT scope-adjacent.
|
|
51
|
+
|
|
52
|
+
Groups are built by scope, sorted by descending estimated duration with a
|
|
53
|
+
lexicographic tie-break on scope_id, and tests of the same group are
|
|
54
|
+
emitted consecutively. Pure function: deterministic, no side effects.
|
|
55
|
+
"""
|
|
56
|
+
if not collection:
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
fallback = default_estimate(estimates)
|
|
60
|
+
groups = group_by_scope(collection, scope)
|
|
61
|
+
for group in groups:
|
|
62
|
+
group.estimated_duration = sum(
|
|
63
|
+
estimates.get(tid, fallback).estimate for tid in group.test_ids
|
|
64
|
+
)
|
|
65
|
+
groups.sort(key=lambda g: (-g.estimated_duration, g.scope_id))
|
|
66
|
+
|
|
67
|
+
test_id_to_index = {tid: idx for idx, tid in enumerate(collection)}
|
|
68
|
+
return [test_id_to_index[tid] for group in groups for tid in group.test_ids]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Scope-aware test grouping for balanced partitioning."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import enum
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Scope(enum.Enum):
|
|
11
|
+
TEST = "test"
|
|
12
|
+
CLASS = "class"
|
|
13
|
+
MODULE = "module"
|
|
14
|
+
GROUP = "group"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class TestGroup:
|
|
19
|
+
scope_id: str
|
|
20
|
+
test_ids: list[str] = field(default_factory=list)
|
|
21
|
+
estimated_duration: float = 0.0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def extract_scope(test_id: str, scope: Scope) -> str:
|
|
25
|
+
"""Extract the scope key from a pytest node ID."""
|
|
26
|
+
if scope == Scope.TEST:
|
|
27
|
+
return test_id
|
|
28
|
+
|
|
29
|
+
if scope == Scope.GROUP:
|
|
30
|
+
bracket_pos = test_id.rfind("]")
|
|
31
|
+
at_pos = test_id.rfind("@")
|
|
32
|
+
if at_pos > bracket_pos:
|
|
33
|
+
return test_id[at_pos + 1 :]
|
|
34
|
+
return test_id
|
|
35
|
+
|
|
36
|
+
parts = test_id.split("::")
|
|
37
|
+
file_part = parts[0] if parts else test_id
|
|
38
|
+
|
|
39
|
+
if scope == Scope.MODULE:
|
|
40
|
+
return file_part
|
|
41
|
+
|
|
42
|
+
if scope == Scope.CLASS:
|
|
43
|
+
if len(parts) >= 3:
|
|
44
|
+
return f"{parts[0]}::{parts[1]}"
|
|
45
|
+
return test_id
|
|
46
|
+
|
|
47
|
+
return test_id
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def group_by_scope(test_ids: list[str], scope: Scope) -> list[TestGroup]:
|
|
51
|
+
"""Group test IDs by scope, preserving order within groups."""
|
|
52
|
+
if not test_ids:
|
|
53
|
+
return []
|
|
54
|
+
|
|
55
|
+
groups: OrderedDict[str, TestGroup] = OrderedDict()
|
|
56
|
+
for test_id in test_ids:
|
|
57
|
+
scope_key = extract_scope(test_id, scope)
|
|
58
|
+
if scope_key not in groups:
|
|
59
|
+
groups[scope_key] = TestGroup(scope_id=scope_key)
|
|
60
|
+
groups[scope_key].test_ids.append(test_id)
|
|
61
|
+
|
|
62
|
+
return list(groups.values())
|
|
File without changes
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""CI environment auto-detection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True, slots=True)
|
|
11
|
+
class CIContext:
|
|
12
|
+
provider: str
|
|
13
|
+
node_index: int
|
|
14
|
+
node_total: int
|
|
15
|
+
run_id: str
|
|
16
|
+
branch: str | None = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def detect_ci(
|
|
20
|
+
explicit_index: int | None = None,
|
|
21
|
+
explicit_total: int | None = None,
|
|
22
|
+
) -> CIContext | None:
|
|
23
|
+
ctx = _detect_provider()
|
|
24
|
+
|
|
25
|
+
if explicit_index is not None and explicit_total is not None:
|
|
26
|
+
if ctx is not None:
|
|
27
|
+
ctx = CIContext(
|
|
28
|
+
provider=ctx.provider,
|
|
29
|
+
node_index=explicit_index,
|
|
30
|
+
node_total=explicit_total,
|
|
31
|
+
run_id=ctx.run_id,
|
|
32
|
+
branch=ctx.branch,
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
ctx = CIContext(
|
|
36
|
+
provider="explicit",
|
|
37
|
+
node_index=explicit_index,
|
|
38
|
+
node_total=explicit_total,
|
|
39
|
+
run_id="local",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
if ctx is None:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
_validate(ctx)
|
|
46
|
+
return ctx
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _detect_provider() -> CIContext | None:
|
|
50
|
+
for detector in [
|
|
51
|
+
_detect_github,
|
|
52
|
+
_detect_gitlab,
|
|
53
|
+
_detect_circleci,
|
|
54
|
+
_detect_azure,
|
|
55
|
+
_detect_buildkite,
|
|
56
|
+
_detect_generic,
|
|
57
|
+
]:
|
|
58
|
+
ctx = detector()
|
|
59
|
+
if ctx is not None:
|
|
60
|
+
return ctx
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _detect_github() -> CIContext | None:
|
|
65
|
+
if os.environ.get("GITHUB_ACTIONS") != "true":
|
|
66
|
+
return None
|
|
67
|
+
index = os.environ.get("PYTEST_BALANCE_NODE_INDEX")
|
|
68
|
+
total = os.environ.get("PYTEST_BALANCE_NODE_TOTAL")
|
|
69
|
+
if index is None or total is None:
|
|
70
|
+
return None
|
|
71
|
+
try:
|
|
72
|
+
index_int, total_int = int(index), int(total)
|
|
73
|
+
except ValueError:
|
|
74
|
+
warnings.warn(
|
|
75
|
+
f"Non-integer balance env var: NODE_INDEX={index!r}, NODE_TOTAL={total!r}",
|
|
76
|
+
UserWarning,
|
|
77
|
+
stacklevel=2,
|
|
78
|
+
)
|
|
79
|
+
return None
|
|
80
|
+
run_id = os.environ.get("GITHUB_RUN_ID", "unknown")
|
|
81
|
+
attempt = os.environ.get("GITHUB_RUN_ATTEMPT", "1")
|
|
82
|
+
branch = os.environ.get("GITHUB_REF_NAME")
|
|
83
|
+
return CIContext("github", index_int, total_int, f"{run_id}-{attempt}", branch)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _detect_gitlab() -> CIContext | None:
|
|
87
|
+
if os.environ.get("GITLAB_CI") != "true":
|
|
88
|
+
return None
|
|
89
|
+
index = os.environ.get("CI_NODE_INDEX")
|
|
90
|
+
total = os.environ.get("CI_NODE_TOTAL")
|
|
91
|
+
if index is None or total is None:
|
|
92
|
+
return None
|
|
93
|
+
try:
|
|
94
|
+
index_int, total_int = int(index) - 1, int(total)
|
|
95
|
+
except ValueError:
|
|
96
|
+
warnings.warn(
|
|
97
|
+
f"Non-integer CI env var: CI_NODE_INDEX={index!r}, CI_NODE_TOTAL={total!r}",
|
|
98
|
+
UserWarning,
|
|
99
|
+
stacklevel=2,
|
|
100
|
+
)
|
|
101
|
+
return None
|
|
102
|
+
return CIContext(
|
|
103
|
+
"gitlab",
|
|
104
|
+
index_int,
|
|
105
|
+
total_int,
|
|
106
|
+
os.environ.get("CI_PIPELINE_ID", "unknown"),
|
|
107
|
+
os.environ.get("CI_COMMIT_REF_NAME"),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _detect_circleci() -> CIContext | None:
|
|
112
|
+
if os.environ.get("CIRCLECI") != "true":
|
|
113
|
+
return None
|
|
114
|
+
index = os.environ.get("CIRCLE_NODE_INDEX")
|
|
115
|
+
total = os.environ.get("CIRCLE_NODE_TOTAL")
|
|
116
|
+
if index is None or total is None:
|
|
117
|
+
return None
|
|
118
|
+
try:
|
|
119
|
+
index_int, total_int = int(index), int(total)
|
|
120
|
+
except ValueError:
|
|
121
|
+
warnings.warn(
|
|
122
|
+
f"Non-integer CI env var: CIRCLE_NODE_INDEX={index!r}, CIRCLE_NODE_TOTAL={total!r}",
|
|
123
|
+
UserWarning,
|
|
124
|
+
stacklevel=2,
|
|
125
|
+
)
|
|
126
|
+
return None
|
|
127
|
+
return CIContext(
|
|
128
|
+
"circleci",
|
|
129
|
+
index_int,
|
|
130
|
+
total_int,
|
|
131
|
+
os.environ.get("CIRCLE_BUILD_NUM", "unknown"),
|
|
132
|
+
os.environ.get("CIRCLE_BRANCH"),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _detect_azure() -> CIContext | None:
|
|
137
|
+
if os.environ.get("TF_BUILD") != "True":
|
|
138
|
+
return None
|
|
139
|
+
index = os.environ.get("SYSTEM_JOBPOSITIONINPHASE")
|
|
140
|
+
total = os.environ.get("SYSTEM_TOTALJOBSINPHASE")
|
|
141
|
+
if index is None or total is None:
|
|
142
|
+
return None
|
|
143
|
+
try:
|
|
144
|
+
index_int, total_int = int(index) - 1, int(total)
|
|
145
|
+
except ValueError:
|
|
146
|
+
warnings.warn(
|
|
147
|
+
f"Non-integer CI env var: SYSTEM_JOBPOSITIONINPHASE={index!r},"
|
|
148
|
+
f" SYSTEM_TOTALJOBSINPHASE={total!r}",
|
|
149
|
+
UserWarning,
|
|
150
|
+
stacklevel=2,
|
|
151
|
+
)
|
|
152
|
+
return None
|
|
153
|
+
return CIContext(
|
|
154
|
+
"azure",
|
|
155
|
+
index_int,
|
|
156
|
+
total_int,
|
|
157
|
+
os.environ.get("BUILD_BUILDID", "unknown"),
|
|
158
|
+
os.environ.get("BUILD_SOURCEBRANCH"),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _detect_buildkite() -> CIContext | None:
|
|
163
|
+
if os.environ.get("BUILDKITE") != "true":
|
|
164
|
+
return None
|
|
165
|
+
index = os.environ.get("BUILDKITE_PARALLEL_JOB")
|
|
166
|
+
total = os.environ.get("BUILDKITE_PARALLEL_JOB_COUNT")
|
|
167
|
+
if index is None or total is None:
|
|
168
|
+
return None
|
|
169
|
+
try:
|
|
170
|
+
index_int, total_int = int(index), int(total)
|
|
171
|
+
except ValueError:
|
|
172
|
+
warnings.warn(
|
|
173
|
+
f"Non-integer CI env var: BUILDKITE_PARALLEL_JOB={index!r},"
|
|
174
|
+
f" BUILDKITE_PARALLEL_JOB_COUNT={total!r}",
|
|
175
|
+
UserWarning,
|
|
176
|
+
stacklevel=2,
|
|
177
|
+
)
|
|
178
|
+
return None
|
|
179
|
+
return CIContext(
|
|
180
|
+
"buildkite",
|
|
181
|
+
index_int,
|
|
182
|
+
total_int,
|
|
183
|
+
os.environ.get("BUILDKITE_BUILD_ID", "unknown"),
|
|
184
|
+
os.environ.get("BUILDKITE_BRANCH"),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _detect_generic() -> CIContext | None:
|
|
189
|
+
index = os.environ.get("PYTEST_BALANCE_NODE_INDEX")
|
|
190
|
+
total = os.environ.get("PYTEST_BALANCE_NODE_TOTAL")
|
|
191
|
+
if index is None or total is None:
|
|
192
|
+
return None
|
|
193
|
+
try:
|
|
194
|
+
index_int, total_int = int(index), int(total)
|
|
195
|
+
except ValueError:
|
|
196
|
+
warnings.warn(
|
|
197
|
+
f"Non-integer balance env var: NODE_INDEX={index!r}, NODE_TOTAL={total!r}",
|
|
198
|
+
UserWarning,
|
|
199
|
+
stacklevel=2,
|
|
200
|
+
)
|
|
201
|
+
return None
|
|
202
|
+
return CIContext("unknown", index_int, total_int, "generic")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _validate(ctx: CIContext) -> None:
|
|
206
|
+
if ctx.node_total < 1:
|
|
207
|
+
raise ValueError(f"node_total must be >= 1, got {ctx.node_total}")
|
|
208
|
+
if not (0 <= ctx.node_index < ctx.node_total):
|
|
209
|
+
raise ValueError(f"node_index {ctx.node_index} out of range for {ctx.node_total} nodes")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""CI-level test splitting using LPT and scope-aware grouping."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pytest_balance.algorithms.lpt import partition
|
|
6
|
+
from pytest_balance.algorithms.partitioner import Scope, group_by_scope
|
|
7
|
+
from pytest_balance.store.models import DurationEstimate
|
|
8
|
+
from pytest_balance.store.reader import default_estimate
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def split_tests(
|
|
12
|
+
tests: list[str],
|
|
13
|
+
estimates: dict[str, DurationEstimate],
|
|
14
|
+
node_index: int,
|
|
15
|
+
node_total: int,
|
|
16
|
+
scope: Scope,
|
|
17
|
+
) -> list[str]:
|
|
18
|
+
if not tests:
|
|
19
|
+
return []
|
|
20
|
+
|
|
21
|
+
fallback = default_estimate(estimates)
|
|
22
|
+
|
|
23
|
+
# Group tests by scope
|
|
24
|
+
groups = group_by_scope(tests, scope)
|
|
25
|
+
|
|
26
|
+
# Compute estimated duration for each group
|
|
27
|
+
group_durations: dict[str, float] = {}
|
|
28
|
+
for group in groups:
|
|
29
|
+
group.estimated_duration = sum(
|
|
30
|
+
estimates.get(tid, fallback).estimate for tid in group.test_ids
|
|
31
|
+
)
|
|
32
|
+
group_durations[group.scope_id] = group.estimated_duration
|
|
33
|
+
|
|
34
|
+
# LPT partition groups across nodes
|
|
35
|
+
buckets = partition(group_durations, node_total)
|
|
36
|
+
|
|
37
|
+
# Select groups for this node
|
|
38
|
+
selected_scopes = set(buckets[node_index])
|
|
39
|
+
|
|
40
|
+
# Flatten test IDs preserving original order within groups
|
|
41
|
+
selected_tests: list[str] = []
|
|
42
|
+
for group in groups:
|
|
43
|
+
if group.scope_id in selected_scopes:
|
|
44
|
+
selected_tests.extend(group.test_ids)
|
|
45
|
+
|
|
46
|
+
return selected_tests
|