pytest-balance 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ """Duration-based test distribution for pytest."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,5 @@
1
+ """Allow running pytest-balance as a module: python -m pytest_balance"""
2
+
3
+ from pytest_balance.cli import main
4
+
5
+ main()
File without changes
@@ -0,0 +1,68 @@
1
+ """Longest Processing Time First (LPT) partitioning algorithm."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import heapq
6
+
7
+ from pytest_balance.algorithms.partitioner import Scope, group_by_scope
8
+ from pytest_balance.store.models import DurationEstimate
9
+ from pytest_balance.store.reader import default_estimate
10
+
11
+
12
+ def partition(durations: dict[str, float], n: int) -> list[list[str]]:
13
+ """Partition items into n buckets minimizing makespan using LPT.
14
+
15
+ Args:
16
+ durations: Mapping of item ID to estimated duration.
17
+ n: Number of buckets.
18
+
19
+ Returns:
20
+ List of n lists, each containing item IDs assigned to that bucket.
21
+ """
22
+ if n <= 0:
23
+ raise ValueError(f"n must be >= 1, got {n}")
24
+
25
+ buckets: list[list[str]] = [[] for _ in range(n)]
26
+
27
+ if not durations:
28
+ return buckets
29
+
30
+ # Sort by duration descending, break ties by name for determinism
31
+ sorted_items = sorted(durations.items(), key=lambda x: (-x[1], x[0]))
32
+
33
+ # Min-heap of (total_duration, bucket_index)
34
+ heap: list[tuple[float, int]] = [(0.0, i) for i in range(n)]
35
+ heapq.heapify(heap)
36
+
37
+ for item_id, _ in sorted_items:
38
+ total, idx = heapq.heappop(heap)
39
+ buckets[idx].append(item_id)
40
+ heapq.heappush(heap, (total + durations[item_id], idx))
41
+
42
+ return buckets
43
+
44
+
45
+ def compute_order(
46
+ collection: list[str],
47
+ estimates: dict[str, DurationEstimate],
48
+ scope: Scope,
49
+ ) -> list[int]:
50
+ """Return indices of `collection` ordered LPT scope-adjacent.
51
+
52
+ Groups are built by scope, sorted by descending estimated duration with a
53
+ lexicographic tie-break on scope_id, and tests of the same group are
54
+ emitted consecutively. Pure function: deterministic, no side effects.
55
+ """
56
+ if not collection:
57
+ return []
58
+
59
+ fallback = default_estimate(estimates)
60
+ groups = group_by_scope(collection, scope)
61
+ for group in groups:
62
+ group.estimated_duration = sum(
63
+ estimates.get(tid, fallback).estimate for tid in group.test_ids
64
+ )
65
+ groups.sort(key=lambda g: (-g.estimated_duration, g.scope_id))
66
+
67
+ test_id_to_index = {tid: idx for idx, tid in enumerate(collection)}
68
+ return [test_id_to_index[tid] for group in groups for tid in group.test_ids]
@@ -0,0 +1,62 @@
1
+ """Scope-aware test grouping for balanced partitioning."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+ from collections import OrderedDict
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ class Scope(enum.Enum):
11
+ TEST = "test"
12
+ CLASS = "class"
13
+ MODULE = "module"
14
+ GROUP = "group"
15
+
16
+
17
+ @dataclass
18
+ class TestGroup:
19
+ scope_id: str
20
+ test_ids: list[str] = field(default_factory=list)
21
+ estimated_duration: float = 0.0
22
+
23
+
24
+ def extract_scope(test_id: str, scope: Scope) -> str:
25
+ """Extract the scope key from a pytest node ID."""
26
+ if scope == Scope.TEST:
27
+ return test_id
28
+
29
+ if scope == Scope.GROUP:
30
+ bracket_pos = test_id.rfind("]")
31
+ at_pos = test_id.rfind("@")
32
+ if at_pos > bracket_pos:
33
+ return test_id[at_pos + 1 :]
34
+ return test_id
35
+
36
+ parts = test_id.split("::")
37
+ file_part = parts[0] if parts else test_id
38
+
39
+ if scope == Scope.MODULE:
40
+ return file_part
41
+
42
+ if scope == Scope.CLASS:
43
+ if len(parts) >= 3:
44
+ return f"{parts[0]}::{parts[1]}"
45
+ return test_id
46
+
47
+ return test_id
48
+
49
+
50
+ def group_by_scope(test_ids: list[str], scope: Scope) -> list[TestGroup]:
51
+ """Group test IDs by scope, preserving order within groups."""
52
+ if not test_ids:
53
+ return []
54
+
55
+ groups: OrderedDict[str, TestGroup] = OrderedDict()
56
+ for test_id in test_ids:
57
+ scope_key = extract_scope(test_id, scope)
58
+ if scope_key not in groups:
59
+ groups[scope_key] = TestGroup(scope_id=scope_key)
60
+ groups[scope_key].test_ids.append(test_id)
61
+
62
+ return list(groups.values())
File without changes
@@ -0,0 +1,209 @@
1
+ """CI environment auto-detection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import warnings
7
+ from dataclasses import dataclass
8
+
9
+
10
+ @dataclass(frozen=True, slots=True)
11
+ class CIContext:
12
+ provider: str
13
+ node_index: int
14
+ node_total: int
15
+ run_id: str
16
+ branch: str | None = None
17
+
18
+
19
+ def detect_ci(
20
+ explicit_index: int | None = None,
21
+ explicit_total: int | None = None,
22
+ ) -> CIContext | None:
23
+ ctx = _detect_provider()
24
+
25
+ if explicit_index is not None and explicit_total is not None:
26
+ if ctx is not None:
27
+ ctx = CIContext(
28
+ provider=ctx.provider,
29
+ node_index=explicit_index,
30
+ node_total=explicit_total,
31
+ run_id=ctx.run_id,
32
+ branch=ctx.branch,
33
+ )
34
+ else:
35
+ ctx = CIContext(
36
+ provider="explicit",
37
+ node_index=explicit_index,
38
+ node_total=explicit_total,
39
+ run_id="local",
40
+ )
41
+
42
+ if ctx is None:
43
+ return None
44
+
45
+ _validate(ctx)
46
+ return ctx
47
+
48
+
49
+ def _detect_provider() -> CIContext | None:
50
+ for detector in [
51
+ _detect_github,
52
+ _detect_gitlab,
53
+ _detect_circleci,
54
+ _detect_azure,
55
+ _detect_buildkite,
56
+ _detect_generic,
57
+ ]:
58
+ ctx = detector()
59
+ if ctx is not None:
60
+ return ctx
61
+ return None
62
+
63
+
64
+ def _detect_github() -> CIContext | None:
65
+ if os.environ.get("GITHUB_ACTIONS") != "true":
66
+ return None
67
+ index = os.environ.get("PYTEST_BALANCE_NODE_INDEX")
68
+ total = os.environ.get("PYTEST_BALANCE_NODE_TOTAL")
69
+ if index is None or total is None:
70
+ return None
71
+ try:
72
+ index_int, total_int = int(index), int(total)
73
+ except ValueError:
74
+ warnings.warn(
75
+ f"Non-integer balance env var: NODE_INDEX={index!r}, NODE_TOTAL={total!r}",
76
+ UserWarning,
77
+ stacklevel=2,
78
+ )
79
+ return None
80
+ run_id = os.environ.get("GITHUB_RUN_ID", "unknown")
81
+ attempt = os.environ.get("GITHUB_RUN_ATTEMPT", "1")
82
+ branch = os.environ.get("GITHUB_REF_NAME")
83
+ return CIContext("github", index_int, total_int, f"{run_id}-{attempt}", branch)
84
+
85
+
86
+ def _detect_gitlab() -> CIContext | None:
87
+ if os.environ.get("GITLAB_CI") != "true":
88
+ return None
89
+ index = os.environ.get("CI_NODE_INDEX")
90
+ total = os.environ.get("CI_NODE_TOTAL")
91
+ if index is None or total is None:
92
+ return None
93
+ try:
94
+ index_int, total_int = int(index) - 1, int(total)
95
+ except ValueError:
96
+ warnings.warn(
97
+ f"Non-integer CI env var: CI_NODE_INDEX={index!r}, CI_NODE_TOTAL={total!r}",
98
+ UserWarning,
99
+ stacklevel=2,
100
+ )
101
+ return None
102
+ return CIContext(
103
+ "gitlab",
104
+ index_int,
105
+ total_int,
106
+ os.environ.get("CI_PIPELINE_ID", "unknown"),
107
+ os.environ.get("CI_COMMIT_REF_NAME"),
108
+ )
109
+
110
+
111
+ def _detect_circleci() -> CIContext | None:
112
+ if os.environ.get("CIRCLECI") != "true":
113
+ return None
114
+ index = os.environ.get("CIRCLE_NODE_INDEX")
115
+ total = os.environ.get("CIRCLE_NODE_TOTAL")
116
+ if index is None or total is None:
117
+ return None
118
+ try:
119
+ index_int, total_int = int(index), int(total)
120
+ except ValueError:
121
+ warnings.warn(
122
+ f"Non-integer CI env var: CIRCLE_NODE_INDEX={index!r}, CIRCLE_NODE_TOTAL={total!r}",
123
+ UserWarning,
124
+ stacklevel=2,
125
+ )
126
+ return None
127
+ return CIContext(
128
+ "circleci",
129
+ index_int,
130
+ total_int,
131
+ os.environ.get("CIRCLE_BUILD_NUM", "unknown"),
132
+ os.environ.get("CIRCLE_BRANCH"),
133
+ )
134
+
135
+
136
+ def _detect_azure() -> CIContext | None:
137
+ if os.environ.get("TF_BUILD") != "True":
138
+ return None
139
+ index = os.environ.get("SYSTEM_JOBPOSITIONINPHASE")
140
+ total = os.environ.get("SYSTEM_TOTALJOBSINPHASE")
141
+ if index is None or total is None:
142
+ return None
143
+ try:
144
+ index_int, total_int = int(index) - 1, int(total)
145
+ except ValueError:
146
+ warnings.warn(
147
+ f"Non-integer CI env var: SYSTEM_JOBPOSITIONINPHASE={index!r},"
148
+ f" SYSTEM_TOTALJOBSINPHASE={total!r}",
149
+ UserWarning,
150
+ stacklevel=2,
151
+ )
152
+ return None
153
+ return CIContext(
154
+ "azure",
155
+ index_int,
156
+ total_int,
157
+ os.environ.get("BUILD_BUILDID", "unknown"),
158
+ os.environ.get("BUILD_SOURCEBRANCH"),
159
+ )
160
+
161
+
162
+ def _detect_buildkite() -> CIContext | None:
163
+ if os.environ.get("BUILDKITE") != "true":
164
+ return None
165
+ index = os.environ.get("BUILDKITE_PARALLEL_JOB")
166
+ total = os.environ.get("BUILDKITE_PARALLEL_JOB_COUNT")
167
+ if index is None or total is None:
168
+ return None
169
+ try:
170
+ index_int, total_int = int(index), int(total)
171
+ except ValueError:
172
+ warnings.warn(
173
+ f"Non-integer CI env var: BUILDKITE_PARALLEL_JOB={index!r},"
174
+ f" BUILDKITE_PARALLEL_JOB_COUNT={total!r}",
175
+ UserWarning,
176
+ stacklevel=2,
177
+ )
178
+ return None
179
+ return CIContext(
180
+ "buildkite",
181
+ index_int,
182
+ total_int,
183
+ os.environ.get("BUILDKITE_BUILD_ID", "unknown"),
184
+ os.environ.get("BUILDKITE_BRANCH"),
185
+ )
186
+
187
+
188
+ def _detect_generic() -> CIContext | None:
189
+ index = os.environ.get("PYTEST_BALANCE_NODE_INDEX")
190
+ total = os.environ.get("PYTEST_BALANCE_NODE_TOTAL")
191
+ if index is None or total is None:
192
+ return None
193
+ try:
194
+ index_int, total_int = int(index), int(total)
195
+ except ValueError:
196
+ warnings.warn(
197
+ f"Non-integer balance env var: NODE_INDEX={index!r}, NODE_TOTAL={total!r}",
198
+ UserWarning,
199
+ stacklevel=2,
200
+ )
201
+ return None
202
+ return CIContext("unknown", index_int, total_int, "generic")
203
+
204
+
205
+ def _validate(ctx: CIContext) -> None:
206
+ if ctx.node_total < 1:
207
+ raise ValueError(f"node_total must be >= 1, got {ctx.node_total}")
208
+ if not (0 <= ctx.node_index < ctx.node_total):
209
+ raise ValueError(f"node_index {ctx.node_index} out of range for {ctx.node_total} nodes")
@@ -0,0 +1,46 @@
1
+ """CI-level test splitting using LPT and scope-aware grouping."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pytest_balance.algorithms.lpt import partition
6
+ from pytest_balance.algorithms.partitioner import Scope, group_by_scope
7
+ from pytest_balance.store.models import DurationEstimate
8
+ from pytest_balance.store.reader import default_estimate
9
+
10
+
11
+ def split_tests(
12
+ tests: list[str],
13
+ estimates: dict[str, DurationEstimate],
14
+ node_index: int,
15
+ node_total: int,
16
+ scope: Scope,
17
+ ) -> list[str]:
18
+ if not tests:
19
+ return []
20
+
21
+ fallback = default_estimate(estimates)
22
+
23
+ # Group tests by scope
24
+ groups = group_by_scope(tests, scope)
25
+
26
+ # Compute estimated duration for each group
27
+ group_durations: dict[str, float] = {}
28
+ for group in groups:
29
+ group.estimated_duration = sum(
30
+ estimates.get(tid, fallback).estimate for tid in group.test_ids
31
+ )
32
+ group_durations[group.scope_id] = group.estimated_duration
33
+
34
+ # LPT partition groups across nodes
35
+ buckets = partition(group_durations, node_total)
36
+
37
+ # Select groups for this node
38
+ selected_scopes = set(buckets[node_index])
39
+
40
+ # Flatten test IDs preserving original order within groups
41
+ selected_tests: list[str] = []
42
+ for group in groups:
43
+ if group.scope_id in selected_scopes:
44
+ selected_tests.extend(group.test_ids)
45
+
46
+ return selected_tests