bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1141 @@
|
|
|
1
|
+
"""List partitioning for experimental item distribution.
|
|
2
|
+
|
|
3
|
+
This module provides the ListPartitioner class for partitioning items into
|
|
4
|
+
balanced experimental lists. Implements three strategies: random, balanced,
|
|
5
|
+
and stratified. Uses stand-off annotation (works with UUIDs only).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections import Counter, defaultdict
|
|
11
|
+
from typing import Any
|
|
12
|
+
from uuid import UUID
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from bead.dsl.evaluator import DSLEvaluator
|
|
17
|
+
from bead.lists.balancer import QuantileBalancer
|
|
18
|
+
from bead.lists.constraints import (
|
|
19
|
+
BalanceConstraint,
|
|
20
|
+
BatchBalanceConstraint,
|
|
21
|
+
BatchConstraint,
|
|
22
|
+
BatchCoverageConstraint,
|
|
23
|
+
BatchDiversityConstraint,
|
|
24
|
+
BatchMinOccurrenceConstraint,
|
|
25
|
+
ListConstraint,
|
|
26
|
+
QuantileConstraint,
|
|
27
|
+
SizeConstraint,
|
|
28
|
+
UniquenessConstraint,
|
|
29
|
+
)
|
|
30
|
+
from bead.lists.experiment_list import ExperimentList, MetadataValue
|
|
31
|
+
from bead.resources.constraints import ContextValue
|
|
32
|
+
|
|
33
|
+
# Type aliases for clarity
|
|
34
|
+
type ItemMetadata = dict[str, Any] # Arbitrary item properties
|
|
35
|
+
type MetadataDict = dict[UUID, ItemMetadata] # Metadata indexed by UUID
|
|
36
|
+
type BalanceMetrics = dict[str, MetadataValue]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ListPartitioner:
|
|
40
|
+
"""Partitions items into balanced experimental lists.
|
|
41
|
+
|
|
42
|
+
Uses stand-off annotation: only stores UUIDs, not full item objects.
|
|
43
|
+
Requires item metadata dict for constraint checking and balancing.
|
|
44
|
+
|
|
45
|
+
Implements three partitioning strategies:
|
|
46
|
+
- Random: Simple round-robin after shuffling
|
|
47
|
+
- Balanced: Greedy algorithm to minimize constraint violations
|
|
48
|
+
- Stratified: Quantile-based stratification with balanced distribution
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
random_seed : int | None, default=None
|
|
53
|
+
Random seed for reproducibility.
|
|
54
|
+
|
|
55
|
+
Attributes
|
|
56
|
+
----------
|
|
57
|
+
random_seed : int | None
|
|
58
|
+
Random seed for reproducibility.
|
|
59
|
+
|
|
60
|
+
Examples
|
|
61
|
+
--------
|
|
62
|
+
>>> from uuid import uuid4
|
|
63
|
+
>>> partitioner = ListPartitioner(random_seed=42)
|
|
64
|
+
>>> items = [uuid4() for _ in range(100)]
|
|
65
|
+
>>> metadata = {uid: {"property": i} for i, uid in enumerate(items)}
|
|
66
|
+
>>> lists = partitioner.partition(items, n_lists=5, metadata=metadata)
|
|
67
|
+
>>> len(lists)
|
|
68
|
+
5
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, random_seed: int | None = None) -> None:
|
|
72
|
+
self.random_seed = random_seed
|
|
73
|
+
self._rng = np.random.default_rng(random_seed)
|
|
74
|
+
self.dsl_evaluator = DSLEvaluator()
|
|
75
|
+
|
|
76
|
+
def partition(
|
|
77
|
+
self,
|
|
78
|
+
items: list[UUID],
|
|
79
|
+
n_lists: int,
|
|
80
|
+
constraints: list[ListConstraint] | None = None,
|
|
81
|
+
strategy: str = "balanced",
|
|
82
|
+
metadata: MetadataDict | None = None,
|
|
83
|
+
) -> list[ExperimentList]:
|
|
84
|
+
"""Partition items into lists.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
items : list[UUID]
|
|
89
|
+
Item UUIDs to partition.
|
|
90
|
+
n_lists : int
|
|
91
|
+
Number of lists to create.
|
|
92
|
+
constraints : list[ListConstraint] | None, default=None
|
|
93
|
+
Constraints to satisfy.
|
|
94
|
+
strategy : str, default="balanced"
|
|
95
|
+
Partitioning strategy ("balanced", "random", "stratified").
|
|
96
|
+
metadata : dict[UUID, dict[str, Any]] | None, default=None
|
|
97
|
+
Metadata for each item UUID. Required for constraint checking.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
list[ExperimentList]
|
|
102
|
+
The partitioned lists.
|
|
103
|
+
|
|
104
|
+
Raises
|
|
105
|
+
------
|
|
106
|
+
ValueError
|
|
107
|
+
If strategy is unknown or n_lists < 1.
|
|
108
|
+
"""
|
|
109
|
+
if n_lists < 1:
|
|
110
|
+
raise ValueError(f"n_lists must be >= 1, got {n_lists}")
|
|
111
|
+
|
|
112
|
+
constraints = constraints or []
|
|
113
|
+
metadata = metadata or {}
|
|
114
|
+
|
|
115
|
+
# Select partitioning method based on strategy
|
|
116
|
+
match strategy:
|
|
117
|
+
case "balanced":
|
|
118
|
+
return self._partition_balanced(items, n_lists, constraints, metadata)
|
|
119
|
+
case "random":
|
|
120
|
+
return self._partition_random(items, n_lists, constraints, metadata)
|
|
121
|
+
case "stratified":
|
|
122
|
+
return self._partition_stratified(items, n_lists, constraints, metadata)
|
|
123
|
+
case _:
|
|
124
|
+
raise ValueError(f"Unknown strategy: {strategy}")
|
|
125
|
+
|
|
126
|
+
def _partition_random(
|
|
127
|
+
self,
|
|
128
|
+
items: list[UUID],
|
|
129
|
+
n_lists: int,
|
|
130
|
+
constraints: list[ListConstraint],
|
|
131
|
+
metadata: MetadataDict,
|
|
132
|
+
) -> list[ExperimentList]:
|
|
133
|
+
"""Partition items randomly.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
items : list[UUID]
|
|
138
|
+
Items to partition.
|
|
139
|
+
n_lists : int
|
|
140
|
+
Number of lists.
|
|
141
|
+
constraints : list[ListConstraint]
|
|
142
|
+
Constraints to attach to lists.
|
|
143
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
144
|
+
Item metadata.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
list[ExperimentList]
|
|
149
|
+
Partitioned lists.
|
|
150
|
+
"""
|
|
151
|
+
# Initialize lists
|
|
152
|
+
lists = [
|
|
153
|
+
ExperimentList(
|
|
154
|
+
name=f"list_{i}",
|
|
155
|
+
list_number=i,
|
|
156
|
+
list_constraints=constraints,
|
|
157
|
+
)
|
|
158
|
+
for i in range(n_lists)
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
# Shuffle and distribute round robin
|
|
162
|
+
items_shuffled = np.array(items)
|
|
163
|
+
self._rng.shuffle(items_shuffled)
|
|
164
|
+
|
|
165
|
+
for i, item_id in enumerate(items_shuffled):
|
|
166
|
+
list_idx = i % n_lists
|
|
167
|
+
lists[list_idx].add_item(item_id)
|
|
168
|
+
|
|
169
|
+
# Compute balance metrics for each list
|
|
170
|
+
for exp_list in lists:
|
|
171
|
+
exp_list.balance_metrics = self._compute_balance_metrics(
|
|
172
|
+
exp_list, constraints, metadata
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return lists
|
|
176
|
+
|
|
177
|
+
def _partition_balanced(
|
|
178
|
+
self,
|
|
179
|
+
items: list[UUID],
|
|
180
|
+
n_lists: int,
|
|
181
|
+
constraints: list[ListConstraint],
|
|
182
|
+
metadata: MetadataDict,
|
|
183
|
+
) -> list[ExperimentList]:
|
|
184
|
+
"""Partition items with balanced distribution.
|
|
185
|
+
|
|
186
|
+
Uses greedy algorithm to distribute items to minimize imbalance.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
items : list[UUID]
|
|
191
|
+
Items to partition.
|
|
192
|
+
n_lists : int
|
|
193
|
+
Number of lists.
|
|
194
|
+
constraints : list[ListConstraint]
|
|
195
|
+
Constraints to satisfy.
|
|
196
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
197
|
+
Item metadata.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
list[ExperimentList]
|
|
202
|
+
Partitioned lists.
|
|
203
|
+
"""
|
|
204
|
+
# Initialize lists
|
|
205
|
+
lists = [
|
|
206
|
+
ExperimentList(
|
|
207
|
+
name=f"list_{i}",
|
|
208
|
+
list_number=i,
|
|
209
|
+
list_constraints=constraints,
|
|
210
|
+
)
|
|
211
|
+
for i in range(n_lists)
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
# Shuffle items
|
|
215
|
+
items_shuffled = np.array(items)
|
|
216
|
+
self._rng.shuffle(items_shuffled)
|
|
217
|
+
|
|
218
|
+
# For each item, assign to list that best maintains balance
|
|
219
|
+
for item_id in items_shuffled:
|
|
220
|
+
best_list = self._find_best_list(item_id, lists, constraints, metadata)
|
|
221
|
+
best_list.add_item(item_id)
|
|
222
|
+
|
|
223
|
+
# Compute balance metrics for each list
|
|
224
|
+
for exp_list in lists:
|
|
225
|
+
exp_list.balance_metrics = self._compute_balance_metrics(
|
|
226
|
+
exp_list, constraints, metadata
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return lists
|
|
230
|
+
|
|
231
|
+
def _partition_stratified(
|
|
232
|
+
self,
|
|
233
|
+
items: list[UUID],
|
|
234
|
+
n_lists: int,
|
|
235
|
+
constraints: list[ListConstraint],
|
|
236
|
+
metadata: MetadataDict,
|
|
237
|
+
) -> list[ExperimentList]:
|
|
238
|
+
"""Partition items with stratification.
|
|
239
|
+
|
|
240
|
+
Creates strata based on quantile constraints and distributes
|
|
241
|
+
items from each stratum across lists.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
items : list[UUID]
|
|
246
|
+
Items to partition.
|
|
247
|
+
n_lists : int
|
|
248
|
+
Number of lists.
|
|
249
|
+
constraints : list[ListConstraint]
|
|
250
|
+
Constraints to satisfy (must include quantile constraints).
|
|
251
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
252
|
+
Item metadata.
|
|
253
|
+
|
|
254
|
+
Returns
|
|
255
|
+
-------
|
|
256
|
+
list[ExperimentList]
|
|
257
|
+
Partitioned lists.
|
|
258
|
+
"""
|
|
259
|
+
# Find quantile constraints
|
|
260
|
+
quantile_constraints = [
|
|
261
|
+
c for c in constraints if isinstance(c, QuantileConstraint)
|
|
262
|
+
]
|
|
263
|
+
|
|
264
|
+
if not quantile_constraints:
|
|
265
|
+
# Fall back to balanced
|
|
266
|
+
return self._partition_balanced(items, n_lists, constraints, metadata)
|
|
267
|
+
|
|
268
|
+
# Use first quantile constraint for stratification
|
|
269
|
+
qc = quantile_constraints[0]
|
|
270
|
+
|
|
271
|
+
# Create balancer
|
|
272
|
+
balancer = QuantileBalancer(
|
|
273
|
+
n_quantiles=qc.n_quantiles, random_seed=self.random_seed
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Create value function
|
|
277
|
+
def value_func(item_id: UUID) -> float:
|
|
278
|
+
return float(
|
|
279
|
+
self._extract_property_value(
|
|
280
|
+
item_id, qc.property_expression, qc.context, metadata
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Balance items across lists
|
|
285
|
+
balanced_lists = balancer.balance(
|
|
286
|
+
items, value_func, n_lists, qc.items_per_quantile
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Convert to ExperimentList objects
|
|
290
|
+
lists: list[ExperimentList] = []
|
|
291
|
+
for i, item_ids in enumerate(balanced_lists):
|
|
292
|
+
exp_list = ExperimentList(
|
|
293
|
+
name=f"list_{i}",
|
|
294
|
+
list_number=i,
|
|
295
|
+
list_constraints=constraints,
|
|
296
|
+
)
|
|
297
|
+
for item_id in item_ids:
|
|
298
|
+
exp_list.add_item(item_id)
|
|
299
|
+
|
|
300
|
+
exp_list.balance_metrics = self._compute_balance_metrics(
|
|
301
|
+
exp_list, constraints, metadata
|
|
302
|
+
)
|
|
303
|
+
lists.append(exp_list)
|
|
304
|
+
|
|
305
|
+
return lists
|
|
306
|
+
|
|
307
|
+
def _find_best_list(
|
|
308
|
+
self,
|
|
309
|
+
item_id: UUID,
|
|
310
|
+
lists: list[ExperimentList],
|
|
311
|
+
constraints: list[ListConstraint],
|
|
312
|
+
metadata: MetadataDict,
|
|
313
|
+
) -> ExperimentList:
|
|
314
|
+
"""Find the list that best maintains balance after adding item.
|
|
315
|
+
|
|
316
|
+
Parameters
|
|
317
|
+
----------
|
|
318
|
+
item_id : UUID
|
|
319
|
+
Item to add.
|
|
320
|
+
lists : list[ExperimentList]
|
|
321
|
+
Available lists.
|
|
322
|
+
constraints : list[ListConstraint]
|
|
323
|
+
Constraints to consider.
|
|
324
|
+
metadata : MetadataDict
|
|
325
|
+
Item metadata.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
ExperimentList
|
|
330
|
+
Best list for this item.
|
|
331
|
+
"""
|
|
332
|
+
# Compute score for each list (violations + size as tiebreaker)
|
|
333
|
+
scores: list[tuple[int, int]] = []
|
|
334
|
+
for exp_list in lists:
|
|
335
|
+
# Temporarily add item
|
|
336
|
+
exp_list.add_item(item_id)
|
|
337
|
+
|
|
338
|
+
# Compute constraint violations
|
|
339
|
+
violations = self._count_violations(exp_list, constraints, metadata)
|
|
340
|
+
|
|
341
|
+
# Remove item
|
|
342
|
+
exp_list.remove_item(item_id)
|
|
343
|
+
|
|
344
|
+
# Use (violations, current_size) as score
|
|
345
|
+
# Prefer lists with fewer violations, then smaller lists
|
|
346
|
+
scores.append((violations, len(exp_list.item_refs)))
|
|
347
|
+
|
|
348
|
+
# Return list with lowest score
|
|
349
|
+
best_idx = int(np.argmin([s[0] * 1000 + s[1] for s in scores]))
|
|
350
|
+
return lists[best_idx]
|
|
351
|
+
|
|
352
|
+
def _count_violations(
|
|
353
|
+
self,
|
|
354
|
+
exp_list: ExperimentList,
|
|
355
|
+
constraints: list[ListConstraint],
|
|
356
|
+
metadata: MetadataDict,
|
|
357
|
+
) -> int:
|
|
358
|
+
"""Count constraint violations for a list.
|
|
359
|
+
|
|
360
|
+
Violations are weighted by constraint priority. Higher priority
|
|
361
|
+
constraints contribute more to the total violation score.
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
exp_list : ExperimentList
|
|
366
|
+
The list to check.
|
|
367
|
+
constraints : list[ListConstraint]
|
|
368
|
+
Constraints to check.
|
|
369
|
+
metadata : MetadataDict
|
|
370
|
+
Item metadata.
|
|
371
|
+
|
|
372
|
+
Returns
|
|
373
|
+
-------
|
|
374
|
+
int
|
|
375
|
+
Weighted violation score (sum of priorities of violated constraints).
|
|
376
|
+
"""
|
|
377
|
+
violations = 0
|
|
378
|
+
|
|
379
|
+
for constraint in constraints:
|
|
380
|
+
is_violated = False
|
|
381
|
+
|
|
382
|
+
if isinstance(constraint, UniquenessConstraint):
|
|
383
|
+
if not self._check_uniqueness(exp_list, constraint, metadata):
|
|
384
|
+
is_violated = True
|
|
385
|
+
elif isinstance(constraint, BalanceConstraint):
|
|
386
|
+
if not self._check_balance(exp_list, constraint, metadata):
|
|
387
|
+
is_violated = True
|
|
388
|
+
elif isinstance(constraint, SizeConstraint):
|
|
389
|
+
if not self._check_size(exp_list, constraint):
|
|
390
|
+
is_violated = True
|
|
391
|
+
|
|
392
|
+
# Add constraint priority if violated
|
|
393
|
+
if is_violated:
|
|
394
|
+
violations += constraint.priority
|
|
395
|
+
|
|
396
|
+
return violations
|
|
397
|
+
|
|
398
|
+
def _check_uniqueness(
|
|
399
|
+
self,
|
|
400
|
+
exp_list: ExperimentList,
|
|
401
|
+
constraint: UniquenessConstraint,
|
|
402
|
+
metadata: MetadataDict,
|
|
403
|
+
) -> bool:
|
|
404
|
+
"""Check uniqueness constraint.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
exp_list : ExperimentList
|
|
409
|
+
List to check.
|
|
410
|
+
constraint : UniquenessConstraint
|
|
411
|
+
Uniqueness constraint.
|
|
412
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
413
|
+
Item metadata.
|
|
414
|
+
|
|
415
|
+
Returns
|
|
416
|
+
-------
|
|
417
|
+
bool
|
|
418
|
+
True if constraint is satisfied.
|
|
419
|
+
"""
|
|
420
|
+
# Get values for property
|
|
421
|
+
values: list[Any] = []
|
|
422
|
+
for item_id in exp_list.item_refs:
|
|
423
|
+
value = self._extract_property_value(
|
|
424
|
+
item_id, constraint.property_expression, constraint.context, metadata
|
|
425
|
+
)
|
|
426
|
+
values.append(value)
|
|
427
|
+
|
|
428
|
+
# Check for duplicates
|
|
429
|
+
if constraint.allow_null:
|
|
430
|
+
values = [v for v in values if v is not None]
|
|
431
|
+
|
|
432
|
+
return bool(len(values) == len(set(values)))
|
|
433
|
+
|
|
434
|
+
def _check_balance(
|
|
435
|
+
self,
|
|
436
|
+
exp_list: ExperimentList,
|
|
437
|
+
constraint: BalanceConstraint,
|
|
438
|
+
metadata: MetadataDict,
|
|
439
|
+
) -> bool:
|
|
440
|
+
"""Check balance constraint.
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
exp_list : ExperimentList
|
|
445
|
+
List to check.
|
|
446
|
+
constraint : BalanceConstraint
|
|
447
|
+
Balance constraint.
|
|
448
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
449
|
+
Item metadata.
|
|
450
|
+
|
|
451
|
+
Returns
|
|
452
|
+
-------
|
|
453
|
+
bool
|
|
454
|
+
True if constraint is satisfied.
|
|
455
|
+
"""
|
|
456
|
+
# Get values for property
|
|
457
|
+
values: list[Any] = []
|
|
458
|
+
for item_id in exp_list.item_refs:
|
|
459
|
+
value = self._extract_property_value(
|
|
460
|
+
item_id, constraint.property_expression, constraint.context, metadata
|
|
461
|
+
)
|
|
462
|
+
values.append(value)
|
|
463
|
+
|
|
464
|
+
# Count occurrences
|
|
465
|
+
counts = Counter(values)
|
|
466
|
+
|
|
467
|
+
# Check against target counts if specified
|
|
468
|
+
if constraint.target_counts is not None:
|
|
469
|
+
for category, target_count in constraint.target_counts.items():
|
|
470
|
+
actual_count = counts.get(category, 0)
|
|
471
|
+
deviation = abs(actual_count - target_count) / max(target_count, 1)
|
|
472
|
+
if deviation > constraint.tolerance:
|
|
473
|
+
return False
|
|
474
|
+
return True
|
|
475
|
+
|
|
476
|
+
# Otherwise check for balanced distribution
|
|
477
|
+
if len(counts) == 0:
|
|
478
|
+
return True
|
|
479
|
+
|
|
480
|
+
count_values = list(counts.values())
|
|
481
|
+
mean_count = np.mean(count_values)
|
|
482
|
+
max_deviation = max(
|
|
483
|
+
abs(c - mean_count) / max(mean_count, 1) for c in count_values
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
return bool(max_deviation <= constraint.tolerance)
|
|
487
|
+
|
|
488
|
+
def _check_size(self, exp_list: ExperimentList, constraint: SizeConstraint) -> bool:
|
|
489
|
+
"""Check size constraint.
|
|
490
|
+
|
|
491
|
+
Parameters
|
|
492
|
+
----------
|
|
493
|
+
exp_list : ExperimentList
|
|
494
|
+
List to check.
|
|
495
|
+
constraint : SizeConstraint
|
|
496
|
+
Size constraint.
|
|
497
|
+
|
|
498
|
+
Returns
|
|
499
|
+
-------
|
|
500
|
+
bool
|
|
501
|
+
True if constraint is satisfied.
|
|
502
|
+
"""
|
|
503
|
+
size = len(exp_list.item_refs)
|
|
504
|
+
|
|
505
|
+
if constraint.exact_size is not None:
|
|
506
|
+
return size == constraint.exact_size
|
|
507
|
+
|
|
508
|
+
if constraint.min_size is not None and size < constraint.min_size:
|
|
509
|
+
return False
|
|
510
|
+
|
|
511
|
+
if constraint.max_size is not None and size > constraint.max_size:
|
|
512
|
+
return False
|
|
513
|
+
|
|
514
|
+
return True
|
|
515
|
+
|
|
516
|
+
def _extract_property_value(
|
|
517
|
+
self,
|
|
518
|
+
item_id: UUID,
|
|
519
|
+
property_expression: str,
|
|
520
|
+
context: dict[str, ContextValue] | None,
|
|
521
|
+
metadata: MetadataDict,
|
|
522
|
+
) -> Any:
|
|
523
|
+
"""Extract property value using DSL expression.
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
item_id : UUID
|
|
528
|
+
Item UUID.
|
|
529
|
+
property_expression : str
|
|
530
|
+
DSL expression using dict access syntax (e.g., "item['lm_prob']",
|
|
531
|
+
"variance([item['val1'], item['val2']])"). The 'item' variable
|
|
532
|
+
refers to the metadata dict for this item.
|
|
533
|
+
context : dict[str, ContextValue] | None
|
|
534
|
+
Additional context variables for evaluation.
|
|
535
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
536
|
+
Metadata dict mapping item UUIDs to their metadata.
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
Any
|
|
541
|
+
Evaluated property value.
|
|
542
|
+
|
|
543
|
+
Raises
|
|
544
|
+
------
|
|
545
|
+
KeyError
|
|
546
|
+
If item_id not in metadata.
|
|
547
|
+
|
|
548
|
+
Notes
|
|
549
|
+
-----
|
|
550
|
+
Since ListPartitioner uses stand-off annotation (UUIDs only, not full
|
|
551
|
+
Item objects), the 'item' variable in property expressions refers to
|
|
552
|
+
the item's metadata dict, not a full Item object. Use dict access
|
|
553
|
+
syntax: item['key'] rather than item.key.
|
|
554
|
+
"""
|
|
555
|
+
if item_id not in metadata:
|
|
556
|
+
raise KeyError(f"Item {item_id} not found in metadata")
|
|
557
|
+
|
|
558
|
+
# Build evaluation context with item metadata directly
|
|
559
|
+
# The metadata dict IS the item for property expression purposes
|
|
560
|
+
eval_context: dict[str, Any] = {"item": metadata[item_id]}
|
|
561
|
+
if context:
|
|
562
|
+
eval_context.update(context)
|
|
563
|
+
|
|
564
|
+
return self.dsl_evaluator.evaluate(property_expression, eval_context)
|
|
565
|
+
|
|
566
|
+
def _compute_balance_metrics(
|
|
567
|
+
self,
|
|
568
|
+
exp_list: ExperimentList,
|
|
569
|
+
constraints: list[ListConstraint],
|
|
570
|
+
metadata: MetadataDict,
|
|
571
|
+
) -> BalanceMetrics:
|
|
572
|
+
"""Compute balance metrics for a list.
|
|
573
|
+
|
|
574
|
+
Parameters
|
|
575
|
+
----------
|
|
576
|
+
exp_list : ExperimentList
|
|
577
|
+
The list.
|
|
578
|
+
constraints : list[ListConstraint]
|
|
579
|
+
Constraints to compute metrics for.
|
|
580
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
581
|
+
Item metadata.
|
|
582
|
+
|
|
583
|
+
Returns
|
|
584
|
+
-------
|
|
585
|
+
dict[str, Any]
|
|
586
|
+
Balance metrics.
|
|
587
|
+
"""
|
|
588
|
+
metrics: dict[str, Any] = {}
|
|
589
|
+
|
|
590
|
+
# Compute metrics for each constraint
|
|
591
|
+
for constraint in constraints:
|
|
592
|
+
if isinstance(constraint, QuantileConstraint):
|
|
593
|
+
metrics[f"quantile_{constraint.property_expression}"] = (
|
|
594
|
+
self._compute_quantile_distribution(exp_list, constraint, metadata)
|
|
595
|
+
)
|
|
596
|
+
elif isinstance(constraint, BalanceConstraint):
|
|
597
|
+
metrics[f"balance_{constraint.property_expression}"] = (
|
|
598
|
+
self._compute_category_distribution(exp_list, constraint, metadata)
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# Overall size
|
|
602
|
+
metrics["size"] = len(exp_list.item_refs)
|
|
603
|
+
|
|
604
|
+
return metrics
|
|
605
|
+
|
|
606
|
+
def _compute_quantile_distribution(
|
|
607
|
+
self,
|
|
608
|
+
exp_list: ExperimentList,
|
|
609
|
+
constraint: QuantileConstraint,
|
|
610
|
+
metadata: MetadataDict,
|
|
611
|
+
) -> dict[str, float | list[float]]:
|
|
612
|
+
"""Compute distribution across quantiles.
|
|
613
|
+
|
|
614
|
+
Parameters
|
|
615
|
+
----------
|
|
616
|
+
exp_list : ExperimentList
|
|
617
|
+
The list.
|
|
618
|
+
constraint : QuantileConstraint
|
|
619
|
+
Quantile constraint.
|
|
620
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
621
|
+
Item metadata.
|
|
622
|
+
|
|
623
|
+
Returns
|
|
624
|
+
-------
|
|
625
|
+
dict[str, Any]
|
|
626
|
+
Distribution metrics.
|
|
627
|
+
"""
|
|
628
|
+
if not exp_list.item_refs:
|
|
629
|
+
return {
|
|
630
|
+
"mean": 0.0,
|
|
631
|
+
"std": 0.0,
|
|
632
|
+
"min": 0.0,
|
|
633
|
+
"max": 0.0,
|
|
634
|
+
"quantiles": [],
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
values = [
|
|
638
|
+
float(
|
|
639
|
+
self._extract_property_value(
|
|
640
|
+
item_id,
|
|
641
|
+
constraint.property_expression,
|
|
642
|
+
constraint.context,
|
|
643
|
+
metadata,
|
|
644
|
+
)
|
|
645
|
+
)
|
|
646
|
+
for item_id in exp_list.item_refs
|
|
647
|
+
]
|
|
648
|
+
|
|
649
|
+
# percentile with list input returns array
|
|
650
|
+
percentiles: np.ndarray[Any, np.dtype[np.floating[Any]]] = np.percentile(
|
|
651
|
+
values, [25, 50, 75]
|
|
652
|
+
)
|
|
653
|
+
# min/max with array input returns scalar
|
|
654
|
+
min_val: np.floating[Any] = np.min(values)
|
|
655
|
+
max_val: np.floating[Any] = np.max(values)
|
|
656
|
+
|
|
657
|
+
return {
|
|
658
|
+
"mean": float(np.mean(values)),
|
|
659
|
+
"std": float(np.std(values)),
|
|
660
|
+
"min": float(min_val),
|
|
661
|
+
"max": float(max_val),
|
|
662
|
+
"quantiles": [float(q) for q in percentiles],
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
def _compute_category_distribution(
|
|
666
|
+
self,
|
|
667
|
+
exp_list: ExperimentList,
|
|
668
|
+
constraint: BalanceConstraint,
|
|
669
|
+
metadata: MetadataDict,
|
|
670
|
+
) -> dict[str, dict[str, int] | int | tuple[Any, int] | None]:
|
|
671
|
+
"""Compute distribution across categories.
|
|
672
|
+
|
|
673
|
+
Parameters
|
|
674
|
+
----------
|
|
675
|
+
exp_list : ExperimentList
|
|
676
|
+
The list.
|
|
677
|
+
constraint : BalanceConstraint
|
|
678
|
+
Balance constraint.
|
|
679
|
+
metadata : dict[UUID, dict[str, Any]]
|
|
680
|
+
Item metadata.
|
|
681
|
+
|
|
682
|
+
Returns
|
|
683
|
+
-------
|
|
684
|
+
dict[str, Any]
|
|
685
|
+
Distribution metrics.
|
|
686
|
+
"""
|
|
687
|
+
if not exp_list.item_refs:
|
|
688
|
+
return {"counts": {}, "n_categories": 0, "most_common": None}
|
|
689
|
+
|
|
690
|
+
values = [
|
|
691
|
+
self._extract_property_value(
|
|
692
|
+
item_id, constraint.property_expression, constraint.context, metadata
|
|
693
|
+
)
|
|
694
|
+
for item_id in exp_list.item_refs
|
|
695
|
+
]
|
|
696
|
+
counts = Counter(values)
|
|
697
|
+
|
|
698
|
+
return {
|
|
699
|
+
"counts": dict(counts),
|
|
700
|
+
"n_categories": len(counts),
|
|
701
|
+
"most_common": counts.most_common(1)[0] if counts else None,
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
# ========================================================================
|
|
705
|
+
# Batch Constraint Methods
|
|
706
|
+
# ========================================================================
|
|
707
|
+
|
|
708
|
+
def partition_with_batch_constraints(
|
|
709
|
+
self,
|
|
710
|
+
items: list[UUID],
|
|
711
|
+
n_lists: int,
|
|
712
|
+
list_constraints: list[ListConstraint] | None = None,
|
|
713
|
+
batch_constraints: list[BatchConstraint] | None = None,
|
|
714
|
+
strategy: str = "balanced",
|
|
715
|
+
metadata: MetadataDict | None = None,
|
|
716
|
+
max_iterations: int = 1000,
|
|
717
|
+
tolerance: float = 0.05,
|
|
718
|
+
) -> list[ExperimentList]:
|
|
719
|
+
"""Partition items with batch-level constraints.
|
|
720
|
+
|
|
721
|
+
Creates initial partition using standard partitioning, then iteratively
|
|
722
|
+
refines to satisfy batch constraints through item swaps between lists.
|
|
723
|
+
|
|
724
|
+
Parameters
|
|
725
|
+
----------
|
|
726
|
+
items : list[UUID]
|
|
727
|
+
Item UUIDs to partition.
|
|
728
|
+
n_lists : int
|
|
729
|
+
Number of lists to create.
|
|
730
|
+
list_constraints : list[ListConstraint] | None, default=None
|
|
731
|
+
Per-list constraints to satisfy.
|
|
732
|
+
batch_constraints : list[BatchConstraint] | None, default=None
|
|
733
|
+
Batch-level constraints to satisfy.
|
|
734
|
+
strategy : str, default="balanced"
|
|
735
|
+
Initial partitioning strategy ("balanced", "random", "stratified").
|
|
736
|
+
metadata : dict[UUID, dict[str, Any]] | None, default=None
|
|
737
|
+
Metadata for each item UUID.
|
|
738
|
+
max_iterations : int, default=1000
|
|
739
|
+
Maximum refinement iterations.
|
|
740
|
+
tolerance : float, default=0.05
|
|
741
|
+
Tolerance for batch constraint satisfaction (score >= 1.0 - tolerance).
|
|
742
|
+
|
|
743
|
+
Returns
|
|
744
|
+
-------
|
|
745
|
+
list[ExperimentList]
|
|
746
|
+
Partitioned lists satisfying both list and batch constraints.
|
|
747
|
+
|
|
748
|
+
Examples
|
|
749
|
+
--------
|
|
750
|
+
>>> from bead.lists.constraints import BatchCoverageConstraint
|
|
751
|
+
>>> partitioner = ListPartitioner(random_seed=42)
|
|
752
|
+
>>> constraint = BatchCoverageConstraint(
|
|
753
|
+
... property_expression="item['template_id']",
|
|
754
|
+
... target_values=list(range(26)),
|
|
755
|
+
... min_coverage=1.0
|
|
756
|
+
... )
|
|
757
|
+
>>> lists = partitioner.partition_with_batch_constraints(
|
|
758
|
+
... items=item_uids,
|
|
759
|
+
... n_lists=8,
|
|
760
|
+
... batch_constraints=[constraint],
|
|
761
|
+
... metadata=metadata_dict,
|
|
762
|
+
... max_iterations=500
|
|
763
|
+
... )
|
|
764
|
+
"""
|
|
765
|
+
# Initial partitioning with list constraints
|
|
766
|
+
lists = self.partition(
|
|
767
|
+
items=items,
|
|
768
|
+
n_lists=n_lists,
|
|
769
|
+
constraints=list_constraints,
|
|
770
|
+
strategy=strategy,
|
|
771
|
+
metadata=metadata,
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
# If no batch constraints, return immediately
|
|
775
|
+
if not batch_constraints:
|
|
776
|
+
return lists
|
|
777
|
+
|
|
778
|
+
metadata = metadata or {}
|
|
779
|
+
|
|
780
|
+
# Iterative refinement loop
|
|
781
|
+
for _ in range(max_iterations):
|
|
782
|
+
# Check all batch constraints
|
|
783
|
+
all_satisfied = True
|
|
784
|
+
min_score = 1.0
|
|
785
|
+
worst_constraint = None
|
|
786
|
+
|
|
787
|
+
for constraint in batch_constraints:
|
|
788
|
+
score = self._compute_batch_constraint_score(
|
|
789
|
+
lists, constraint, metadata
|
|
790
|
+
)
|
|
791
|
+
if score < (1.0 - tolerance):
|
|
792
|
+
all_satisfied = False
|
|
793
|
+
if score < min_score:
|
|
794
|
+
min_score = score
|
|
795
|
+
worst_constraint = constraint
|
|
796
|
+
|
|
797
|
+
# If all satisfied, we're done
|
|
798
|
+
if all_satisfied:
|
|
799
|
+
break
|
|
800
|
+
|
|
801
|
+
# Try to improve worst constraint
|
|
802
|
+
if worst_constraint is not None:
|
|
803
|
+
improved = self._improve_batch_constraint(
|
|
804
|
+
lists,
|
|
805
|
+
worst_constraint,
|
|
806
|
+
list_constraints or [],
|
|
807
|
+
batch_constraints,
|
|
808
|
+
metadata,
|
|
809
|
+
)
|
|
810
|
+
if not improved:
|
|
811
|
+
# No improvement possible, stop
|
|
812
|
+
break
|
|
813
|
+
|
|
814
|
+
return lists
|
|
815
|
+
|
|
816
|
+
def _improve_batch_constraint(
|
|
817
|
+
self,
|
|
818
|
+
lists: list[ExperimentList],
|
|
819
|
+
constraint: BatchConstraint,
|
|
820
|
+
list_constraints: list[ListConstraint],
|
|
821
|
+
batch_constraints: list[BatchConstraint],
|
|
822
|
+
metadata: MetadataDict,
|
|
823
|
+
n_attempts: int = 100,
|
|
824
|
+
) -> bool:
|
|
825
|
+
"""Attempt to improve batch constraint through item swaps.
|
|
826
|
+
|
|
827
|
+
Parameters
|
|
828
|
+
----------
|
|
829
|
+
lists : list[ExperimentList]
|
|
830
|
+
Current lists.
|
|
831
|
+
constraint : BatchConstraint
|
|
832
|
+
Constraint to improve.
|
|
833
|
+
list_constraints : list[ListConstraint]
|
|
834
|
+
Per-list constraints that must remain satisfied.
|
|
835
|
+
batch_constraints : list[BatchConstraint]
|
|
836
|
+
All batch constraints to check.
|
|
837
|
+
metadata : MetadataDict
|
|
838
|
+
Item metadata.
|
|
839
|
+
n_attempts : int, default=100
|
|
840
|
+
Number of swap attempts.
|
|
841
|
+
|
|
842
|
+
Returns
|
|
843
|
+
-------
|
|
844
|
+
bool
|
|
845
|
+
True if improvement was made.
|
|
846
|
+
"""
|
|
847
|
+
current_score = self._compute_batch_constraint_score(
|
|
848
|
+
lists, constraint, metadata
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
for _ in range(n_attempts):
|
|
852
|
+
# Select two random lists
|
|
853
|
+
if len(lists) < 2:
|
|
854
|
+
return False
|
|
855
|
+
|
|
856
|
+
# integers returns int or array depending on size parameter
|
|
857
|
+
list_idx_a = int(self._rng.integers(0, len(lists)))
|
|
858
|
+
list_idx_b = int(self._rng.integers(0, len(lists)))
|
|
859
|
+
|
|
860
|
+
if list_idx_a == list_idx_b:
|
|
861
|
+
continue
|
|
862
|
+
|
|
863
|
+
list_a = lists[list_idx_a]
|
|
864
|
+
list_b = lists[list_idx_b]
|
|
865
|
+
|
|
866
|
+
# Select random items from each list
|
|
867
|
+
if len(list_a.item_refs) == 0 or len(list_b.item_refs) == 0:
|
|
868
|
+
continue
|
|
869
|
+
|
|
870
|
+
item_idx_a = int(self._rng.integers(0, len(list_a.item_refs)))
|
|
871
|
+
item_idx_b = int(self._rng.integers(0, len(list_b.item_refs)))
|
|
872
|
+
|
|
873
|
+
item_a = list_a.item_refs[item_idx_a]
|
|
874
|
+
item_b = list_b.item_refs[item_idx_b]
|
|
875
|
+
|
|
876
|
+
# Perform swap
|
|
877
|
+
list_a.item_refs[item_idx_a] = item_b
|
|
878
|
+
list_b.item_refs[item_idx_b] = item_a
|
|
879
|
+
|
|
880
|
+
# Check if batch constraint improved
|
|
881
|
+
new_score = self._compute_batch_constraint_score(
|
|
882
|
+
lists, constraint, metadata
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
# Check if list constraints still satisfied
|
|
886
|
+
list_a_valid = (
|
|
887
|
+
self._count_violations(list_a, list_constraints, metadata) == 0
|
|
888
|
+
)
|
|
889
|
+
list_b_valid = (
|
|
890
|
+
self._count_violations(list_b, list_constraints, metadata) == 0
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
# Accept if improved and list constraints still satisfied
|
|
894
|
+
if new_score > current_score and list_a_valid and list_b_valid:
|
|
895
|
+
return True
|
|
896
|
+
|
|
897
|
+
# Revert swap
|
|
898
|
+
list_a.item_refs[item_idx_a] = item_a
|
|
899
|
+
list_b.item_refs[item_idx_b] = item_b
|
|
900
|
+
|
|
901
|
+
return False
|
|
902
|
+
|
|
903
|
+
def _compute_batch_constraint_score(
|
|
904
|
+
self,
|
|
905
|
+
lists: list[ExperimentList],
|
|
906
|
+
constraint: BatchConstraint,
|
|
907
|
+
metadata: MetadataDict,
|
|
908
|
+
) -> float:
|
|
909
|
+
"""Compute satisfaction score for batch constraint.
|
|
910
|
+
|
|
911
|
+
Parameters
|
|
912
|
+
----------
|
|
913
|
+
lists : list[ExperimentList]
|
|
914
|
+
All lists in the batch.
|
|
915
|
+
constraint : BatchConstraint
|
|
916
|
+
Batch constraint to check.
|
|
917
|
+
metadata : MetadataDict
|
|
918
|
+
Item metadata.
|
|
919
|
+
|
|
920
|
+
Returns
|
|
921
|
+
-------
|
|
922
|
+
float
|
|
923
|
+
Satisfaction score in [0, 1].
|
|
924
|
+
"""
|
|
925
|
+
if isinstance(constraint, BatchCoverageConstraint):
|
|
926
|
+
return self._compute_batch_coverage_score(lists, constraint, metadata)
|
|
927
|
+
elif isinstance(constraint, BatchBalanceConstraint):
|
|
928
|
+
return self._compute_batch_balance_score(lists, constraint, metadata)
|
|
929
|
+
elif isinstance(constraint, BatchDiversityConstraint):
|
|
930
|
+
return self._compute_batch_diversity_score(lists, constraint, metadata)
|
|
931
|
+
else: # BatchMinOccurrenceConstraint
|
|
932
|
+
return self._compute_batch_min_occurrence_score(lists, constraint, metadata)
|
|
933
|
+
|
|
934
|
+
def _compute_batch_coverage_score(
|
|
935
|
+
self,
|
|
936
|
+
lists: list[ExperimentList],
|
|
937
|
+
constraint: BatchCoverageConstraint,
|
|
938
|
+
metadata: MetadataDict,
|
|
939
|
+
) -> float:
|
|
940
|
+
"""Compute coverage score across all lists.
|
|
941
|
+
|
|
942
|
+
Parameters
|
|
943
|
+
----------
|
|
944
|
+
lists : list[ExperimentList]
|
|
945
|
+
All lists in the batch.
|
|
946
|
+
constraint : BatchCoverageConstraint
|
|
947
|
+
Coverage constraint.
|
|
948
|
+
metadata : MetadataDict
|
|
949
|
+
Item metadata.
|
|
950
|
+
|
|
951
|
+
Returns
|
|
952
|
+
-------
|
|
953
|
+
float
|
|
954
|
+
Coverage ratio (observed_values / target_values).
|
|
955
|
+
"""
|
|
956
|
+
# Collect all observed values across all lists
|
|
957
|
+
observed_values: set[int | float | str | bool] = set()
|
|
958
|
+
for exp_list in lists:
|
|
959
|
+
for item_id in exp_list.item_refs:
|
|
960
|
+
try:
|
|
961
|
+
value = self._extract_property_value(
|
|
962
|
+
item_id,
|
|
963
|
+
constraint.property_expression,
|
|
964
|
+
constraint.context,
|
|
965
|
+
metadata,
|
|
966
|
+
)
|
|
967
|
+
observed_values.add(value)
|
|
968
|
+
except Exception:
|
|
969
|
+
continue
|
|
970
|
+
|
|
971
|
+
# Compute coverage
|
|
972
|
+
if constraint.target_values is None:
|
|
973
|
+
return 1.0
|
|
974
|
+
|
|
975
|
+
if len(constraint.target_values) == 0:
|
|
976
|
+
return 1.0
|
|
977
|
+
|
|
978
|
+
target_set: set[int | float | str | bool] = set(constraint.target_values)
|
|
979
|
+
coverage = len(observed_values & target_set) / len(target_set)
|
|
980
|
+
return float(coverage)
|
|
981
|
+
|
|
982
|
+
def _compute_batch_balance_score(
|
|
983
|
+
self,
|
|
984
|
+
lists: list[ExperimentList],
|
|
985
|
+
constraint: BatchBalanceConstraint,
|
|
986
|
+
metadata: MetadataDict,
|
|
987
|
+
) -> float:
|
|
988
|
+
"""Compute balance score across all lists.
|
|
989
|
+
|
|
990
|
+
Parameters
|
|
991
|
+
----------
|
|
992
|
+
lists : list[ExperimentList]
|
|
993
|
+
All lists in the batch.
|
|
994
|
+
constraint : BatchBalanceConstraint
|
|
995
|
+
Balance constraint.
|
|
996
|
+
metadata : MetadataDict
|
|
997
|
+
Item metadata.
|
|
998
|
+
|
|
999
|
+
Returns
|
|
1000
|
+
-------
|
|
1001
|
+
float
|
|
1002
|
+
Score in [0, 1] based on deviation from target distribution.
|
|
1003
|
+
"""
|
|
1004
|
+
# Count occurrences across all lists
|
|
1005
|
+
counts: Counter[str] = Counter()
|
|
1006
|
+
total = 0
|
|
1007
|
+
|
|
1008
|
+
for exp_list in lists:
|
|
1009
|
+
for item_id in exp_list.item_refs:
|
|
1010
|
+
try:
|
|
1011
|
+
value = self._extract_property_value(
|
|
1012
|
+
item_id,
|
|
1013
|
+
constraint.property_expression,
|
|
1014
|
+
constraint.context,
|
|
1015
|
+
metadata,
|
|
1016
|
+
)
|
|
1017
|
+
counts[value] += 1
|
|
1018
|
+
total += 1
|
|
1019
|
+
except Exception:
|
|
1020
|
+
continue
|
|
1021
|
+
|
|
1022
|
+
if total == 0:
|
|
1023
|
+
return 1.0
|
|
1024
|
+
|
|
1025
|
+
# Compute actual distribution
|
|
1026
|
+
actual_dist = {k: v / total for k, v in counts.items()}
|
|
1027
|
+
|
|
1028
|
+
# Compute max deviation from target
|
|
1029
|
+
max_deviation = 0.0
|
|
1030
|
+
for value, target_prob in constraint.target_distribution.items():
|
|
1031
|
+
actual_prob = actual_dist.get(value, 0.0)
|
|
1032
|
+
deviation = abs(actual_prob - target_prob)
|
|
1033
|
+
max_deviation = max(max_deviation, deviation)
|
|
1034
|
+
|
|
1035
|
+
# Score decreases with deviation
|
|
1036
|
+
score = max(0.0, 1.0 - max_deviation)
|
|
1037
|
+
return float(score)
|
|
1038
|
+
|
|
1039
|
+
def _compute_batch_diversity_score(
|
|
1040
|
+
self,
|
|
1041
|
+
lists: list[ExperimentList],
|
|
1042
|
+
constraint: BatchDiversityConstraint,
|
|
1043
|
+
metadata: MetadataDict,
|
|
1044
|
+
) -> float:
|
|
1045
|
+
"""Compute diversity score across all lists.
|
|
1046
|
+
|
|
1047
|
+
Parameters
|
|
1048
|
+
----------
|
|
1049
|
+
lists : list[ExperimentList]
|
|
1050
|
+
All lists in the batch.
|
|
1051
|
+
constraint : BatchDiversityConstraint
|
|
1052
|
+
Diversity constraint.
|
|
1053
|
+
metadata : MetadataDict
|
|
1054
|
+
Item metadata.
|
|
1055
|
+
|
|
1056
|
+
Returns
|
|
1057
|
+
-------
|
|
1058
|
+
float
|
|
1059
|
+
Score in [0, 1]. 1.0 if constraint satisfied, decreases with violations.
|
|
1060
|
+
"""
|
|
1061
|
+
# Track which lists contain each value
|
|
1062
|
+
value_to_lists: defaultdict[str | int | float, set[int]] = defaultdict(set)
|
|
1063
|
+
|
|
1064
|
+
for list_idx, exp_list in enumerate(lists):
|
|
1065
|
+
for item_id in exp_list.item_refs:
|
|
1066
|
+
try:
|
|
1067
|
+
value = self._extract_property_value(
|
|
1068
|
+
item_id,
|
|
1069
|
+
constraint.property_expression,
|
|
1070
|
+
constraint.context,
|
|
1071
|
+
metadata,
|
|
1072
|
+
)
|
|
1073
|
+
value_to_lists[value].add(list_idx)
|
|
1074
|
+
except Exception:
|
|
1075
|
+
continue
|
|
1076
|
+
|
|
1077
|
+
if not value_to_lists:
|
|
1078
|
+
return 1.0
|
|
1079
|
+
|
|
1080
|
+
# Compute violations
|
|
1081
|
+
violations = 0
|
|
1082
|
+
total_values = len(value_to_lists)
|
|
1083
|
+
|
|
1084
|
+
for _value, list_indices in value_to_lists.items():
|
|
1085
|
+
if len(list_indices) > constraint.max_lists_per_value:
|
|
1086
|
+
violations += 1
|
|
1087
|
+
|
|
1088
|
+
# Score = 1.0 when no violations, decreases linearly
|
|
1089
|
+
score = 1.0 - (violations / max(total_values, 1))
|
|
1090
|
+
return float(max(0.0, score))
|
|
1091
|
+
|
|
1092
|
+
def _compute_batch_min_occurrence_score(
|
|
1093
|
+
self,
|
|
1094
|
+
lists: list[ExperimentList],
|
|
1095
|
+
constraint: BatchMinOccurrenceConstraint,
|
|
1096
|
+
metadata: MetadataDict,
|
|
1097
|
+
) -> float:
|
|
1098
|
+
"""Compute minimum occurrence score across all lists.
|
|
1099
|
+
|
|
1100
|
+
Parameters
|
|
1101
|
+
----------
|
|
1102
|
+
lists : list[ExperimentList]
|
|
1103
|
+
All lists in the batch.
|
|
1104
|
+
constraint : BatchMinOccurrenceConstraint
|
|
1105
|
+
Minimum occurrence constraint.
|
|
1106
|
+
metadata : MetadataDict
|
|
1107
|
+
Item metadata.
|
|
1108
|
+
|
|
1109
|
+
Returns
|
|
1110
|
+
-------
|
|
1111
|
+
float
|
|
1112
|
+
Score in [0, 1] based on minimum count ratio.
|
|
1113
|
+
"""
|
|
1114
|
+
# Count occurrences of each value
|
|
1115
|
+
counts: Counter[str] = Counter()
|
|
1116
|
+
|
|
1117
|
+
for exp_list in lists:
|
|
1118
|
+
for item_id in exp_list.item_refs:
|
|
1119
|
+
try:
|
|
1120
|
+
value = self._extract_property_value(
|
|
1121
|
+
item_id,
|
|
1122
|
+
constraint.property_expression,
|
|
1123
|
+
constraint.context,
|
|
1124
|
+
metadata,
|
|
1125
|
+
)
|
|
1126
|
+
counts[value] += 1
|
|
1127
|
+
except Exception:
|
|
1128
|
+
continue
|
|
1129
|
+
|
|
1130
|
+
if not counts:
|
|
1131
|
+
return 1.0
|
|
1132
|
+
|
|
1133
|
+
# Score = min(count / target) across all values
|
|
1134
|
+
min_ratio = float("inf")
|
|
1135
|
+
for _value, count in counts.items():
|
|
1136
|
+
ratio = count / constraint.min_occurrences
|
|
1137
|
+
min_ratio = min(min_ratio, ratio)
|
|
1138
|
+
|
|
1139
|
+
# Clip to [0, 1]
|
|
1140
|
+
score = min(1.0, max(0.0, min_ratio))
|
|
1141
|
+
return float(score)
|