openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
gabriel/tasks/rank.py
ADDED
|
@@ -0,0 +1,2041 @@
|
|
|
1
|
+
"""
|
|
2
|
+
rank.py
|
|
3
|
+
~~~~~~~~
|
|
4
|
+
|
|
5
|
+
This module implements a simplified yet fully featured ranking engine for
|
|
6
|
+
evaluating pairs of passages on a set of attributes. It draws heavy
|
|
7
|
+
inspiration from the existing ``elo.py`` implementation found in the
|
|
8
|
+
GABRIEL distribution but removes support for the classic Elo rating
|
|
9
|
+
system and focuses solely on the Bradley–Terry (BT) style approach.
|
|
10
|
+
|
|
11
|
+
Key improvements and changes relative to ``elo.py`` include:
|
|
12
|
+
|
|
13
|
+
* A streamlined configuration dataclass (`RankConfig`) that exposes the
|
|
14
|
+
parameters most relevant to the BT method. Irrelevant options
|
|
15
|
+
(e.g. ``rating_method``, ``k_factor``) have been removed, and
|
|
16
|
+
parameter names have been harmonised with the high‑level API
|
|
17
|
+
described in the calling code. ``file_name`` is now treated as a
|
|
18
|
+
stem; if an extension is provided it will be stripped automatically.
|
|
19
|
+
|
|
20
|
+
* Support for the new rankings prompt (``rankings_prompt.jinja2``)
|
|
21
|
+
which allows the large language model to return one of four
|
|
22
|
+
outcomes for each attribute: ``"circle"``, ``"square"``, ``"draw``
|
|
23
|
+
or ``"insufficient signal"``. ``draw`` and ``insufficient signal``
|
|
24
|
+
are both treated as a tie and contribute equally to both items when
|
|
25
|
+
fitting the BT model.
|
|
26
|
+
|
|
27
|
+
* A cleaned up asynchronous ``run`` method that accepts a pandas
|
|
28
|
+
``DataFrame`` and the name of the column containing the text to be
|
|
29
|
+
ranked. Each row receives a stable identifier derived from a hash of its
|
|
30
|
+
contents; no external ``id_col`` argument is required. The method
|
|
31
|
+
produces a DataFrame with one row per input passage, a numeric
|
|
32
|
+
rating for each attribute, along with z‑scores and standard errors,
|
|
33
|
+
and writes the results to disk under ``save_dir``.
|
|
34
|
+
|
|
35
|
+
The core ranking logic remains largely unchanged from ``elo.py``
|
|
36
|
+
because the underlying mathematics of the BT model and the pairing
|
|
37
|
+
strategies continue to work well. However, comments have been added
|
|
38
|
+
throughout the code to clarify intent and to highlight areas where
|
|
39
|
+
further experimentation (e.g. alternative information gain metrics) can
|
|
40
|
+
be incorporated.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
from __future__ import annotations
|
|
44
|
+
|
|
45
|
+
import os
|
|
46
|
+
from pathlib import Path
|
|
47
|
+
import random
|
|
48
|
+
import hashlib
|
|
49
|
+
import math
|
|
50
|
+
import copy
|
|
51
|
+
from dataclasses import dataclass, field, fields
|
|
52
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
53
|
+
|
|
54
|
+
import numpy as np
|
|
55
|
+
import pandas as pd
|
|
56
|
+
|
|
57
|
+
# Import helper utilities from the gabriel package. These modules are
|
|
58
|
+
# expected to be available in the runtime environment. Should you wish
|
|
59
|
+
# to run this module outside of the GABRIEL distribution, you may need
|
|
60
|
+
# to adjust these imports accordingly.
|
|
61
|
+
from gabriel.core.prompt_template import PromptTemplate, resolve_template
|
|
62
|
+
from gabriel.utils.openai_utils import get_all_responses
|
|
63
|
+
from gabriel.utils import (
|
|
64
|
+
safest_json,
|
|
65
|
+
load_image_inputs,
|
|
66
|
+
load_audio_inputs,
|
|
67
|
+
load_pdf_inputs,
|
|
68
|
+
warn_if_modality_mismatch,
|
|
69
|
+
)
|
|
70
|
+
from gabriel.utils.logging import announce_prompt_rendering
|
|
71
|
+
from .rate import Rate, RateConfig
|
|
72
|
+
from ._attribute_utils import load_persisted_attributes
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class RankConfig:
|
|
77
|
+
"""User‑visible configuration for :class:`Rank`.
|
|
78
|
+
|
|
79
|
+
Only a minimal set of parameters are exposed to keep the API
|
|
80
|
+
straightforward. Additional hyperparameters for the underlying
|
|
81
|
+
Bradley–Terry model and pairing heuristics are fixed at sensible
|
|
82
|
+
values and should not generally need to be changed. See the
|
|
83
|
+
surrounding documentation for more details.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
attributes:
|
|
88
|
+
Mapping from attribute names to definitions. A list of
|
|
89
|
+
attribute names is also accepted; definitions will be set to
|
|
90
|
+
empty strings.
|
|
91
|
+
n_rounds:
|
|
92
|
+
Number of rounds of pairwise comparisons to perform.
|
|
93
|
+
matches_per_round:
|
|
94
|
+
Number of matches per item per round.
|
|
95
|
+
power_matching:
|
|
96
|
+
Whether to use an information‑theoretic pairing heuristic. The
|
|
97
|
+
resulting rankings always include per‑attribute z‑scores alongside the
|
|
98
|
+
raw Bradley–Terry estimates (``"<attribute>_raw"``) and their
|
|
99
|
+
standard errors (``"<attribute>_se"``).
|
|
100
|
+
learning_rate:
|
|
101
|
+
Pseudo‑count used by the BT model to regularise the win/loss
|
|
102
|
+
matrix. A larger value makes updates more conservative.
|
|
103
|
+
model:
|
|
104
|
+
Name of the language model to call via ``get_all_responses``.
|
|
105
|
+
n_parallels:
|
|
106
|
+
Number of parallel API calls to issue.
|
|
107
|
+
use_dummy:
|
|
108
|
+
Whether to use a dummy model for testing purposes.
|
|
109
|
+
save_dir:
|
|
110
|
+
Directory into which result files should be saved.
|
|
111
|
+
file_name:
|
|
112
|
+
Stem for the output CSV files. If an extension is present it
|
|
113
|
+
will be removed.
|
|
114
|
+
additional_instructions:
|
|
115
|
+
Extra, user‑supplied instructions passed to the prompt.
|
|
116
|
+
recursive:
|
|
117
|
+
When ``True`` run ranking in multiple stages, pruning the pool
|
|
118
|
+
of candidates between stages according to ``recursive_fraction``
|
|
119
|
+
and ``recursive_min_remaining``.
|
|
120
|
+
recursive_fraction, recursive_min_remaining,
|
|
121
|
+
recursive_final_round_multiplier:
|
|
122
|
+
Parameters controlling how many items are kept between stages
|
|
123
|
+
and how many rounds are executed in the final stage when
|
|
124
|
+
``recursive`` is enabled.
|
|
125
|
+
recursive_cut_attr, recursive_cut_side:
|
|
126
|
+
Select which attribute and direction are used when choosing
|
|
127
|
+
which items survive to the next stage.
|
|
128
|
+
recursive_rate_first_round:
|
|
129
|
+
If ``True`` perform a :class:`Rate` sweep before the first
|
|
130
|
+
recursive stage and seed subsequent rounds with those scores.
|
|
131
|
+
This is enabled by default so the initial culling uses grounded
|
|
132
|
+
single-pass ratings; set to ``False`` to skip.
|
|
133
|
+
recursive_rewrite_func, recursive_rewrite_text_col:
|
|
134
|
+
Optional hook to rewrite surviving passages between stages and
|
|
135
|
+
the column where rewritten text should be stored.
|
|
136
|
+
recursive_keep_stage_columns, recursive_add_stage_suffix:
|
|
137
|
+
Control whether intermediate stage outputs are merged into the
|
|
138
|
+
final results and whether their columns receive stage prefixes.
|
|
139
|
+
max_timeout:
|
|
140
|
+
Optional upper bound for individual API calls when retrieving
|
|
141
|
+
ranking judgements. ``None`` (default) lets the timeout be
|
|
142
|
+
derived dynamically from observed latencies in
|
|
143
|
+
:func:`gabriel.utils.openai_utils.get_all_responses`.
|
|
144
|
+
initial_rating_pass:
|
|
145
|
+
Enables a one-off :class:`Rate` pass before standard ranking
|
|
146
|
+
rounds. The centred scores from that pass seed the initial
|
|
147
|
+
Bradley–Terry ratings which helps pairing focus on refinement.
|
|
148
|
+
Enabled by default; set ``initial_rating_pass=False`` if you
|
|
149
|
+
want to start directly with pairwise comparisons.
|
|
150
|
+
rate_kwargs:
|
|
151
|
+
Optional dictionary of overrides forwarded to the rating task
|
|
152
|
+
whenever it is invoked (either as a seed or during recursion).
|
|
153
|
+
primer_scores, primer_scale, primer_center:
|
|
154
|
+
Optional manual primers to seed the Bradley–Terry rating state.
|
|
155
|
+
Scores are centred per attribute when ``primer_center`` is
|
|
156
|
+
``True`` and scaled by ``primer_scale``.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
attributes: Union[Dict[str, str], List[str]]
|
|
160
|
+
n_rounds: int = 5
|
|
161
|
+
matches_per_round: int = 5
|
|
162
|
+
power_matching: bool = True
|
|
163
|
+
learning_rate: float = 0.1
|
|
164
|
+
model: str = "gpt-5-mini"
|
|
165
|
+
n_parallels: int = 650
|
|
166
|
+
use_dummy: bool = False
|
|
167
|
+
save_dir: str = os.path.expanduser("~/Documents/runs")
|
|
168
|
+
file_name: str = "rankings"
|
|
169
|
+
additional_instructions: Optional[str] = None
|
|
170
|
+
circle_first: Optional[bool] = None
|
|
171
|
+
modality: str = "text"
|
|
172
|
+
n_attributes_per_run: int = 8
|
|
173
|
+
reasoning_effort: Optional[str] = None
|
|
174
|
+
reasoning_summary: Optional[str] = None
|
|
175
|
+
max_timeout: Optional[float] = None
|
|
176
|
+
# Recursive execution controls
|
|
177
|
+
recursive: bool = False
|
|
178
|
+
recursive_fraction: float = 1.0 / 3.0
|
|
179
|
+
recursive_min_remaining: int = 30
|
|
180
|
+
recursive_final_round_multiplier: int = 3
|
|
181
|
+
recursive_cut_attr: Optional[str] = None
|
|
182
|
+
recursive_cut_side: str = "top"
|
|
183
|
+
recursive_rate_first_round: bool = True
|
|
184
|
+
recursive_rewrite_func: Optional[Callable[[str, str, int], str]] = None
|
|
185
|
+
recursive_rewrite_text_col: str = "text"
|
|
186
|
+
recursive_keep_stage_columns: bool = True
|
|
187
|
+
recursive_add_stage_suffix: bool = True
|
|
188
|
+
# Optional single pass rating seed controls
|
|
189
|
+
initial_rating_pass: bool = True
|
|
190
|
+
rate_kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
191
|
+
# Optional manual primers to seed ratings (applies to both recursive and
|
|
192
|
+
# non-recursive runs). Mapping from identifier -> {attribute: score}.
|
|
193
|
+
# Scores are centred by attribute and scaled by ``primer_scale`` before
|
|
194
|
+
# being injected into the Bradley–Terry state.
|
|
195
|
+
primer_scores: Optional[Dict[str, Dict[str, float]]] = None
|
|
196
|
+
primer_scale: float = 1.0
|
|
197
|
+
primer_center: bool = True
|
|
198
|
+
|
|
199
|
+
def __post_init__(self) -> None:
|
|
200
|
+
if self.additional_instructions is not None:
|
|
201
|
+
cleaned = str(self.additional_instructions).strip()
|
|
202
|
+
self.additional_instructions = cleaned or None
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class Rank:
|
|
206
|
+
"""Rank items by comparing passages pairwise on multiple attributes.
|
|
207
|
+
|
|
208
|
+
An instance of :class:`Ranker` orchestrates the iterative process
|
|
209
|
+
of sampling pairs, calling a language model to adjudicate which
|
|
210
|
+
passage better exhibits each attribute, and then fitting a
|
|
211
|
+
Bradley–Terry model to those outcomes. Standard errors and
|
|
212
|
+
z‑scores are computed for every attribute. Results are persisted to disk
|
|
213
|
+
after the final round.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
cfg: RankConfig,
|
|
219
|
+
template: Optional[PromptTemplate] = None,
|
|
220
|
+
template_path: Optional[str] = None,
|
|
221
|
+
) -> None:
|
|
222
|
+
"""Instantiate a ranking engine.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
cfg:
|
|
227
|
+
User‑provided configuration.
|
|
228
|
+
template:
|
|
229
|
+
Optional :class:`gabriel.core.prompt_template.PromptTemplate` to
|
|
230
|
+
render the comparison prompts. If not supplied, the built‑in
|
|
231
|
+
``rankings_prompt.jinja2`` template is used.
|
|
232
|
+
template_path:
|
|
233
|
+
Path to a custom prompt template on disk. The template is
|
|
234
|
+
validated to ensure it expects the same variables as the
|
|
235
|
+
built‑in template.
|
|
236
|
+
"""
|
|
237
|
+
expanded = Path(os.path.expandvars(os.path.expanduser(cfg.save_dir)))
|
|
238
|
+
expanded.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
cfg.save_dir = str(expanded)
|
|
240
|
+
self.cfg = cfg
|
|
241
|
+
self.template = resolve_template(
|
|
242
|
+
template=template,
|
|
243
|
+
template_path=template_path,
|
|
244
|
+
reference_filename="rankings_prompt.jinja2",
|
|
245
|
+
)
|
|
246
|
+
# random state; a seed is intentionally omitted from the public
|
|
247
|
+
# configuration to discourage brittle behaviour. If
|
|
248
|
+
# reproducibility is required, modify this line to pass a
|
|
249
|
+
# specific seed.
|
|
250
|
+
self.rng = random.Random()
|
|
251
|
+
# place holders for multiway rankings and aggregated standard errors
|
|
252
|
+
self.history_multi: Dict[str, List[List[str]]] = {}
|
|
253
|
+
self._last_se_agg: Optional[Dict[str, float]] = None
|
|
254
|
+
|
|
255
|
+
# internal constants for the pairing and BT algorithms. These
|
|
256
|
+
# values are deliberately not exposed through the public API as
|
|
257
|
+
# they seldom need tuning and adjusting them can complicate
|
|
258
|
+
# reproducibility. Should you need to experiment with these
|
|
259
|
+
# hyperparameters, modify the values below.
|
|
260
|
+
self._EXPLORE_FRAC = 0.2 # fraction of random pairings per round
|
|
261
|
+
self._CANDIDATE_NEIGHBORS = 20 # neighbourhood size for info gain pairing
|
|
262
|
+
self._HIGH_SE_FRAC = 0.25 # fraction of high‑uncertainty items
|
|
263
|
+
self._MAX_ITER = 1000 # maximum iterations for BT optimisation
|
|
264
|
+
self._TOL = 1e-6 # convergence tolerance for BT
|
|
265
|
+
# A small ridge term stabilises the inversion of the Fisher information
|
|
266
|
+
# matrix when computing standard errors. The previous value (1e‑9)
|
|
267
|
+
# occasionally led to extremely large uncertainties for items with
|
|
268
|
+
# limited or contradictory comparisons. Increasing this value
|
|
269
|
+
# regularises the covariance estimate and prevents unreasonably
|
|
270
|
+
# large standard errors. If you observe inflated SE values,
|
|
271
|
+
# consider increasing this further (e.g. to 1e‑4).
|
|
272
|
+
self._SE_RIDGE = 1e-5
|
|
273
|
+
# The maximum number of candidate pairs to consider per pairing round.
|
|
274
|
+
# When the number of items becomes very large (e.g. tens of thousands),
|
|
275
|
+
# evaluating all possible pairs is intractable. We therefore cap the
|
|
276
|
+
# total number of candidate pairs by limiting the neighbourhood size
|
|
277
|
+
# used when constructing candidate pairs. The default of 200k ensures
|
|
278
|
+
# that information gain pairing remains tractable even with very
|
|
279
|
+
# large data sets: for example, with 10 000 items and a cap of
|
|
280
|
+
# 200 000, each item will only consider approximately 20 neighbours.
|
|
281
|
+
self._MAX_CANDIDATE_PAIRS_PER_ROUND = 200_000
|
|
282
|
+
|
|
283
|
+
# ------------------------------------------------------------------
|
|
284
|
+
def _apply_primer(
|
|
285
|
+
self,
|
|
286
|
+
ratings: Dict[str, Dict[str, float]],
|
|
287
|
+
primer: Optional[Dict[str, Dict[str, float]]],
|
|
288
|
+
attr_keys: List[str],
|
|
289
|
+
) -> None:
|
|
290
|
+
"""Inject user-provided primer scores into the rating state.
|
|
291
|
+
|
|
292
|
+
Primers are centred per-attribute if ``primer_center`` is True and
|
|
293
|
+
scaled by ``primer_scale``. Missing attributes are ignored.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
if not primer:
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
# normalise per attribute
|
|
300
|
+
attr_to_vals: Dict[str, List[float]] = {a: [] for a in attr_keys}
|
|
301
|
+
for ident, amap in primer.items():
|
|
302
|
+
if ident not in ratings:
|
|
303
|
+
continue
|
|
304
|
+
for attr in attr_keys:
|
|
305
|
+
if attr in amap and amap[attr] is not None:
|
|
306
|
+
try:
|
|
307
|
+
attr_to_vals[attr].append(float(amap[attr]))
|
|
308
|
+
except Exception:
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
attr_offset: Dict[str, float] = {a: 0.0 for a in attr_keys}
|
|
312
|
+
if self.cfg.primer_center:
|
|
313
|
+
for attr, vals in attr_to_vals.items():
|
|
314
|
+
if vals:
|
|
315
|
+
attr_offset[attr] = float(np.mean(vals))
|
|
316
|
+
|
|
317
|
+
scale = self.cfg.primer_scale or 1.0
|
|
318
|
+
for ident, amap in primer.items():
|
|
319
|
+
if ident not in ratings:
|
|
320
|
+
continue
|
|
321
|
+
for attr in attr_keys:
|
|
322
|
+
if attr not in amap or amap[attr] is None:
|
|
323
|
+
continue
|
|
324
|
+
try:
|
|
325
|
+
val = (float(amap[attr]) - attr_offset[attr]) * scale
|
|
326
|
+
ratings[ident][attr] = val
|
|
327
|
+
except Exception:
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
# ------------------------------------------------------------------
|
|
331
|
+
# Public API for adding multiway rankings
|
|
332
|
+
# ------------------------------------------------------------------
|
|
333
|
+
def add_multiway_ranking(self, attr: str, ranking: List[str]) -> None:
|
|
334
|
+
"""Record a multiway ranking for a given attribute.
|
|
335
|
+
|
|
336
|
+
Multiway rankings are stored but not used by the current BT
|
|
337
|
+
implementation. They are retained for potential future
|
|
338
|
+
extensions where a Plackett–Luce model could be incorporated.
|
|
339
|
+
"""
|
|
340
|
+
if attr not in self.history_multi:
|
|
341
|
+
self.history_multi[attr] = []
|
|
342
|
+
self.history_multi[attr].append(ranking)
|
|
343
|
+
|
|
344
|
+
def _attributes_as_dict(self) -> Dict[str, str]:
|
|
345
|
+
if isinstance(self.cfg.attributes, dict):
|
|
346
|
+
return dict(self.cfg.attributes)
|
|
347
|
+
return {attr: "" for attr in self.cfg.attributes}
|
|
348
|
+
|
|
349
|
+
def _split_rate_kwargs(self, overrides: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
350
|
+
merged: Dict[str, Any] = {}
|
|
351
|
+
if self.cfg.rate_kwargs:
|
|
352
|
+
merged.update(self.cfg.rate_kwargs)
|
|
353
|
+
if overrides:
|
|
354
|
+
merged.update(overrides)
|
|
355
|
+
config_fields = {f.name for f in fields(RateConfig)}
|
|
356
|
+
cfg_kwargs: Dict[str, Any] = {}
|
|
357
|
+
run_kwargs: Dict[str, Any] = {}
|
|
358
|
+
for key, value in merged.items():
|
|
359
|
+
if key in config_fields:
|
|
360
|
+
cfg_kwargs[key] = value
|
|
361
|
+
else:
|
|
362
|
+
run_kwargs[key] = value
|
|
363
|
+
return cfg_kwargs, run_kwargs
|
|
364
|
+
|
|
365
|
+
async def _run_rate_pass(
|
|
366
|
+
self,
|
|
367
|
+
df: pd.DataFrame,
|
|
368
|
+
column_name: str,
|
|
369
|
+
*,
|
|
370
|
+
save_dir: str,
|
|
371
|
+
file_name: str,
|
|
372
|
+
reset_files: bool,
|
|
373
|
+
rate_kwargs: Optional[Dict[str, Any]] = None,
|
|
374
|
+
runtime_kwargs: Optional[Dict[str, Any]] = None,
|
|
375
|
+
) -> pd.DataFrame:
|
|
376
|
+
cfg_overrides, run_kwargs = self._split_rate_kwargs(rate_kwargs)
|
|
377
|
+
rate_cfg = RateConfig(
|
|
378
|
+
attributes=self._attributes_as_dict(),
|
|
379
|
+
save_dir=save_dir,
|
|
380
|
+
file_name=file_name,
|
|
381
|
+
model=self.cfg.model,
|
|
382
|
+
n_parallels=self.cfg.n_parallels,
|
|
383
|
+
n_runs=1,
|
|
384
|
+
use_dummy=self.cfg.use_dummy,
|
|
385
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
386
|
+
modality=self.cfg.modality,
|
|
387
|
+
n_attributes_per_run=self.cfg.n_attributes_per_run,
|
|
388
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
389
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
390
|
+
max_timeout=self.cfg.max_timeout,
|
|
391
|
+
)
|
|
392
|
+
for key, value in cfg_overrides.items():
|
|
393
|
+
setattr(rate_cfg, key, value)
|
|
394
|
+
combined_kwargs = dict(run_kwargs)
|
|
395
|
+
if runtime_kwargs:
|
|
396
|
+
combined_kwargs.update(runtime_kwargs)
|
|
397
|
+
combined_kwargs.setdefault("web_search", self.cfg.modality == "web")
|
|
398
|
+
rate_task = Rate(rate_cfg)
|
|
399
|
+
return await rate_task.run(
|
|
400
|
+
df,
|
|
401
|
+
column_name,
|
|
402
|
+
reset_files=reset_files,
|
|
403
|
+
**combined_kwargs,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def _seed_ratings_from_rate(
|
|
407
|
+
self,
|
|
408
|
+
rate_df: pd.DataFrame,
|
|
409
|
+
*,
|
|
410
|
+
id_column: Optional[str],
|
|
411
|
+
text_column: str,
|
|
412
|
+
item_ids: Sequence[str],
|
|
413
|
+
attr_keys: Sequence[str],
|
|
414
|
+
) -> Dict[str, Dict[str, float]]:
|
|
415
|
+
if rate_df.empty:
|
|
416
|
+
return {}
|
|
417
|
+
attr_cols = [attr for attr in attr_keys if attr in rate_df.columns]
|
|
418
|
+
if not attr_cols:
|
|
419
|
+
return {}
|
|
420
|
+
if id_column and id_column in rate_df.columns:
|
|
421
|
+
key_series = rate_df[id_column].astype(str)
|
|
422
|
+
elif text_column in rate_df.columns:
|
|
423
|
+
key_series = rate_df[text_column].astype(str).map(
|
|
424
|
+
lambda x: hashlib.sha1(x.encode()).hexdigest()[:8]
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
return {}
|
|
428
|
+
stage_df = pd.DataFrame({"_id": key_series})
|
|
429
|
+
for attr in attr_cols:
|
|
430
|
+
stage_df[attr] = pd.to_numeric(rate_df[attr], errors="coerce")
|
|
431
|
+
grouped = stage_df.groupby("_id")[attr_cols].mean()
|
|
432
|
+
seeds: Dict[str, Dict[str, float]] = {}
|
|
433
|
+
for attr in attr_cols:
|
|
434
|
+
series = grouped[attr].dropna()
|
|
435
|
+
if series.empty:
|
|
436
|
+
continue
|
|
437
|
+
mean_val = float(series.mean())
|
|
438
|
+
centred = series - mean_val
|
|
439
|
+
for item_id, value in centred.items():
|
|
440
|
+
seeds.setdefault(item_id, {})[attr] = float(value)
|
|
441
|
+
# Only retain seeds for items that will appear in the ranking loop
|
|
442
|
+
return {item_id: seeds[item_id] for item_id in item_ids if item_id in seeds}
|
|
443
|
+
|
|
444
|
+
# ------------------------------------------------------------------
|
|
445
|
+
# BT / PL fitting utilities
|
|
446
|
+
# ------------------------------------------------------------------
|
|
447
|
+
def _fit_bt(
|
|
448
|
+
self,
|
|
449
|
+
item_ids: List[str],
|
|
450
|
+
outcomes: List[Tuple[str, str]],
|
|
451
|
+
pseudo: float,
|
|
452
|
+
max_iter: int,
|
|
453
|
+
tol: float,
|
|
454
|
+
return_info: bool = False,
|
|
455
|
+
) -> Union[Dict[str, float], Tuple[Dict[str, float], np.ndarray, np.ndarray]]:
|
|
456
|
+
"""Fit a Bradley–Terry model given pairwise outcomes.
|
|
457
|
+
|
|
458
|
+
Parameters
|
|
459
|
+
----------
|
|
460
|
+
item_ids:
|
|
461
|
+
List of unique item identifiers.
|
|
462
|
+
outcomes:
|
|
463
|
+
List of tuples ``(winner, loser)`` representing outcomes of
|
|
464
|
+
pairwise matches. Ties can be represented by including
|
|
465
|
+
both ``(a, b)`` and ``(b, a)`` in the list; each entry
|
|
466
|
+
contributes a single increment to the win matrix.
|
|
467
|
+
pseudo:
|
|
468
|
+
Pseudo count added to both win and total match counts. Acts
|
|
469
|
+
as a smoothing prior.
|
|
470
|
+
max_iter, tol:
|
|
471
|
+
Control convergence of the iterative fixed‑point updates.
|
|
472
|
+
return_info:
|
|
473
|
+
If ``True`` return the intermediate matrices ``n_ij`` and
|
|
474
|
+
``p_ij`` for downstream standard error computation.
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
scores : dict
|
|
479
|
+
Mapping from item identifier to estimated log‑skill.
|
|
480
|
+
(scores, n_ij, p_ij) : tuple
|
|
481
|
+
When ``return_info`` is ``True``, also return the total
|
|
482
|
+
match counts and predicted win probabilities for each pair.
|
|
483
|
+
"""
|
|
484
|
+
n = len(item_ids)
|
|
485
|
+
idx = {item: i for i, item in enumerate(item_ids)}
|
|
486
|
+
# win matrix; wins[i,j] counts how many times i beat j
|
|
487
|
+
wins = np.zeros((n, n), dtype=float)
|
|
488
|
+
for w, l in outcomes:
|
|
489
|
+
if w in idx and l in idx:
|
|
490
|
+
wins[idx[w], idx[l]] += 1.0
|
|
491
|
+
# total matches between each pair
|
|
492
|
+
n_ij = wins + wins.T
|
|
493
|
+
# total wins for each item
|
|
494
|
+
w_i = wins.sum(axis=1)
|
|
495
|
+
# add pseudo counts
|
|
496
|
+
n_ij += pseudo
|
|
497
|
+
w_i += pseudo
|
|
498
|
+
# initialise skill parameters uniformly
|
|
499
|
+
p = np.ones(n, dtype=float)
|
|
500
|
+
for _ in range(max_iter):
|
|
501
|
+
# denominator for each player in the fixed point update
|
|
502
|
+
denom = (n_ij / (p[:, None] + p[None, :])).sum(axis=1)
|
|
503
|
+
p_new = w_i / denom
|
|
504
|
+
if np.max(np.abs(p_new - p)) < tol:
|
|
505
|
+
p = p_new
|
|
506
|
+
break
|
|
507
|
+
p = p_new
|
|
508
|
+
# convert to log space and centre at zero mean
|
|
509
|
+
s = np.log(p)
|
|
510
|
+
s -= s.mean()
|
|
511
|
+
if not return_info:
|
|
512
|
+
return {item: float(val) for item, val in zip(item_ids, s)}
|
|
513
|
+
# predicted win probabilities between each pair
|
|
514
|
+
exp_s = np.exp(s)
|
|
515
|
+
p_ij = exp_s[:, None] / (exp_s[:, None] + exp_s[None, :])
|
|
516
|
+
return {item: float(val) for item, val in zip(item_ids, s)}, n_ij, p_ij
|
|
517
|
+
|
|
518
|
+
def _bt_standard_errors(
|
|
519
|
+
self,
|
|
520
|
+
s: np.ndarray,
|
|
521
|
+
n_ij: np.ndarray,
|
|
522
|
+
p_ij: np.ndarray,
|
|
523
|
+
ridge: float,
|
|
524
|
+
) -> np.ndarray:
|
|
525
|
+
"""Estimate standard errors for BT skill parameters.
|
|
526
|
+
|
|
527
|
+
The observed Fisher information for the Bradley–Terry model is given by
|
|
528
|
+
``I = diag(q 1) - q`` where ``q = n_ij * p_ij * (1 - p_ij)`` encodes the
|
|
529
|
+
uncertainty contributed by each pairwise comparison (Ford, 1957). The
|
|
530
|
+
estimates satisfy a sum-to-zero constraint, so the Fisher information is
|
|
531
|
+
rank deficient with the all-ones vector in its null space. Instead of
|
|
532
|
+
selecting an arbitrary reference item (which previously produced
|
|
533
|
+
inflated standard errors for that reference when it received few
|
|
534
|
+
comparisons), we project the matrix onto the constrained subspace and
|
|
535
|
+
take its Moore–Penrose pseudoinverse. A small ridge term stabilises the
|
|
536
|
+
inversion for sparse comparison graphs. The standard error for item ``i``
|
|
537
|
+
is the square root of the ``i``-th diagonal entry of the resulting
|
|
538
|
+
covariance matrix.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
s : np.ndarray
|
|
543
|
+
Array of estimated log-skills for each item.
|
|
544
|
+
n_ij : np.ndarray
|
|
545
|
+
Matrix of total match counts between items (wins + losses).
|
|
546
|
+
p_ij : np.ndarray
|
|
547
|
+
Matrix of predicted win probabilities between items.
|
|
548
|
+
ridge : float
|
|
549
|
+
Small constant added to the diagonal of the projected Fisher
|
|
550
|
+
information matrix for numerical stability.
|
|
551
|
+
|
|
552
|
+
Returns
|
|
553
|
+
-------
|
|
554
|
+
np.ndarray
|
|
555
|
+
Array of standard errors corresponding to each element of ``s``.
|
|
556
|
+
"""
|
|
557
|
+
|
|
558
|
+
n = len(s)
|
|
559
|
+
if n == 0:
|
|
560
|
+
return np.array([], dtype=float)
|
|
561
|
+
if n == 1:
|
|
562
|
+
return np.zeros(1, dtype=float)
|
|
563
|
+
|
|
564
|
+
q_ij = n_ij * p_ij * (1 - p_ij)
|
|
565
|
+
diag = q_ij.sum(axis=1)
|
|
566
|
+
I = np.diag(diag) - q_ij
|
|
567
|
+
I = np.nan_to_num(I)
|
|
568
|
+
ones = np.ones((n, 1))
|
|
569
|
+
proj = np.eye(n) - ones @ ones.T / n
|
|
570
|
+
I_proj = proj @ I @ proj
|
|
571
|
+
I_proj[np.diag_indices(n)] += ridge
|
|
572
|
+
try:
|
|
573
|
+
cov = np.linalg.pinv(I_proj, rcond=1e-12)
|
|
574
|
+
except np.linalg.LinAlgError:
|
|
575
|
+
cov = np.linalg.pinv(np.nan_to_num(I_proj), rcond=1e-12)
|
|
576
|
+
cov = proj @ cov @ proj
|
|
577
|
+
cov = 0.5 * (cov + cov.T)
|
|
578
|
+
se = np.sqrt(np.clip(np.diag(cov), 0.0, None))
|
|
579
|
+
return np.nan_to_num(se)
|
|
580
|
+
|
|
581
|
+
def _fit_pl(
|
|
582
|
+
self,
|
|
583
|
+
item_ids: List[str],
|
|
584
|
+
rankings: List[List[str]],
|
|
585
|
+
pseudo: float,
|
|
586
|
+
max_iter: int,
|
|
587
|
+
tol: float,
|
|
588
|
+
) -> Dict[str, float]:
|
|
589
|
+
"""Fit a Plackett–Luce model for multiway rankings.
|
|
590
|
+
|
|
591
|
+
When every ranking is of length two this reduces to the BT
|
|
592
|
+
model and defers to :meth:`_fit_bt`. If no rankings are
|
|
593
|
+
provided a zero‑centred score is returned for each item. See
|
|
594
|
+
Hunter (2004) for details on the fitting procedure.
|
|
595
|
+
"""
|
|
596
|
+
if not rankings:
|
|
597
|
+
return {i: 0.0 for i in item_ids}
|
|
598
|
+
# if all rankings are of length 2, delegate to BT
|
|
599
|
+
if all(len(r) == 2 for r in rankings):
|
|
600
|
+
outcomes = [(r[0], r[1]) for r in rankings]
|
|
601
|
+
return self._fit_bt(
|
|
602
|
+
item_ids, outcomes, pseudo, max_iter, tol, return_info=False
|
|
603
|
+
)
|
|
604
|
+
n = len(item_ids)
|
|
605
|
+
idx = {item: i for i, item in enumerate(item_ids)}
|
|
606
|
+
w_i = np.zeros(n, dtype=float)
|
|
607
|
+
rankings_idx = []
|
|
608
|
+
for r in rankings:
|
|
609
|
+
r_idx = [idx[x] for x in r if x in idx]
|
|
610
|
+
if len(r_idx) < 2:
|
|
611
|
+
continue
|
|
612
|
+
rankings_idx.append(r_idx)
|
|
613
|
+
for i_ in r_idx:
|
|
614
|
+
w_i[i_] += 1.0
|
|
615
|
+
if len(rankings_idx) == 0:
|
|
616
|
+
return {i: 0.0 for i in item_ids}
|
|
617
|
+
w_i += pseudo
|
|
618
|
+
p = np.ones(n, dtype=float)
|
|
619
|
+
for _ in range(max_iter):
|
|
620
|
+
denom = np.zeros(n, dtype=float)
|
|
621
|
+
for r_idx in rankings_idx:
|
|
622
|
+
remaining = np.array(r_idx, dtype=int)
|
|
623
|
+
sum_p = p[remaining].sum()
|
|
624
|
+
for i_ in r_idx:
|
|
625
|
+
denom[i_] += 1.0 / sum_p
|
|
626
|
+
sum_p -= p[i_]
|
|
627
|
+
denom[denom == 0] = 1e-12
|
|
628
|
+
p_new = w_i / denom
|
|
629
|
+
if np.max(np.abs(p_new - p)) < tol:
|
|
630
|
+
p = p_new
|
|
631
|
+
break
|
|
632
|
+
p = p_new
|
|
633
|
+
s = np.log(p)
|
|
634
|
+
s -= s.mean()
|
|
635
|
+
return {item: float(val) for item, val in zip(item_ids, s)}
|
|
636
|
+
|
|
637
|
+
# ------------------------------------------------------------------
|
|
638
|
+
# Pairing strategies
|
|
639
|
+
# ------------------------------------------------------------------
|
|
640
|
+
def _pairs_random(
|
|
641
|
+
self, item_ids: List[str], texts_by_id: Dict[str, str], mpr: int
|
|
642
|
+
) -> List[Tuple[Tuple[str, str], Tuple[str, str]]]:
|
|
643
|
+
"""Return a set of random, unique pairs for the given items."""
|
|
644
|
+
pairs_set: Set[Tuple[str, str]] = set()
|
|
645
|
+
for a in item_ids:
|
|
646
|
+
others = [x for x in item_ids if x != a]
|
|
647
|
+
if not others:
|
|
648
|
+
continue
|
|
649
|
+
k = min(mpr, len(others))
|
|
650
|
+
opponents = self.rng.sample(others, k)
|
|
651
|
+
for b in opponents:
|
|
652
|
+
pairs_set.add(tuple(sorted((a, b))))
|
|
653
|
+
return [((a, texts_by_id[a]), (b, texts_by_id[b])) for a, b in pairs_set]
|
|
654
|
+
|
|
655
|
+
def _pairs_adjacent(
|
|
656
|
+
self,
|
|
657
|
+
item_ids: List[str],
|
|
658
|
+
texts_by_id: Dict[str, str],
|
|
659
|
+
current_ratings: Dict[str, float],
|
|
660
|
+
mpr: int,
|
|
661
|
+
) -> List[Tuple[Tuple[str, str], Tuple[str, str]]]:
|
|
662
|
+
"""Pair each item with its nearest neighbours in rating space."""
|
|
663
|
+
pairs_set: Set[Tuple[str, str]] = set()
|
|
664
|
+
sorted_ids = sorted(item_ids, key=lambda i: current_ratings[i])
|
|
665
|
+
n = len(sorted_ids)
|
|
666
|
+
for i, a in enumerate(sorted_ids):
|
|
667
|
+
for off in range(1, mpr + 1):
|
|
668
|
+
b = sorted_ids[(i + off) % n]
|
|
669
|
+
if a == b:
|
|
670
|
+
continue
|
|
671
|
+
pairs_set.add(tuple(sorted((a, b))))
|
|
672
|
+
# small amount of random exploration to avoid pathological pairings
|
|
673
|
+
n_random_targets = int(self._EXPLORE_FRAC * n * mpr)
|
|
674
|
+
for _ in range(n_random_targets):
|
|
675
|
+
if n < 2:
|
|
676
|
+
break
|
|
677
|
+
a, b = self.rng.sample(item_ids, 2)
|
|
678
|
+
pairs_set.add(tuple(sorted((a, b))))
|
|
679
|
+
return [((a, texts_by_id[a]), (b, texts_by_id[b])) for a, b in pairs_set]
|
|
680
|
+
|
|
681
|
+
def _pairs_info_gain(
|
|
682
|
+
self,
|
|
683
|
+
item_ids: List[str],
|
|
684
|
+
texts_by_id: Dict[str, str],
|
|
685
|
+
current_ratings: Dict[str, float],
|
|
686
|
+
se_agg: Dict[str, float],
|
|
687
|
+
mpr: int,
|
|
688
|
+
) -> List[Tuple[Tuple[str, str], Tuple[str, str]]]:
|
|
689
|
+
"""Select pairs by maximising expected information gain while ensuring
|
|
690
|
+
that every item participates in the prescribed number of matches.
|
|
691
|
+
|
|
692
|
+
This implementation differs from the original heuristics by
|
|
693
|
+
considering a bounded set of candidate pairs that scales with the
|
|
694
|
+
number of items. Each pair is assigned a score based on the
|
|
695
|
+
expected reduction in uncertainty (estimated from the current
|
|
696
|
+
ratings and aggregated standard errors). Pairs with larger
|
|
697
|
+
scores are chosen first, subject to the constraint that each
|
|
698
|
+
item is matched exactly ``mpr`` times. If some items remain
|
|
699
|
+
unmatched after exhausting the scored pairs, additional pairs
|
|
700
|
+
are filled in randomly to satisfy the per‑item quota.
|
|
701
|
+
"""
|
|
702
|
+
n = len(item_ids)
|
|
703
|
+
if n < 2:
|
|
704
|
+
return []
|
|
705
|
+
max_pairs = max(1, self._MAX_CANDIDATE_PAIRS_PER_ROUND)
|
|
706
|
+
desired_neighbors = max_pairs // max(1, n)
|
|
707
|
+
candidate_neighbors = max(
|
|
708
|
+
mpr, min(self._CANDIDATE_NEIGHBORS, desired_neighbors)
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
def logistic_clip(x: float) -> float:
|
|
712
|
+
if x > 50:
|
|
713
|
+
return 1.0
|
|
714
|
+
if x < -50:
|
|
715
|
+
return 0.0
|
|
716
|
+
return 1.0 / (1.0 + np.exp(-x))
|
|
717
|
+
|
|
718
|
+
ids_sorted = sorted(item_ids, key=lambda i: current_ratings[i])
|
|
719
|
+
idx_of = {i_id: k for k, i_id in enumerate(ids_sorted)}
|
|
720
|
+
num_high_se = max(1, int(self._HIGH_SE_FRAC * n))
|
|
721
|
+
high_se_ids = sorted(item_ids, key=lambda i: se_agg.get(i, 1.0), reverse=True)[
|
|
722
|
+
:num_high_se
|
|
723
|
+
]
|
|
724
|
+
candidate_pairs_set: Set[Tuple[str, str]] = set()
|
|
725
|
+
for i_id in item_ids:
|
|
726
|
+
pos = idx_of[i_id]
|
|
727
|
+
lower = max(0, pos - candidate_neighbors)
|
|
728
|
+
upper = min(n, pos + candidate_neighbors + 1)
|
|
729
|
+
for j in ids_sorted[lower:upper]:
|
|
730
|
+
if i_id == j:
|
|
731
|
+
continue
|
|
732
|
+
candidate_pairs_set.add(tuple(sorted((i_id, j))))
|
|
733
|
+
for hs in high_se_ids:
|
|
734
|
+
others = [x for x in item_ids if x != hs]
|
|
735
|
+
k = min(candidate_neighbors, len(others))
|
|
736
|
+
samp = self.rng.sample(others, k)
|
|
737
|
+
for j in samp:
|
|
738
|
+
candidate_pairs_set.add(tuple(sorted((hs, j))))
|
|
739
|
+
remaining_capacity = max_pairs - len(candidate_pairs_set)
|
|
740
|
+
n_random_targets = int(self._EXPLORE_FRAC * n * mpr)
|
|
741
|
+
if remaining_capacity > 0:
|
|
742
|
+
n_random_targets = min(n_random_targets, remaining_capacity)
|
|
743
|
+
for _ in range(n_random_targets):
|
|
744
|
+
if n < 2:
|
|
745
|
+
break
|
|
746
|
+
a, b = self.rng.sample(item_ids, 2)
|
|
747
|
+
candidate_pairs_set.add(tuple(sorted((a, b))))
|
|
748
|
+
partners_count = {i: 0 for i in item_ids}
|
|
749
|
+
for a, b in candidate_pairs_set:
|
|
750
|
+
partners_count[a] += 1
|
|
751
|
+
partners_count[b] += 1
|
|
752
|
+
for i_id in item_ids:
|
|
753
|
+
while partners_count[i_id] < mpr:
|
|
754
|
+
potential = [x for x in item_ids if x != i_id]
|
|
755
|
+
if not potential:
|
|
756
|
+
break
|
|
757
|
+
j = self.rng.choice(potential)
|
|
758
|
+
pair = tuple(sorted((i_id, j)))
|
|
759
|
+
if pair not in candidate_pairs_set:
|
|
760
|
+
candidate_pairs_set.add(pair)
|
|
761
|
+
partners_count[i_id] += 1
|
|
762
|
+
partners_count[j] += 1
|
|
763
|
+
else:
|
|
764
|
+
partners_count[i_id] += 1
|
|
765
|
+
partners_count[j] += 1
|
|
766
|
+
scored_pairs: List[Tuple[float, str, str]] = []
|
|
767
|
+
for a, b in candidate_pairs_set:
|
|
768
|
+
diff = current_ratings[a] - current_ratings[b]
|
|
769
|
+
p = logistic_clip(diff)
|
|
770
|
+
outcome_var = p * (1 - p)
|
|
771
|
+
var_a = se_agg.get(a, 1.0) ** 2
|
|
772
|
+
var_b = se_agg.get(b, 1.0) ** 2
|
|
773
|
+
param_unc = var_a + var_b
|
|
774
|
+
# Encourage comparisons between similarly‑rated items (high information
|
|
775
|
+
# gain) while still prioritising uncertain pairs. The closeness term
|
|
776
|
+
# dampens pairings with large rating gaps to tease out subtle ordering
|
|
777
|
+
# differences.
|
|
778
|
+
closeness = 1.0 / (1.0 + abs(diff))
|
|
779
|
+
score = outcome_var * param_unc * closeness
|
|
780
|
+
scored_pairs.append((score, a, b))
|
|
781
|
+
scored_pairs.sort(key=lambda x: x[0], reverse=True)
|
|
782
|
+
needed: Dict[str, int] = {i: mpr for i in item_ids}
|
|
783
|
+
pairs_selected: List[Tuple[str, str]] = []
|
|
784
|
+
pairs_seen: Set[Tuple[str, str]] = set()
|
|
785
|
+
for score, a, b in scored_pairs:
|
|
786
|
+
if needed[a] > 0 and needed[b] > 0:
|
|
787
|
+
tup = (a, b) if a < b else (b, a)
|
|
788
|
+
if tup in pairs_seen:
|
|
789
|
+
continue
|
|
790
|
+
pairs_selected.append(tup)
|
|
791
|
+
pairs_seen.add(tup)
|
|
792
|
+
needed[a] -= 1
|
|
793
|
+
needed[b] -= 1
|
|
794
|
+
while any(cnt > 0 for cnt in needed.values()):
|
|
795
|
+
ids_needing = [i for i, cnt in needed.items() if cnt > 0]
|
|
796
|
+
if not ids_needing:
|
|
797
|
+
break
|
|
798
|
+
# Choose an item that still needs matches
|
|
799
|
+
a = self.rng.choice(ids_needing)
|
|
800
|
+
# Try to pair it with any other item (not just those needing matches) to avoid self‑pairs
|
|
801
|
+
potential = [x for x in item_ids if x != a]
|
|
802
|
+
if not potential:
|
|
803
|
+
# Degenerate case: only one item exists; cannot form a valid pair
|
|
804
|
+
break
|
|
805
|
+
b = self.rng.choice(potential)
|
|
806
|
+
tup = (a, b) if a < b else (b, a)
|
|
807
|
+
pairs_selected.append(tup)
|
|
808
|
+
needed[a] -= 1
|
|
809
|
+
needed[b] -= 1
|
|
810
|
+
return [((a, texts_by_id[a]), (b, texts_by_id[b])) for a, b in pairs_selected]
|
|
811
|
+
|
|
812
|
+
def _generate_pairs(
|
|
813
|
+
self,
|
|
814
|
+
item_ids: List[str],
|
|
815
|
+
texts_by_id: Dict[str, str],
|
|
816
|
+
current_ratings: Optional[Dict[str, float]],
|
|
817
|
+
se_agg: Optional[Dict[str, float]],
|
|
818
|
+
) -> List[Tuple[Tuple[str, str], Tuple[str, str]]]:
|
|
819
|
+
"""Dispatch to the appropriate pairing strategy."""
|
|
820
|
+
mpr = max(1, self.cfg.matches_per_round)
|
|
821
|
+
# Always use information gain pairing to guarantee exact match counts
|
|
822
|
+
if current_ratings is None:
|
|
823
|
+
current_ratings = {i: 0.0 for i in item_ids}
|
|
824
|
+
if se_agg is None or len(se_agg) != len(item_ids):
|
|
825
|
+
se_full = {i: 1.0 for i in item_ids}
|
|
826
|
+
else:
|
|
827
|
+
se_full = se_agg
|
|
828
|
+
return self._pairs_info_gain(
|
|
829
|
+
item_ids, texts_by_id, current_ratings, se_full, mpr
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
async def _catch_up_existing_rounds(
|
|
833
|
+
self,
|
|
834
|
+
new_ids: List[str],
|
|
835
|
+
round_indices: List[int],
|
|
836
|
+
item_ids: List[str],
|
|
837
|
+
texts_by_id: Dict[str, str],
|
|
838
|
+
images_by_id: Dict[str, List[str]],
|
|
839
|
+
audio_by_id: Dict[str, List[Dict[str, str]]],
|
|
840
|
+
attr_batches: List[List[str]],
|
|
841
|
+
attr_keys: List[str],
|
|
842
|
+
history_pairs: Dict[str, List[Tuple[str, str]]],
|
|
843
|
+
ratings: Dict[str, Dict[str, float]],
|
|
844
|
+
se_store: Dict[str, Dict[str, float]],
|
|
845
|
+
base_name: str,
|
|
846
|
+
df_proc: pd.DataFrame,
|
|
847
|
+
_write_checkpoint: Callable[[], None],
|
|
848
|
+
current_ratings: Optional[Dict[str, float]],
|
|
849
|
+
se_agg_local: Optional[Dict[str, float]],
|
|
850
|
+
reset_files: bool,
|
|
851
|
+
**kwargs: Any,
|
|
852
|
+
) -> None:
|
|
853
|
+
if not new_ids:
|
|
854
|
+
return
|
|
855
|
+
for rnd in round_indices:
|
|
856
|
+
round_path = os.path.join(self.cfg.save_dir, f"{base_name}_round{rnd}.csv")
|
|
857
|
+
if not os.path.exists(round_path):
|
|
858
|
+
continue
|
|
859
|
+
try:
|
|
860
|
+
df_round = pd.read_csv(round_path)
|
|
861
|
+
except Exception:
|
|
862
|
+
continue
|
|
863
|
+
counts: Dict[str, int] = {}
|
|
864
|
+
if {"IdA", "IdB"}.issubset(df_round.columns):
|
|
865
|
+
for a, b in zip(df_round["IdA"], df_round["IdB"]):
|
|
866
|
+
counts[str(a)] = counts.get(str(a), 0) + 1
|
|
867
|
+
counts[str(b)] = counts.get(str(b), 0) + 1
|
|
868
|
+
else:
|
|
869
|
+
for ident in df_round.get("Identifier", []):
|
|
870
|
+
parts = str(ident).split("|")
|
|
871
|
+
if len(parts) != 5:
|
|
872
|
+
continue
|
|
873
|
+
_, _, _, id_a, id_b = parts
|
|
874
|
+
counts[id_a] = counts.get(id_a, 0) + 1
|
|
875
|
+
counts[id_b] = counts.get(id_b, 0) + 1
|
|
876
|
+
pairs_needed: List[Tuple[str, str]] = []
|
|
877
|
+
for nid in new_ids:
|
|
878
|
+
needed = self.cfg.matches_per_round - counts.get(nid, 0)
|
|
879
|
+
if needed <= 0:
|
|
880
|
+
continue
|
|
881
|
+
opponents = [i for i in item_ids if i != nid]
|
|
882
|
+
self.rng.shuffle(opponents)
|
|
883
|
+
for opp in opponents[:needed]:
|
|
884
|
+
pairs_needed.append((nid, opp))
|
|
885
|
+
if not pairs_needed:
|
|
886
|
+
continue
|
|
887
|
+
announce_prompt_rendering(
|
|
888
|
+
"Rank:catchup", len(attr_batches) * len(pairs_needed)
|
|
889
|
+
)
|
|
890
|
+
prompts: List[str] = []
|
|
891
|
+
ids: List[str] = []
|
|
892
|
+
pair_images: Dict[str, List[str]] = {}
|
|
893
|
+
pair_audio: Dict[str, List[Dict[str, str]]] = {}
|
|
894
|
+
pair_pdfs: Dict[str, List[Dict[str, str]]] = {}
|
|
895
|
+
pair_pdfs: Dict[str, List[Dict[str, str]]] = {}
|
|
896
|
+
meta_map: Dict[str, Tuple[int, int, str, str]] = {}
|
|
897
|
+
id_to_circle_first: Dict[str, bool] = {}
|
|
898
|
+
for batch_idx, batch in enumerate(attr_batches):
|
|
899
|
+
attr_def_map = (
|
|
900
|
+
{a: self.cfg.attributes[a] for a in batch}
|
|
901
|
+
if isinstance(self.cfg.attributes, dict)
|
|
902
|
+
else {a: "" for a in batch}
|
|
903
|
+
)
|
|
904
|
+
for pair_idx, (id_a, id_b) in enumerate(pairs_needed):
|
|
905
|
+
raw_ident = f"catchup|{rnd}|{batch_idx}|{pair_idx}|{id_a}|{id_b}"
|
|
906
|
+
sha8 = hashlib.sha1(raw_ident.encode()).hexdigest()[:8]
|
|
907
|
+
circle_first_flag = (
|
|
908
|
+
self.cfg.circle_first
|
|
909
|
+
if self.cfg.circle_first is not None
|
|
910
|
+
else self.rng.random() < 0.5
|
|
911
|
+
)
|
|
912
|
+
id_to_circle_first[sha8] = circle_first_flag
|
|
913
|
+
prompts.append(
|
|
914
|
+
self.template.render(
|
|
915
|
+
entry_circle=texts_by_id[id_a],
|
|
916
|
+
entry_square=texts_by_id[id_b],
|
|
917
|
+
attributes=attr_def_map,
|
|
918
|
+
additional_instructions=self.cfg.additional_instructions or "",
|
|
919
|
+
modality=self.cfg.modality,
|
|
920
|
+
circle_first=circle_first_flag,
|
|
921
|
+
)
|
|
922
|
+
)
|
|
923
|
+
ids.append(sha8)
|
|
924
|
+
meta_map[sha8] = (batch_idx, pair_idx, id_a, id_b)
|
|
925
|
+
if images_by_id:
|
|
926
|
+
imgs = []
|
|
927
|
+
ia = images_by_id.get(id_a, [])
|
|
928
|
+
ib = images_by_id.get(id_b, [])
|
|
929
|
+
if circle_first_flag:
|
|
930
|
+
if ia:
|
|
931
|
+
imgs.extend(ia)
|
|
932
|
+
if ib:
|
|
933
|
+
imgs.extend(ib)
|
|
934
|
+
else:
|
|
935
|
+
if ib:
|
|
936
|
+
imgs.extend(ib)
|
|
937
|
+
if ia:
|
|
938
|
+
imgs.extend(ia)
|
|
939
|
+
if imgs:
|
|
940
|
+
pair_images[sha8] = imgs
|
|
941
|
+
if audio_by_id:
|
|
942
|
+
auds = []
|
|
943
|
+
aa = audio_by_id.get(id_a, [])
|
|
944
|
+
ab = audio_by_id.get(id_b, [])
|
|
945
|
+
if circle_first_flag:
|
|
946
|
+
if aa:
|
|
947
|
+
auds.extend(aa)
|
|
948
|
+
if ab:
|
|
949
|
+
auds.extend(ab)
|
|
950
|
+
else:
|
|
951
|
+
if ab:
|
|
952
|
+
auds.extend(ab)
|
|
953
|
+
if aa:
|
|
954
|
+
auds.extend(aa)
|
|
955
|
+
if auds:
|
|
956
|
+
pair_audio[sha8] = auds
|
|
957
|
+
if pdfs_by_id:
|
|
958
|
+
pdfs: List[Dict[str, str]] = []
|
|
959
|
+
pa = pdfs_by_id.get(id_a, [])
|
|
960
|
+
pb = pdfs_by_id.get(id_b, [])
|
|
961
|
+
if circle_first_flag:
|
|
962
|
+
if pa:
|
|
963
|
+
pdfs.extend(pa)
|
|
964
|
+
if pb:
|
|
965
|
+
pdfs.extend(pb)
|
|
966
|
+
else:
|
|
967
|
+
if pb:
|
|
968
|
+
pdfs.extend(pb)
|
|
969
|
+
if pa:
|
|
970
|
+
pdfs.extend(pa)
|
|
971
|
+
if pdfs:
|
|
972
|
+
pair_pdfs[sha8] = pdfs
|
|
973
|
+
if not prompts:
|
|
974
|
+
continue
|
|
975
|
+
resp_df = await get_all_responses(
|
|
976
|
+
prompts=prompts,
|
|
977
|
+
identifiers=ids,
|
|
978
|
+
prompt_images=pair_images or None,
|
|
979
|
+
prompt_audio=pair_audio or None,
|
|
980
|
+
prompt_pdfs=pair_pdfs or None,
|
|
981
|
+
n_parallels=self.cfg.n_parallels,
|
|
982
|
+
model=self.cfg.model,
|
|
983
|
+
json_mode=self.cfg.modality != "audio",
|
|
984
|
+
save_path=round_path,
|
|
985
|
+
reset_files=reset_files,
|
|
986
|
+
use_dummy=self.cfg.use_dummy,
|
|
987
|
+
max_timeout=self.cfg.max_timeout,
|
|
988
|
+
max_retries=1,
|
|
989
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
990
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
991
|
+
**kwargs,
|
|
992
|
+
)
|
|
993
|
+
resp_df["Batch"] = resp_df.Identifier.map(
|
|
994
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[0]
|
|
995
|
+
)
|
|
996
|
+
resp_df["Pair"] = resp_df.Identifier.map(
|
|
997
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[1]
|
|
998
|
+
)
|
|
999
|
+
resp_df["IdA"] = resp_df.Identifier.map(
|
|
1000
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[2]
|
|
1001
|
+
)
|
|
1002
|
+
resp_df["IdB"] = resp_df.Identifier.map(
|
|
1003
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[3]
|
|
1004
|
+
)
|
|
1005
|
+
resp_df.to_csv(round_path, index=False)
|
|
1006
|
+
|
|
1007
|
+
async def _coerce_dict(raw: Any) -> Dict[str, Any]:
|
|
1008
|
+
obj = await safest_json(raw)
|
|
1009
|
+
if isinstance(obj, dict):
|
|
1010
|
+
return obj
|
|
1011
|
+
if isinstance(obj, str):
|
|
1012
|
+
obj2 = await safest_json(obj)
|
|
1013
|
+
if isinstance(obj2, dict):
|
|
1014
|
+
return obj2
|
|
1015
|
+
if isinstance(obj, list) and obj:
|
|
1016
|
+
inner = await safest_json(obj[0])
|
|
1017
|
+
if isinstance(inner, dict):
|
|
1018
|
+
return inner
|
|
1019
|
+
return {}
|
|
1020
|
+
|
|
1021
|
+
for ident, resp in zip(resp_df.Identifier, resp_df.Response):
|
|
1022
|
+
meta = meta_map.get(str(ident))
|
|
1023
|
+
if not meta:
|
|
1024
|
+
continue
|
|
1025
|
+
batch_idx, _, id_a, id_b = meta
|
|
1026
|
+
safe_obj = await _coerce_dict(resp)
|
|
1027
|
+
if not safe_obj:
|
|
1028
|
+
continue
|
|
1029
|
+
batch = attr_batches[batch_idx]
|
|
1030
|
+
batch_attr_map = {str(k).strip().lower(): k for k in batch}
|
|
1031
|
+
for attr_raw, winner_raw in safe_obj.items():
|
|
1032
|
+
attr_key_l = str(attr_raw).strip().lower()
|
|
1033
|
+
if attr_key_l not in batch_attr_map:
|
|
1034
|
+
continue
|
|
1035
|
+
real_attr = batch_attr_map[attr_key_l]
|
|
1036
|
+
val = winner_raw
|
|
1037
|
+
if isinstance(val, dict) and "winner" in val:
|
|
1038
|
+
val = val.get("winner")
|
|
1039
|
+
if isinstance(val, str):
|
|
1040
|
+
v = val.strip().lower()
|
|
1041
|
+
else:
|
|
1042
|
+
v = ""
|
|
1043
|
+
if v.startswith(("cir", "c", "left", "text a")):
|
|
1044
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1045
|
+
elif v.startswith(("squ", "b", "right", "text b")):
|
|
1046
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1047
|
+
elif v.startswith("draw") or v.startswith("insufficient"):
|
|
1048
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1049
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1050
|
+
else:
|
|
1051
|
+
continue
|
|
1052
|
+
se_agg_next: Dict[str, float] = {i: 0.0 for i in item_ids}
|
|
1053
|
+
se_agg_counts: Dict[str, int] = {i: 0 for i in item_ids}
|
|
1054
|
+
for attr in attr_keys:
|
|
1055
|
+
outcomes = history_pairs[attr]
|
|
1056
|
+
if len(outcomes) == 0:
|
|
1057
|
+
continue
|
|
1058
|
+
bt_scores, n_ij, p_ij = self._fit_bt(
|
|
1059
|
+
item_ids=item_ids,
|
|
1060
|
+
outcomes=outcomes,
|
|
1061
|
+
pseudo=self.cfg.learning_rate,
|
|
1062
|
+
max_iter=self._MAX_ITER,
|
|
1063
|
+
tol=self._TOL,
|
|
1064
|
+
return_info=True,
|
|
1065
|
+
)
|
|
1066
|
+
for i in item_ids:
|
|
1067
|
+
ratings[i][attr] = bt_scores[i]
|
|
1068
|
+
s_vec = np.array([bt_scores[i] for i in item_ids])
|
|
1069
|
+
se_vec = self._bt_standard_errors(
|
|
1070
|
+
s=s_vec,
|
|
1071
|
+
n_ij=n_ij,
|
|
1072
|
+
p_ij=p_ij,
|
|
1073
|
+
ridge=self._SE_RIDGE,
|
|
1074
|
+
)
|
|
1075
|
+
for i, se_val in zip(item_ids, se_vec):
|
|
1076
|
+
se_store[attr][i] = float(se_val)
|
|
1077
|
+
se_agg_next[i] += float(se_val)
|
|
1078
|
+
se_agg_counts[i] += 1
|
|
1079
|
+
for i in item_ids:
|
|
1080
|
+
if se_agg_counts[i] > 0:
|
|
1081
|
+
se_agg_next[i] /= se_agg_counts[i]
|
|
1082
|
+
else:
|
|
1083
|
+
se_agg_next[i] = 1.0
|
|
1084
|
+
self._last_se_agg = se_agg_next
|
|
1085
|
+
for attr in attr_keys:
|
|
1086
|
+
vals = [ratings[i][attr] for i in item_ids]
|
|
1087
|
+
mean_val = float(np.mean(vals))
|
|
1088
|
+
for i in item_ids:
|
|
1089
|
+
ratings[i][attr] -= mean_val
|
|
1090
|
+
_write_checkpoint()
|
|
1091
|
+
|
|
1092
|
+
async def _run_recursive(
|
|
1093
|
+
self,
|
|
1094
|
+
df: pd.DataFrame,
|
|
1095
|
+
text_column: str,
|
|
1096
|
+
*,
|
|
1097
|
+
id_column: Optional[str],
|
|
1098
|
+
reset_files: bool,
|
|
1099
|
+
**kwargs: Any,
|
|
1100
|
+
) -> pd.DataFrame:
|
|
1101
|
+
attr_dict = self._attributes_as_dict()
|
|
1102
|
+
attr_list = list(attr_dict.keys())
|
|
1103
|
+
if not attr_list:
|
|
1104
|
+
raise ValueError("No attributes provided for ranking")
|
|
1105
|
+
cut_attr = self.cfg.recursive_cut_attr or attr_list[0]
|
|
1106
|
+
if cut_attr not in attr_list:
|
|
1107
|
+
raise ValueError(
|
|
1108
|
+
f"recursive_cut_attr '{self.cfg.recursive_cut_attr}' not present in attributes"
|
|
1109
|
+
)
|
|
1110
|
+
cut_side = (self.cfg.recursive_cut_side or "top").lower()
|
|
1111
|
+
if cut_side not in {"top", "bottom"}:
|
|
1112
|
+
raise ValueError("recursive_cut_side must be 'top' or 'bottom'")
|
|
1113
|
+
|
|
1114
|
+
work_df = df.reset_index(drop=True).copy()
|
|
1115
|
+
if id_column is not None:
|
|
1116
|
+
if id_column not in work_df.columns:
|
|
1117
|
+
raise ValueError(f"id_column '{id_column}' not found in DataFrame")
|
|
1118
|
+
work_df["identifier"] = work_df[id_column].astype(str)
|
|
1119
|
+
else:
|
|
1120
|
+
work_df["identifier"] = work_df[text_column].astype(str).map(
|
|
1121
|
+
lambda x: hashlib.sha1(x.encode()).hexdigest()[:8]
|
|
1122
|
+
)
|
|
1123
|
+
if text_column != "text":
|
|
1124
|
+
work_df = work_df.rename(columns={text_column: "text"})
|
|
1125
|
+
rewrite_col = self.cfg.recursive_rewrite_text_col or "text"
|
|
1126
|
+
if rewrite_col not in work_df.columns:
|
|
1127
|
+
work_df[rewrite_col] = work_df["text"]
|
|
1128
|
+
work_df["identifier"] = work_df["identifier"].astype(str)
|
|
1129
|
+
|
|
1130
|
+
original_cols = list(df.columns)
|
|
1131
|
+
original_df = df.reset_index(drop=True).copy()
|
|
1132
|
+
original_df["identifier"] = work_df["identifier"]
|
|
1133
|
+
latest_text: Dict[str, str] = {
|
|
1134
|
+
ident: txt for ident, txt in zip(work_df["identifier"], work_df["text"])
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
base_folder = os.path.join(
|
|
1138
|
+
self.cfg.save_dir, f"{self.cfg.file_name}_recursive"
|
|
1139
|
+
)
|
|
1140
|
+
os.makedirs(base_folder, exist_ok=True)
|
|
1141
|
+
|
|
1142
|
+
def _compute_stage_zscores(
|
|
1143
|
+
stage_df: pd.DataFrame,
|
|
1144
|
+
) -> Tuple[Dict[str, Dict[str, float]], Dict[str, float]]:
|
|
1145
|
+
zscores: Dict[str, Dict[str, float]] = {attr: {} for attr in attr_list}
|
|
1146
|
+
scales: Dict[str, float] = {attr: 1.0 for attr in attr_list}
|
|
1147
|
+
for attr in attr_list:
|
|
1148
|
+
raw_col = f"{attr}_raw"
|
|
1149
|
+
source_col = raw_col if raw_col in stage_df.columns else attr
|
|
1150
|
+
if source_col not in stage_df.columns:
|
|
1151
|
+
continue
|
|
1152
|
+
series = pd.to_numeric(stage_df[source_col], errors="coerce")
|
|
1153
|
+
mean = series.mean()
|
|
1154
|
+
std = series.std(ddof=0)
|
|
1155
|
+
if std == 0 or np.isnan(std):
|
|
1156
|
+
normed = pd.Series([0.0] * len(series), index=stage_df.index)
|
|
1157
|
+
scales[attr] = 1.0
|
|
1158
|
+
else:
|
|
1159
|
+
normed = (series - mean) / std
|
|
1160
|
+
scales[attr] = float(std) if raw_col in stage_df.columns else 1.0
|
|
1161
|
+
for ident, val in zip(stage_df["identifier"], normed):
|
|
1162
|
+
zscores[attr][str(ident)] = float(val)
|
|
1163
|
+
return zscores, scales
|
|
1164
|
+
|
|
1165
|
+
def _select_next_ids(active_ids: Sequence[str], stage_zs: Dict[str, Dict[str, float]]) -> List[str]:
|
|
1166
|
+
n = len(active_ids)
|
|
1167
|
+
if n <= self.cfg.recursive_min_remaining:
|
|
1168
|
+
return list(active_ids)
|
|
1169
|
+
keep_n = max(
|
|
1170
|
+
int(math.ceil(n * self.cfg.recursive_fraction)),
|
|
1171
|
+
self.cfg.recursive_min_remaining,
|
|
1172
|
+
)
|
|
1173
|
+
scores = {i: stage_zs.get(cut_attr, {}).get(i, 0.0) for i in active_ids}
|
|
1174
|
+
ascending = cut_side == "bottom"
|
|
1175
|
+
ranked = sorted(active_ids, key=lambda x: scores.get(x, 0.0), reverse=not ascending)
|
|
1176
|
+
return ranked[:keep_n]
|
|
1177
|
+
|
|
1178
|
+
def _maybe_rewrite_texts(
|
|
1179
|
+
df_local: pd.DataFrame,
|
|
1180
|
+
ids_to_keep: Sequence[str],
|
|
1181
|
+
stage_idx: int,
|
|
1182
|
+
) -> pd.DataFrame:
|
|
1183
|
+
if self.cfg.recursive_rewrite_func is None:
|
|
1184
|
+
return df_local
|
|
1185
|
+
mask = df_local["identifier"].isin(ids_to_keep)
|
|
1186
|
+
rewritten: List[str] = []
|
|
1187
|
+
for _, row in df_local[mask].iterrows():
|
|
1188
|
+
new_text = self.cfg.recursive_rewrite_func(
|
|
1189
|
+
row[self.cfg.recursive_rewrite_text_col],
|
|
1190
|
+
row["identifier"],
|
|
1191
|
+
stage_idx,
|
|
1192
|
+
)
|
|
1193
|
+
rewritten.append(new_text)
|
|
1194
|
+
latest_text[str(row["identifier"])] = new_text
|
|
1195
|
+
df_local.loc[mask, self.cfg.recursive_rewrite_text_col] = rewritten
|
|
1196
|
+
if (
|
|
1197
|
+
self.cfg.recursive_rewrite_text_col != "text"
|
|
1198
|
+
and "text" in df_local.columns
|
|
1199
|
+
):
|
|
1200
|
+
df_local.loc[mask, "text"] = df_local.loc[
|
|
1201
|
+
mask, self.cfg.recursive_rewrite_text_col
|
|
1202
|
+
]
|
|
1203
|
+
return df_local
|
|
1204
|
+
|
|
1205
|
+
stage_idx = 0
|
|
1206
|
+
final_stage_idx: Optional[int] = None
|
|
1207
|
+
final_stage_df: Optional[pd.DataFrame] = None
|
|
1208
|
+
stage_z_history: Dict[int, Dict[str, Dict[str, float]]] = {}
|
|
1209
|
+
exit_stage: Dict[str, Optional[int]] = {ident: None for ident in work_df["identifier"]}
|
|
1210
|
+
current_ids = list(work_df["identifier"])
|
|
1211
|
+
stage_primer = self.cfg.primer_scores or None
|
|
1212
|
+
|
|
1213
|
+
while current_ids:
|
|
1214
|
+
stage_idx += 1
|
|
1215
|
+
n_current = len(current_ids)
|
|
1216
|
+
is_final_stage = False
|
|
1217
|
+
if n_current <= self.cfg.recursive_min_remaining:
|
|
1218
|
+
is_final_stage = True
|
|
1219
|
+
else:
|
|
1220
|
+
next_keep = max(
|
|
1221
|
+
int(math.ceil(n_current * self.cfg.recursive_fraction)),
|
|
1222
|
+
self.cfg.recursive_min_remaining,
|
|
1223
|
+
)
|
|
1224
|
+
if next_keep <= self.cfg.recursive_min_remaining:
|
|
1225
|
+
is_final_stage = True
|
|
1226
|
+
|
|
1227
|
+
stage_rounds = self.cfg.n_rounds
|
|
1228
|
+
if is_final_stage:
|
|
1229
|
+
final_multiplier = self.cfg.recursive_final_round_multiplier or 3
|
|
1230
|
+
stage_rounds = max(1, stage_rounds * final_multiplier)
|
|
1231
|
+
|
|
1232
|
+
stage_folder = os.path.join(base_folder, f"stage{stage_idx}")
|
|
1233
|
+
os.makedirs(stage_folder, exist_ok=True)
|
|
1234
|
+
stage_cfg = copy.deepcopy(self.cfg)
|
|
1235
|
+
stage_cfg.recursive = False
|
|
1236
|
+
stage_cfg.recursive_rate_first_round = False
|
|
1237
|
+
stage_cfg.save_dir = stage_folder
|
|
1238
|
+
stage_cfg.n_rounds = stage_rounds
|
|
1239
|
+
stage_cfg.file_name = self.cfg.file_name
|
|
1240
|
+
stage_cfg.rate_kwargs = dict(self.cfg.rate_kwargs)
|
|
1241
|
+
stage_cfg.initial_rating_pass = False
|
|
1242
|
+
stage_cfg.primer_scores = stage_primer
|
|
1243
|
+
stage_cfg.primer_center = False
|
|
1244
|
+
|
|
1245
|
+
stage_df_in = work_df[work_df["identifier"].isin(current_ids)].copy()
|
|
1246
|
+
|
|
1247
|
+
if stage_idx == 1 and self.cfg.recursive_rate_first_round:
|
|
1248
|
+
print(
|
|
1249
|
+
"[Rank] Recursive stage 1: running Rate for initial culling "
|
|
1250
|
+
"(disable with recursive_rate_first_round=False)."
|
|
1251
|
+
)
|
|
1252
|
+
stage_df_out = await self._run_rate_pass(
|
|
1253
|
+
stage_df_in,
|
|
1254
|
+
column_name="text",
|
|
1255
|
+
save_dir=stage_folder,
|
|
1256
|
+
file_name=f"stage{stage_idx}_ratings.csv",
|
|
1257
|
+
reset_files=reset_files,
|
|
1258
|
+
runtime_kwargs=kwargs,
|
|
1259
|
+
)
|
|
1260
|
+
stage_df_out["identifier"] = stage_df_in["identifier"].values
|
|
1261
|
+
else:
|
|
1262
|
+
stage_ranker = Rank(stage_cfg, template=self.template)
|
|
1263
|
+
stage_df_out = await stage_ranker.run(
|
|
1264
|
+
stage_df_in,
|
|
1265
|
+
column_name="text",
|
|
1266
|
+
id_column="identifier",
|
|
1267
|
+
reset_files=reset_files,
|
|
1268
|
+
**kwargs,
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
stage_zs, stage_scales = _compute_stage_zscores(stage_df_out)
|
|
1272
|
+
stage_z_history[stage_idx] = stage_zs
|
|
1273
|
+
|
|
1274
|
+
if is_final_stage:
|
|
1275
|
+
for ident in current_ids:
|
|
1276
|
+
exit_stage[ident] = stage_idx
|
|
1277
|
+
final_stage_df = stage_df_out
|
|
1278
|
+
final_stage_idx = stage_idx
|
|
1279
|
+
break
|
|
1280
|
+
|
|
1281
|
+
next_ids = _select_next_ids(current_ids, stage_zs)
|
|
1282
|
+
removed = set(current_ids) - set(next_ids)
|
|
1283
|
+
for ident in removed:
|
|
1284
|
+
exit_stage[ident] = stage_idx
|
|
1285
|
+
stage_primer = {
|
|
1286
|
+
ident: {
|
|
1287
|
+
attr: stage_zs.get(attr, {}).get(ident, 0.0) * stage_scales.get(attr, 1.0)
|
|
1288
|
+
for attr in attr_list
|
|
1289
|
+
}
|
|
1290
|
+
for ident in next_ids
|
|
1291
|
+
}
|
|
1292
|
+
work_df = _maybe_rewrite_texts(work_df, next_ids, stage_idx)
|
|
1293
|
+
current_ids = next_ids
|
|
1294
|
+
|
|
1295
|
+
if final_stage_df is None:
|
|
1296
|
+
final_stage_df = work_df[work_df["identifier"].isin(current_ids)].copy()
|
|
1297
|
+
if stage_idx:
|
|
1298
|
+
final_stage_idx = stage_idx
|
|
1299
|
+
|
|
1300
|
+
# Build final output
|
|
1301
|
+
stage_cols: Dict[str, List[Optional[float]]] = {}
|
|
1302
|
+
final_attr_cols: Dict[str, List[Optional[float]]] = {a: [] for a in attr_list}
|
|
1303
|
+
exit_col: List[Optional[int]] = []
|
|
1304
|
+
|
|
1305
|
+
# build a consolidated map of stage-wise z-scores per identifier
|
|
1306
|
+
stage_order = sorted(stage_z_history.keys())
|
|
1307
|
+
id_list = list(original_df["identifier"])
|
|
1308
|
+
for ident in id_list:
|
|
1309
|
+
ident_stage = exit_stage.get(ident)
|
|
1310
|
+
exit_col.append(ident_stage)
|
|
1311
|
+
final_attr_vals: Dict[str, Optional[float]] = {a: None for a in attr_list}
|
|
1312
|
+
for stage in stage_order:
|
|
1313
|
+
zs = stage_z_history.get(stage, {})
|
|
1314
|
+
for attr in attr_list:
|
|
1315
|
+
col_name = f"stage{stage}_{attr}"
|
|
1316
|
+
stage_cols.setdefault(col_name, []).append(zs.get(attr, {}).get(ident))
|
|
1317
|
+
if final_stage_idx is not None and stage == final_stage_idx:
|
|
1318
|
+
final_attr_vals[attr] = zs.get(attr, {}).get(ident)
|
|
1319
|
+
for attr in attr_list:
|
|
1320
|
+
final_attr_cols[attr].append(final_attr_vals[attr])
|
|
1321
|
+
|
|
1322
|
+
ordered_df = original_df.copy()
|
|
1323
|
+
ordered_df[text_column] = ordered_df["identifier"].map(latest_text)
|
|
1324
|
+
ordered_df["exit_stage"] = exit_col
|
|
1325
|
+
for attr, vals in final_attr_cols.items():
|
|
1326
|
+
ordered_df[attr] = vals
|
|
1327
|
+
for col, vals in stage_cols.items():
|
|
1328
|
+
ordered_df[col] = vals
|
|
1329
|
+
|
|
1330
|
+
# Compute overall ranking: later stages outrank earlier; within a stage, sort by cut_attr z-score
|
|
1331
|
+
cut_scores = {i: ordered_df.loc[idx, cut_attr] if cut_attr in ordered_df else None for idx, i in enumerate(id_list)}
|
|
1332
|
+
def _rank_key(idx: int) -> Tuple[int, float]:
|
|
1333
|
+
ident = id_list[idx]
|
|
1334
|
+
stage_num = ordered_df.loc[idx, "exit_stage"] or 0
|
|
1335
|
+
score = cut_scores.get(ident)
|
|
1336
|
+
if score is None or np.isnan(score):
|
|
1337
|
+
score = -np.inf if cut_side == "top" else np.inf
|
|
1338
|
+
if cut_side == "bottom":
|
|
1339
|
+
score = -score
|
|
1340
|
+
return (stage_num, score)
|
|
1341
|
+
|
|
1342
|
+
order_indices = sorted(range(len(id_list)), key=_rank_key, reverse=True)
|
|
1343
|
+
ordered_df = ordered_df.iloc[order_indices].reset_index(drop=True)
|
|
1344
|
+
ordered_df.insert(0, "overall_rank", range(1, len(ordered_df) + 1))
|
|
1345
|
+
|
|
1346
|
+
final_columns: List[str] = []
|
|
1347
|
+
if text_column in original_cols:
|
|
1348
|
+
for col in original_cols:
|
|
1349
|
+
final_columns.append(col)
|
|
1350
|
+
if col == text_column:
|
|
1351
|
+
final_columns.append("overall_rank")
|
|
1352
|
+
else:
|
|
1353
|
+
final_columns = ["overall_rank"] + [c for c in original_cols]
|
|
1354
|
+
for attr in attr_list:
|
|
1355
|
+
final_columns.append(attr)
|
|
1356
|
+
final_columns.append("exit_stage")
|
|
1357
|
+
final_columns.extend(sorted(stage_cols.keys()))
|
|
1358
|
+
final_columns = [c for c in final_columns if c in ordered_df.columns and c != "identifier"]
|
|
1359
|
+
ordered_df = ordered_df[final_columns]
|
|
1360
|
+
|
|
1361
|
+
final_path = os.path.join(base_folder, "recursive_final.csv")
|
|
1362
|
+
ordered_df.to_csv(final_path, index=False)
|
|
1363
|
+
return ordered_df
|
|
1364
|
+
|
|
1365
|
+
# ------------------------------------------------------------------
|
|
1366
|
+
# Main ranking loop
|
|
1367
|
+
# ------------------------------------------------------------------
|
|
1368
|
+
async def run(
|
|
1369
|
+
self,
|
|
1370
|
+
df: pd.DataFrame,
|
|
1371
|
+
column_name: str,
|
|
1372
|
+
*,
|
|
1373
|
+
id_column: Optional[str] = None,
|
|
1374
|
+
reset_files: bool = False,
|
|
1375
|
+
n_runs: Optional[int] = None,
|
|
1376
|
+
**kwargs: Any,
|
|
1377
|
+
) -> pd.DataFrame:
|
|
1378
|
+
"""Execute the ranking procedure.
|
|
1379
|
+
|
|
1380
|
+
Parameters
|
|
1381
|
+
----------
|
|
1382
|
+
df:
|
|
1383
|
+
Input DataFrame containing the passages to be ranked.
|
|
1384
|
+
column_name:
|
|
1385
|
+
Name of the column in ``df`` that holds the text for each
|
|
1386
|
+
passage.
|
|
1387
|
+
id_column:
|
|
1388
|
+
Optional name of a column that contains stable identifiers
|
|
1389
|
+
for each row. When provided, these identifiers are used to
|
|
1390
|
+
track passages across rounds instead of hashing the text
|
|
1391
|
+
itself. Supplying ``id_column`` is recommended when texts
|
|
1392
|
+
may be rewritten between stages (e.g., during recursive
|
|
1393
|
+
runs).
|
|
1394
|
+
reset_files:
|
|
1395
|
+
If ``True``, ignore any previously saved results and
|
|
1396
|
+
recompute the rankings. Otherwise, if the final output
|
|
1397
|
+
file already exists on disk it will be loaded and returned
|
|
1398
|
+
immediately.
|
|
1399
|
+
n_runs:
|
|
1400
|
+
Deprecated/ignored parameter provided for compatibility
|
|
1401
|
+
with :class:`Rate`. When supplied, a message is printed
|
|
1402
|
+
noting that ``n_rounds`` controls the number of iterations
|
|
1403
|
+
and that ``n_runs`` has no effect.
|
|
1404
|
+
**kwargs:
|
|
1405
|
+
Additional keyword arguments forwarded to
|
|
1406
|
+
:func:`get_all_responses`. When ``initial_rating_pass`` is
|
|
1407
|
+
enabled these arguments are also forwarded to the rating
|
|
1408
|
+
stage. Useful for passing through authentication tokens or
|
|
1409
|
+
tracing settings.
|
|
1410
|
+
|
|
1411
|
+
Returns
|
|
1412
|
+
-------
|
|
1413
|
+
pandas.DataFrame
|
|
1414
|
+
A DataFrame with one row per input passage. For each
|
|
1415
|
+
attribute the DataFrame contains a ``"<attribute>"`` column
|
|
1416
|
+
holding the z‑score, a ``"<attribute>_raw"`` column with the
|
|
1417
|
+
centred Bradley–Terry estimate, and a ``"<attribute>_se"``
|
|
1418
|
+
column with the standard error. The DataFrame is also written
|
|
1419
|
+
to ``save_dir``.
|
|
1420
|
+
"""
|
|
1421
|
+
base_name = os.path.splitext(self.cfg.file_name)[0]
|
|
1422
|
+
self.cfg.attributes = load_persisted_attributes(
|
|
1423
|
+
save_dir=self.cfg.save_dir,
|
|
1424
|
+
incoming=self.cfg.attributes,
|
|
1425
|
+
reset_files=reset_files,
|
|
1426
|
+
task_name="Rank",
|
|
1427
|
+
item_name="attributes",
|
|
1428
|
+
legacy_filename=f"{base_name}_attrs.json",
|
|
1429
|
+
)
|
|
1430
|
+
|
|
1431
|
+
kwargs.setdefault("web_search", self.cfg.modality == "web")
|
|
1432
|
+
if self.cfg.recursive:
|
|
1433
|
+
return await self._run_recursive(
|
|
1434
|
+
df,
|
|
1435
|
+
column_name,
|
|
1436
|
+
id_column=id_column,
|
|
1437
|
+
reset_files=reset_files,
|
|
1438
|
+
**kwargs,
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
# prepare file paths
|
|
1442
|
+
final_path = os.path.join(self.cfg.save_dir, f"{base_name}_final.csv")
|
|
1443
|
+
if n_runs is not None:
|
|
1444
|
+
print(
|
|
1445
|
+
"Parameter 'n_runs' is ignored. Use 'n_rounds' to control the number of iterations. "
|
|
1446
|
+
f"Current n_rounds={self.cfg.n_rounds}."
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
df_proc = df.reset_index(drop=True).copy()
|
|
1450
|
+
warn_if_modality_mismatch(df_proc[column_name].tolist(), self.cfg.modality, column_name=column_name)
|
|
1451
|
+
if id_column is not None:
|
|
1452
|
+
if id_column not in df_proc.columns:
|
|
1453
|
+
raise ValueError(f"id_column '{id_column}' not found in DataFrame")
|
|
1454
|
+
df_proc["_id"] = df_proc[id_column].astype(str)
|
|
1455
|
+
else:
|
|
1456
|
+
# assign a stable identifier per row using an sha1 hash
|
|
1457
|
+
df_proc["_id"] = (
|
|
1458
|
+
df_proc[column_name]
|
|
1459
|
+
.astype(str)
|
|
1460
|
+
.map(lambda x: hashlib.sha1(x.encode()).hexdigest()[:8])
|
|
1461
|
+
)
|
|
1462
|
+
# Determine how many rounds have already been processed when
|
|
1463
|
+
# `reset_files` is False. We look for files named
|
|
1464
|
+
# ``<base_name>_round<k>.csv`` to infer progress. If a final
|
|
1465
|
+
# checkpoint exists for the last round, reuse it; otherwise we
|
|
1466
|
+
# resume from the next incomplete round. When ``reset_files``
|
|
1467
|
+
# is ``True``, all progress is ignored and the computation
|
|
1468
|
+
# restarts from round 0.
|
|
1469
|
+
start_round = 0
|
|
1470
|
+
existing_rounds: List[int] = []
|
|
1471
|
+
if not reset_files:
|
|
1472
|
+
try:
|
|
1473
|
+
for fname in os.listdir(self.cfg.save_dir):
|
|
1474
|
+
if fname.startswith(f"{base_name}_round") and fname.endswith(
|
|
1475
|
+
".csv"
|
|
1476
|
+
):
|
|
1477
|
+
try:
|
|
1478
|
+
idx_str = fname[
|
|
1479
|
+
len(base_name) + 6 : -4
|
|
1480
|
+
] # len("_round") == 6
|
|
1481
|
+
rnd_idx = int(idx_str)
|
|
1482
|
+
existing_rounds.append(rnd_idx)
|
|
1483
|
+
except Exception:
|
|
1484
|
+
continue
|
|
1485
|
+
except Exception:
|
|
1486
|
+
existing_rounds = []
|
|
1487
|
+
if existing_rounds:
|
|
1488
|
+
last_completed = max(existing_rounds)
|
|
1489
|
+
if os.path.exists(final_path):
|
|
1490
|
+
try:
|
|
1491
|
+
final_df = pd.read_csv(final_path)
|
|
1492
|
+
identifier_col = (
|
|
1493
|
+
id_column
|
|
1494
|
+
if id_column and id_column in final_df.columns
|
|
1495
|
+
else column_name
|
|
1496
|
+
)
|
|
1497
|
+
if identifier_col not in final_df.columns:
|
|
1498
|
+
raise ValueError(
|
|
1499
|
+
"Existing ranking output is missing identifier column "
|
|
1500
|
+
f"'{identifier_col}'."
|
|
1501
|
+
)
|
|
1502
|
+
if id_column:
|
|
1503
|
+
final_ids = set(final_df[identifier_col].astype(str))
|
|
1504
|
+
else:
|
|
1505
|
+
final_ids = set(
|
|
1506
|
+
final_df[identifier_col]
|
|
1507
|
+
.astype(str)
|
|
1508
|
+
.map(lambda x: hashlib.sha1(x.encode()).hexdigest()[:8])
|
|
1509
|
+
)
|
|
1510
|
+
if last_completed >= self.cfg.n_rounds - 1 and set(df_proc["_id"]) <= final_ids:
|
|
1511
|
+
return final_df
|
|
1512
|
+
except Exception:
|
|
1513
|
+
pass
|
|
1514
|
+
start_round = last_completed + 1
|
|
1515
|
+
# extract contents and build lookup
|
|
1516
|
+
if self.cfg.modality in {"image", "audio", "pdf"}:
|
|
1517
|
+
texts = list(zip(df_proc["_id"], ["" for _ in df_proc[column_name]]))
|
|
1518
|
+
else:
|
|
1519
|
+
texts = list(zip(df_proc["_id"], df_proc[column_name].astype(str)))
|
|
1520
|
+
texts_by_id = {i: t for i, t in texts}
|
|
1521
|
+
item_ids = [i for i, _ in texts]
|
|
1522
|
+
|
|
1523
|
+
images_by_id: Dict[str, List[str]] = {}
|
|
1524
|
+
audio_by_id: Dict[str, List[Dict[str, str]]] = {}
|
|
1525
|
+
pdfs_by_id: Dict[str, List[Dict[str, str]]] = {}
|
|
1526
|
+
if self.cfg.modality == "image":
|
|
1527
|
+
for rid, imgs in zip(df_proc["_id"], df_proc[column_name]):
|
|
1528
|
+
encoded = load_image_inputs(imgs)
|
|
1529
|
+
if encoded:
|
|
1530
|
+
images_by_id[rid] = encoded
|
|
1531
|
+
elif self.cfg.modality == "audio":
|
|
1532
|
+
for rid, auds in zip(df_proc["_id"], df_proc[column_name]):
|
|
1533
|
+
encoded = load_audio_inputs(auds)
|
|
1534
|
+
if encoded:
|
|
1535
|
+
audio_by_id[rid] = encoded
|
|
1536
|
+
elif self.cfg.modality == "pdf":
|
|
1537
|
+
for rid, pdfs in zip(df_proc["_id"], df_proc[column_name]):
|
|
1538
|
+
encoded = load_pdf_inputs(pdfs)
|
|
1539
|
+
if encoded:
|
|
1540
|
+
pdfs_by_id[rid] = encoded
|
|
1541
|
+
# derive list of attributes
|
|
1542
|
+
if isinstance(self.cfg.attributes, dict):
|
|
1543
|
+
attr_keys = list(self.cfg.attributes.keys())
|
|
1544
|
+
else:
|
|
1545
|
+
attr_keys = list(self.cfg.attributes)
|
|
1546
|
+
# initialise ratings for each item/attribute
|
|
1547
|
+
ratings: Dict[str, Dict[str, float]] = {
|
|
1548
|
+
i: {a: 0.0 for a in attr_keys} for i in item_ids
|
|
1549
|
+
}
|
|
1550
|
+
rate_seed: Dict[str, Dict[str, float]] = {}
|
|
1551
|
+
if self.cfg.primer_scores:
|
|
1552
|
+
self._apply_primer(ratings, self.cfg.primer_scores, attr_keys)
|
|
1553
|
+
if self.cfg.initial_rating_pass and attr_keys:
|
|
1554
|
+
print(
|
|
1555
|
+
"[Rank] Running initial rating pass to seed pairwise comparisons "
|
|
1556
|
+
"(disable with initial_rating_pass=False)."
|
|
1557
|
+
)
|
|
1558
|
+
rate_dir = os.path.join(self.cfg.save_dir, f"{base_name}_initial_rate")
|
|
1559
|
+
os.makedirs(rate_dir, exist_ok=True)
|
|
1560
|
+
rate_df = await self._run_rate_pass(
|
|
1561
|
+
df_proc,
|
|
1562
|
+
column_name,
|
|
1563
|
+
save_dir=rate_dir,
|
|
1564
|
+
file_name=f"{base_name}_initial_rate.csv",
|
|
1565
|
+
reset_files=reset_files,
|
|
1566
|
+
runtime_kwargs=kwargs,
|
|
1567
|
+
)
|
|
1568
|
+
rate_seed = self._seed_ratings_from_rate(
|
|
1569
|
+
rate_df,
|
|
1570
|
+
id_column=id_column,
|
|
1571
|
+
text_column=column_name,
|
|
1572
|
+
item_ids=item_ids,
|
|
1573
|
+
attr_keys=attr_keys,
|
|
1574
|
+
)
|
|
1575
|
+
if rate_seed:
|
|
1576
|
+
print(
|
|
1577
|
+
"[Rank] Initial rating pass complete. Seeding tournament with "
|
|
1578
|
+
"centred ratings from the rate stage."
|
|
1579
|
+
)
|
|
1580
|
+
for item_id, attr_map in rate_seed.items():
|
|
1581
|
+
for attr, val in attr_map.items():
|
|
1582
|
+
ratings[item_id][attr] = val
|
|
1583
|
+
has_seed_ratings = bool(rate_seed)
|
|
1584
|
+
# maintain a history of pairwise outcomes for each attribute
|
|
1585
|
+
history_pairs: Dict[str, List[Tuple[str, str]]] = {a: [] for a in attr_keys}
|
|
1586
|
+
# store per‑attribute standard errors across items
|
|
1587
|
+
se_store: Dict[str, Dict[str, float]] = {
|
|
1588
|
+
a: {i: np.nan for i in item_ids} for a in attr_keys
|
|
1589
|
+
}
|
|
1590
|
+
# Define attribute batches once to reuse across replay and new rounds
|
|
1591
|
+
attr_count = len(attr_keys)
|
|
1592
|
+
if attr_count > self.cfg.n_attributes_per_run:
|
|
1593
|
+
batches = (
|
|
1594
|
+
attr_count + self.cfg.n_attributes_per_run - 1
|
|
1595
|
+
) // self.cfg.n_attributes_per_run
|
|
1596
|
+
print(
|
|
1597
|
+
f"[Rank] {attr_count} attributes provided. n_attributes_per_run={self.cfg.n_attributes_per_run}. "
|
|
1598
|
+
f"Splitting into {batches} prompt batches. Increase n_attributes_per_run if you want all attributes "
|
|
1599
|
+
"to be processed in the same prompt."
|
|
1600
|
+
)
|
|
1601
|
+
attr_batches: List[List[str]] = [
|
|
1602
|
+
attr_keys[i : i + self.cfg.n_attributes_per_run]
|
|
1603
|
+
for i in range(0, len(attr_keys), self.cfg.n_attributes_per_run)
|
|
1604
|
+
]
|
|
1605
|
+
|
|
1606
|
+
|
|
1607
|
+
# Helper function to write the current results to the final CSV. This
|
|
1608
|
+
# builds the output DataFrame from the current ``df_proc`` and
|
|
1609
|
+
# ``ratings``/``se_store``/``zscores`` and writes it to
|
|
1610
|
+
# ``final_path``.
|
|
1611
|
+
def _write_checkpoint() -> None:
|
|
1612
|
+
# Compute z‑scores for each attribute so that we can expose centred
|
|
1613
|
+
# BT scores alongside their normalised variants.
|
|
1614
|
+
zscores_local: Dict[str, Dict[str, float]] = {}
|
|
1615
|
+
for attr in attr_keys:
|
|
1616
|
+
vals = np.array([ratings[i][attr] for i in item_ids])
|
|
1617
|
+
mean = vals.mean()
|
|
1618
|
+
std = vals.std(ddof=0)
|
|
1619
|
+
if std == 0:
|
|
1620
|
+
zscores_local[attr] = {i: 0.0 for i in item_ids}
|
|
1621
|
+
else:
|
|
1622
|
+
zscores_local[attr] = {
|
|
1623
|
+
i: float((ratings[i][attr] - mean) / std) for i in item_ids
|
|
1624
|
+
}
|
|
1625
|
+
# Merge computed results back into the original DataFrame copy.
|
|
1626
|
+
for attr in attr_keys:
|
|
1627
|
+
raw_col = f"{attr}_raw"
|
|
1628
|
+
# ratings
|
|
1629
|
+
val_map = {i: ratings[i][attr] for i in item_ids}
|
|
1630
|
+
df_proc[raw_col] = df_proc["_id"].map(val_map)
|
|
1631
|
+
# standard errors
|
|
1632
|
+
se_map = {i: se_store[attr].get(i, np.nan) for i in item_ids}
|
|
1633
|
+
df_proc[f"{attr}_se"] = df_proc["_id"].map(se_map)
|
|
1634
|
+
# z‑scores
|
|
1635
|
+
z_map = zscores_local.get(attr, {i: np.nan for i in item_ids})
|
|
1636
|
+
df_proc[attr] = df_proc["_id"].map(z_map)
|
|
1637
|
+
|
|
1638
|
+
# Reorder columns: original user columns first (excluding the internal ``_id``),
|
|
1639
|
+
# then for each attribute the z‑score column followed by raw scores and
|
|
1640
|
+
# standard errors.
|
|
1641
|
+
original_cols = [
|
|
1642
|
+
c for c in df.columns
|
|
1643
|
+
] # preserve the order provided by the user
|
|
1644
|
+
new_cols: List[str] = []
|
|
1645
|
+
for attr in attr_keys:
|
|
1646
|
+
new_cols.append(attr)
|
|
1647
|
+
new_cols.append(f"{attr}_raw")
|
|
1648
|
+
new_cols.append(f"{attr}_se")
|
|
1649
|
+
final_cols = original_cols + new_cols
|
|
1650
|
+
final_cols = [c for c in final_cols if c in df_proc.columns]
|
|
1651
|
+
df_out_local = df_proc[final_cols].copy()
|
|
1652
|
+
# Write the final results to disk in CSV format. Using CSV avoids
|
|
1653
|
+
# Excel row limits and unnecessary overhead.
|
|
1654
|
+
df_out_local.to_csv(final_path, index=False)
|
|
1655
|
+
|
|
1656
|
+
# If there are completed rounds and we're resuming, replay them to
|
|
1657
|
+
# reconstruct the ratings and uncertainties. After each replayed
|
|
1658
|
+
# round we write a checkpoint to ``final_path``.
|
|
1659
|
+
if start_round > 0:
|
|
1660
|
+
for replay_rnd in range(start_round):
|
|
1661
|
+
round_path = os.path.join(
|
|
1662
|
+
self.cfg.save_dir, f"{base_name}_round{replay_rnd}.csv"
|
|
1663
|
+
)
|
|
1664
|
+
if not os.path.exists(round_path):
|
|
1665
|
+
break
|
|
1666
|
+
try:
|
|
1667
|
+
# Load existing responses for this round
|
|
1668
|
+
df_round = pd.read_csv(round_path)
|
|
1669
|
+
df_round["Response"] = df_round["Response"].apply(
|
|
1670
|
+
lambda x: None if pd.isna(x) else x
|
|
1671
|
+
)
|
|
1672
|
+
except Exception:
|
|
1673
|
+
continue
|
|
1674
|
+
|
|
1675
|
+
# Parse each response to build history_pairs
|
|
1676
|
+
async def _coerce_dict_replay(raw: Any) -> Dict[str, Any]:
|
|
1677
|
+
obj = await safest_json(raw)
|
|
1678
|
+
if isinstance(obj, dict):
|
|
1679
|
+
return obj
|
|
1680
|
+
if isinstance(obj, str):
|
|
1681
|
+
obj2 = await safest_json(obj)
|
|
1682
|
+
if isinstance(obj2, dict):
|
|
1683
|
+
return obj2
|
|
1684
|
+
if isinstance(obj, list) and obj:
|
|
1685
|
+
inner = await safest_json(obj[0])
|
|
1686
|
+
if isinstance(inner, dict):
|
|
1687
|
+
return inner
|
|
1688
|
+
return {}
|
|
1689
|
+
|
|
1690
|
+
if {"Batch", "IdA", "IdB"}.issubset(df_round.columns):
|
|
1691
|
+
for batch_idx_raw, id_a, id_b, resp_raw in zip(
|
|
1692
|
+
df_round["Batch"],
|
|
1693
|
+
df_round["IdA"],
|
|
1694
|
+
df_round["IdB"],
|
|
1695
|
+
df_round["Response"],
|
|
1696
|
+
):
|
|
1697
|
+
batch_idx = int(batch_idx_raw)
|
|
1698
|
+
batch = attr_batches[batch_idx]
|
|
1699
|
+
batch_attr_map = {str(k).strip().lower(): k for k in batch}
|
|
1700
|
+
safe_obj = await _coerce_dict_replay(resp_raw)
|
|
1701
|
+
if not safe_obj:
|
|
1702
|
+
continue
|
|
1703
|
+
for attr_raw, winner_raw in safe_obj.items():
|
|
1704
|
+
attr_key_l = str(attr_raw).strip().lower()
|
|
1705
|
+
if attr_key_l not in batch_attr_map:
|
|
1706
|
+
continue
|
|
1707
|
+
real_attr = batch_attr_map[attr_key_l]
|
|
1708
|
+
val = winner_raw
|
|
1709
|
+
if isinstance(val, dict) and "winner" in val:
|
|
1710
|
+
val = val.get("winner")
|
|
1711
|
+
if isinstance(val, str):
|
|
1712
|
+
v = val.strip().lower()
|
|
1713
|
+
else:
|
|
1714
|
+
v = ""
|
|
1715
|
+
if v.startswith(("cir", "c", "left", "text a")):
|
|
1716
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1717
|
+
elif v.startswith(("squ", "b", "right", "text b")):
|
|
1718
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1719
|
+
elif v.startswith("draw") or v.startswith("insufficient"):
|
|
1720
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1721
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1722
|
+
else:
|
|
1723
|
+
continue
|
|
1724
|
+
else:
|
|
1725
|
+
for ident, resp_raw in zip(
|
|
1726
|
+
df_round["Identifier"], df_round["Response"]
|
|
1727
|
+
):
|
|
1728
|
+
parts = str(ident).split("|")
|
|
1729
|
+
if len(parts) != 5:
|
|
1730
|
+
continue
|
|
1731
|
+
_, batch_idx_str, _, id_a, id_b = parts
|
|
1732
|
+
batch_idx = int(batch_idx_str)
|
|
1733
|
+
batch = attr_batches[batch_idx]
|
|
1734
|
+
batch_attr_map = {str(k).strip().lower(): k for k in batch}
|
|
1735
|
+
safe_obj = await _coerce_dict_replay(resp_raw)
|
|
1736
|
+
if not safe_obj:
|
|
1737
|
+
continue
|
|
1738
|
+
for attr_raw, winner_raw in safe_obj.items():
|
|
1739
|
+
attr_key_l = str(attr_raw).strip().lower()
|
|
1740
|
+
if attr_key_l not in batch_attr_map:
|
|
1741
|
+
continue
|
|
1742
|
+
real_attr = batch_attr_map[attr_key_l]
|
|
1743
|
+
val = winner_raw
|
|
1744
|
+
if isinstance(val, dict) and "winner" in val:
|
|
1745
|
+
val = val.get("winner")
|
|
1746
|
+
if isinstance(val, str):
|
|
1747
|
+
v = val.strip().lower()
|
|
1748
|
+
else:
|
|
1749
|
+
v = ""
|
|
1750
|
+
if v.startswith(("cir", "c", "left", "text a")):
|
|
1751
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1752
|
+
elif v.startswith(("squ", "b", "right", "text b")):
|
|
1753
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1754
|
+
elif v.startswith("draw") or v.startswith("insufficient"):
|
|
1755
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1756
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1757
|
+
else:
|
|
1758
|
+
continue
|
|
1759
|
+
# After parsing all pairs for this round, update ratings
|
|
1760
|
+
se_agg_next: Dict[str, float] = {i: 0.0 for i in item_ids}
|
|
1761
|
+
se_agg_counts: Dict[str, int] = {i: 0 for i in item_ids}
|
|
1762
|
+
for attr in attr_keys:
|
|
1763
|
+
outcomes = history_pairs[attr]
|
|
1764
|
+
if len(outcomes) == 0:
|
|
1765
|
+
continue
|
|
1766
|
+
bt_scores, n_ij, p_ij = self._fit_bt(
|
|
1767
|
+
item_ids=item_ids,
|
|
1768
|
+
outcomes=outcomes,
|
|
1769
|
+
pseudo=self.cfg.learning_rate,
|
|
1770
|
+
max_iter=self._MAX_ITER,
|
|
1771
|
+
tol=self._TOL,
|
|
1772
|
+
return_info=True,
|
|
1773
|
+
)
|
|
1774
|
+
for i in item_ids:
|
|
1775
|
+
ratings[i][attr] = bt_scores[i]
|
|
1776
|
+
s_vec = np.array([bt_scores[i] for i in item_ids])
|
|
1777
|
+
se_vec = self._bt_standard_errors(
|
|
1778
|
+
s=s_vec,
|
|
1779
|
+
n_ij=n_ij,
|
|
1780
|
+
p_ij=p_ij,
|
|
1781
|
+
ridge=self._SE_RIDGE,
|
|
1782
|
+
)
|
|
1783
|
+
for i, se_val in zip(item_ids, se_vec):
|
|
1784
|
+
se_store[attr][i] = float(se_val)
|
|
1785
|
+
se_agg_next[i] += float(se_val)
|
|
1786
|
+
se_agg_counts[i] += 1
|
|
1787
|
+
for i in item_ids:
|
|
1788
|
+
if se_agg_counts[i] > 0:
|
|
1789
|
+
se_agg_next[i] /= se_agg_counts[i]
|
|
1790
|
+
else:
|
|
1791
|
+
se_agg_next[i] = 1.0
|
|
1792
|
+
self._last_se_agg = se_agg_next
|
|
1793
|
+
# Centre ratings to zero mean for each attribute
|
|
1794
|
+
for attr in attr_keys:
|
|
1795
|
+
vals = [ratings[i][attr] for i in item_ids]
|
|
1796
|
+
mean_val = float(np.mean(vals))
|
|
1797
|
+
for i in item_ids:
|
|
1798
|
+
ratings[i][attr] -= mean_val
|
|
1799
|
+
# Write checkpoint after this replayed round
|
|
1800
|
+
_write_checkpoint()
|
|
1801
|
+
|
|
1802
|
+
# Determine if any new items were added and need to catch up on existing rounds
|
|
1803
|
+
seen_ids: Set[str] = set()
|
|
1804
|
+
for pair_list in history_pairs.values():
|
|
1805
|
+
for a, b in pair_list:
|
|
1806
|
+
seen_ids.add(a)
|
|
1807
|
+
seen_ids.add(b)
|
|
1808
|
+
new_ids = [i for i in item_ids if i not in seen_ids]
|
|
1809
|
+
await self._catch_up_existing_rounds(
|
|
1810
|
+
new_ids=new_ids,
|
|
1811
|
+
round_indices=list(range(start_round)),
|
|
1812
|
+
item_ids=item_ids,
|
|
1813
|
+
texts_by_id=texts_by_id,
|
|
1814
|
+
images_by_id=images_by_id,
|
|
1815
|
+
audio_by_id=audio_by_id,
|
|
1816
|
+
attr_batches=attr_batches,
|
|
1817
|
+
attr_keys=attr_keys,
|
|
1818
|
+
history_pairs=history_pairs,
|
|
1819
|
+
ratings=ratings,
|
|
1820
|
+
se_store=se_store,
|
|
1821
|
+
base_name=base_name,
|
|
1822
|
+
df_proc=df_proc,
|
|
1823
|
+
_write_checkpoint=_write_checkpoint,
|
|
1824
|
+
current_ratings=None,
|
|
1825
|
+
se_agg_local=self._last_se_agg,
|
|
1826
|
+
reset_files=reset_files,
|
|
1827
|
+
**kwargs,
|
|
1828
|
+
)
|
|
1829
|
+
|
|
1830
|
+
# Now proceed with new rounds starting from ``start_round``
|
|
1831
|
+
for rnd in range(start_round, self.cfg.n_rounds):
|
|
1832
|
+
# aggregate current ratings across attributes for pairing
|
|
1833
|
+
current_agg = {
|
|
1834
|
+
i: float(np.mean(list(ratings[i].values()))) for i in item_ids
|
|
1835
|
+
}
|
|
1836
|
+
se_agg_local = self._last_se_agg
|
|
1837
|
+
use_current = rnd > 0 or start_round > 0 or has_seed_ratings
|
|
1838
|
+
se_source = se_agg_local if (rnd > 0 or start_round > 0 or se_agg_local is not None) else None
|
|
1839
|
+
pairs = self._generate_pairs(
|
|
1840
|
+
item_ids=item_ids,
|
|
1841
|
+
texts_by_id=texts_by_id,
|
|
1842
|
+
current_ratings=current_agg if use_current else None,
|
|
1843
|
+
se_agg=se_source,
|
|
1844
|
+
)
|
|
1845
|
+
if not pairs:
|
|
1846
|
+
break
|
|
1847
|
+
announce_prompt_rendering(
|
|
1848
|
+
"Rank", len(attr_batches) * len(pairs)
|
|
1849
|
+
)
|
|
1850
|
+
prompts: List[str] = []
|
|
1851
|
+
ids: List[str] = []
|
|
1852
|
+
pair_images: Dict[str, List[str]] = {}
|
|
1853
|
+
pair_audio: Dict[str, List[Dict[str, str]]] = {}
|
|
1854
|
+
meta_map: Dict[str, Tuple[int, int, str, str]] = {}
|
|
1855
|
+
id_to_circle_first: Dict[str, bool] = {}
|
|
1856
|
+
for batch_idx, batch in enumerate(attr_batches):
|
|
1857
|
+
attr_def_map = (
|
|
1858
|
+
{a: self.cfg.attributes[a] for a in batch}
|
|
1859
|
+
if isinstance(self.cfg.attributes, dict)
|
|
1860
|
+
else {a: "" for a in batch}
|
|
1861
|
+
)
|
|
1862
|
+
for pair_idx, ((id_a, t_a), (id_b, t_b)) in enumerate(pairs):
|
|
1863
|
+
raw_ident = f"{rnd}|{batch_idx}|{pair_idx}|{id_a}|{id_b}"
|
|
1864
|
+
sha8 = hashlib.sha1(raw_ident.encode()).hexdigest()[:8]
|
|
1865
|
+
circle_first_flag = (
|
|
1866
|
+
self.cfg.circle_first
|
|
1867
|
+
if self.cfg.circle_first is not None
|
|
1868
|
+
else self.rng.random() < 0.5
|
|
1869
|
+
)
|
|
1870
|
+
id_to_circle_first[sha8] = circle_first_flag
|
|
1871
|
+
prompts.append(
|
|
1872
|
+
self.template.render(
|
|
1873
|
+
entry_circle=t_a,
|
|
1874
|
+
entry_square=t_b,
|
|
1875
|
+
attributes=attr_def_map,
|
|
1876
|
+
additional_instructions=self.cfg.additional_instructions
|
|
1877
|
+
or "",
|
|
1878
|
+
modality=self.cfg.modality,
|
|
1879
|
+
circle_first=circle_first_flag,
|
|
1880
|
+
)
|
|
1881
|
+
)
|
|
1882
|
+
ids.append(sha8)
|
|
1883
|
+
meta_map[sha8] = (batch_idx, pair_idx, id_a, id_b)
|
|
1884
|
+
if images_by_id:
|
|
1885
|
+
imgs = []
|
|
1886
|
+
ia = images_by_id.get(id_a, [])
|
|
1887
|
+
ib = images_by_id.get(id_b, [])
|
|
1888
|
+
if circle_first_flag:
|
|
1889
|
+
if ia:
|
|
1890
|
+
imgs.extend(ia)
|
|
1891
|
+
if ib:
|
|
1892
|
+
imgs.extend(ib)
|
|
1893
|
+
else:
|
|
1894
|
+
if ib:
|
|
1895
|
+
imgs.extend(ib)
|
|
1896
|
+
if ia:
|
|
1897
|
+
imgs.extend(ia)
|
|
1898
|
+
if imgs:
|
|
1899
|
+
pair_images[sha8] = imgs
|
|
1900
|
+
if audio_by_id:
|
|
1901
|
+
auds = []
|
|
1902
|
+
aa = audio_by_id.get(id_a, [])
|
|
1903
|
+
ab = audio_by_id.get(id_b, [])
|
|
1904
|
+
if circle_first_flag:
|
|
1905
|
+
if aa:
|
|
1906
|
+
auds.extend(aa)
|
|
1907
|
+
if ab:
|
|
1908
|
+
auds.extend(ab)
|
|
1909
|
+
else:
|
|
1910
|
+
if ab:
|
|
1911
|
+
auds.extend(ab)
|
|
1912
|
+
if aa:
|
|
1913
|
+
auds.extend(aa)
|
|
1914
|
+
if auds:
|
|
1915
|
+
pair_audio[sha8] = auds
|
|
1916
|
+
# obtain responses from the language model for this round
|
|
1917
|
+
round_path = os.path.join(self.cfg.save_dir, f"{base_name}_round{rnd}.csv")
|
|
1918
|
+
resp_df = await get_all_responses(
|
|
1919
|
+
prompts=prompts,
|
|
1920
|
+
identifiers=ids,
|
|
1921
|
+
prompt_images=pair_images or None,
|
|
1922
|
+
prompt_audio=pair_audio or None,
|
|
1923
|
+
n_parallels=self.cfg.n_parallels,
|
|
1924
|
+
model=self.cfg.model,
|
|
1925
|
+
json_mode=self.cfg.modality != "audio",
|
|
1926
|
+
save_path=round_path,
|
|
1927
|
+
reset_files=reset_files,
|
|
1928
|
+
use_dummy=self.cfg.use_dummy,
|
|
1929
|
+
max_timeout=self.cfg.max_timeout,
|
|
1930
|
+
max_retries=1,
|
|
1931
|
+
reasoning_effort=self.cfg.reasoning_effort,
|
|
1932
|
+
reasoning_summary=self.cfg.reasoning_summary,
|
|
1933
|
+
**kwargs,
|
|
1934
|
+
)
|
|
1935
|
+
# attach metadata columns and overwrite the round CSV
|
|
1936
|
+
resp_df["Batch"] = resp_df.Identifier.map(
|
|
1937
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[0]
|
|
1938
|
+
)
|
|
1939
|
+
resp_df["Pair"] = resp_df.Identifier.map(
|
|
1940
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[1]
|
|
1941
|
+
)
|
|
1942
|
+
resp_df["IdA"] = resp_df.Identifier.map(
|
|
1943
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[2]
|
|
1944
|
+
)
|
|
1945
|
+
resp_df["IdB"] = resp_df.Identifier.map(
|
|
1946
|
+
lambda x: meta_map.get(str(x), (np.nan, np.nan, "", ""))[3]
|
|
1947
|
+
)
|
|
1948
|
+
resp_df.to_csv(round_path, index=False)
|
|
1949
|
+
|
|
1950
|
+
# parse each response
|
|
1951
|
+
# reuse the _coerce_dict function defined in the original implementation
|
|
1952
|
+
async def _coerce_dict(raw: Any) -> Dict[str, Any]:
|
|
1953
|
+
obj = await safest_json(raw)
|
|
1954
|
+
if isinstance(obj, dict):
|
|
1955
|
+
return obj
|
|
1956
|
+
if isinstance(obj, str):
|
|
1957
|
+
obj2 = await safest_json(obj)
|
|
1958
|
+
if isinstance(obj2, dict):
|
|
1959
|
+
return obj2
|
|
1960
|
+
if isinstance(obj, list) and obj:
|
|
1961
|
+
inner = await safest_json(obj[0])
|
|
1962
|
+
if isinstance(inner, dict):
|
|
1963
|
+
return inner
|
|
1964
|
+
return {}
|
|
1965
|
+
|
|
1966
|
+
for ident, resp in zip(resp_df.Identifier, resp_df.Response):
|
|
1967
|
+
meta = meta_map.get(str(ident))
|
|
1968
|
+
if not meta:
|
|
1969
|
+
continue
|
|
1970
|
+
batch_idx, _, id_a, id_b = meta
|
|
1971
|
+
safe_obj = await _coerce_dict(resp)
|
|
1972
|
+
if not safe_obj:
|
|
1973
|
+
continue
|
|
1974
|
+
batch = attr_batches[batch_idx]
|
|
1975
|
+
batch_attr_map = {str(k).strip().lower(): k for k in batch}
|
|
1976
|
+
for attr_raw, winner_raw in safe_obj.items():
|
|
1977
|
+
attr_key_l = str(attr_raw).strip().lower()
|
|
1978
|
+
if attr_key_l not in batch_attr_map:
|
|
1979
|
+
continue
|
|
1980
|
+
real_attr = batch_attr_map[attr_key_l]
|
|
1981
|
+
val = winner_raw
|
|
1982
|
+
if isinstance(val, dict) and "winner" in val:
|
|
1983
|
+
val = val.get("winner")
|
|
1984
|
+
if isinstance(val, str):
|
|
1985
|
+
v = val.strip().lower()
|
|
1986
|
+
else:
|
|
1987
|
+
v = ""
|
|
1988
|
+
if v.startswith(("cir", "c", "left", "text a")):
|
|
1989
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1990
|
+
elif v.startswith(("squ", "b", "right", "text b")):
|
|
1991
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1992
|
+
elif v.startswith("draw") or v.startswith("insufficient"):
|
|
1993
|
+
history_pairs[real_attr].append((id_a, id_b))
|
|
1994
|
+
history_pairs[real_attr].append((id_b, id_a))
|
|
1995
|
+
else:
|
|
1996
|
+
continue
|
|
1997
|
+
# update ratings using the BT model for this round
|
|
1998
|
+
se_agg_next: Dict[str, float] = {i: 0.0 for i in item_ids}
|
|
1999
|
+
se_agg_counts: Dict[str, int] = {i: 0 for i in item_ids}
|
|
2000
|
+
for attr in attr_keys:
|
|
2001
|
+
outcomes = history_pairs[attr]
|
|
2002
|
+
if len(outcomes) == 0:
|
|
2003
|
+
continue
|
|
2004
|
+
bt_scores, n_ij, p_ij = self._fit_bt(
|
|
2005
|
+
item_ids=item_ids,
|
|
2006
|
+
outcomes=outcomes,
|
|
2007
|
+
pseudo=self.cfg.learning_rate,
|
|
2008
|
+
max_iter=self._MAX_ITER,
|
|
2009
|
+
tol=self._TOL,
|
|
2010
|
+
return_info=True,
|
|
2011
|
+
)
|
|
2012
|
+
for i in item_ids:
|
|
2013
|
+
ratings[i][attr] = bt_scores[i]
|
|
2014
|
+
s_vec = np.array([bt_scores[i] for i in item_ids])
|
|
2015
|
+
se_vec = self._bt_standard_errors(
|
|
2016
|
+
s=s_vec,
|
|
2017
|
+
n_ij=n_ij,
|
|
2018
|
+
p_ij=p_ij,
|
|
2019
|
+
ridge=self._SE_RIDGE,
|
|
2020
|
+
)
|
|
2021
|
+
for i, se_val in zip(item_ids, se_vec):
|
|
2022
|
+
se_store[attr][i] = float(se_val)
|
|
2023
|
+
se_agg_next[i] += float(se_val)
|
|
2024
|
+
se_agg_counts[i] += 1
|
|
2025
|
+
for i in item_ids:
|
|
2026
|
+
if se_agg_counts[i] > 0:
|
|
2027
|
+
se_agg_next[i] /= se_agg_counts[i]
|
|
2028
|
+
else:
|
|
2029
|
+
se_agg_next[i] = 1.0
|
|
2030
|
+
self._last_se_agg = se_agg_next
|
|
2031
|
+
# Centre ratings to zero mean for each attribute
|
|
2032
|
+
for attr in attr_keys:
|
|
2033
|
+
vals = [ratings[i][attr] for i in item_ids]
|
|
2034
|
+
mean_val = float(np.mean(vals))
|
|
2035
|
+
for i in item_ids:
|
|
2036
|
+
ratings[i][attr] -= mean_val
|
|
2037
|
+
# Write checkpoint after this new round
|
|
2038
|
+
_write_checkpoint()
|
|
2039
|
+
# After processing all rounds, return the final DataFrame
|
|
2040
|
+
# The checkpoint has already been written in the final iteration
|
|
2041
|
+
return pd.read_csv(final_path)
|