llmcomp 1.2.4__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,15 +15,198 @@ import yaml
15
15
  from tqdm import tqdm
16
16
 
17
17
  from llmcomp.config import Config
18
- from llmcomp.question.plots import (
19
- default_title,
20
- free_form_stacked_bar,
21
- probs_stacked_bar,
22
- rating_cumulative_plot,
23
- )
18
+ from llmcomp.question.plots import plot as plots_plot
24
19
  from llmcomp.question.result import JudgeCache, Result
20
+ from llmcomp.question.viewer import render_dataframe
25
21
  from llmcomp.runner.runner import Runner
26
22
 
23
+
24
+ class _ViewMethod:
25
+ """Descriptor that allows view() to work both as classmethod and instance method.
26
+
27
+ - Question.view(df) - class-level call, views a DataFrame directly
28
+ - question.view(MODELS) - instance call, runs df() then views
29
+ - question.view(df) - instance call, views DataFrame directly
30
+ """
31
+
32
+ def __get__(self, obj, objtype=None):
33
+ if obj is None:
34
+ # Called on class: Question.view(df)
35
+ return self._class_view
36
+ else:
37
+ # Called on instance: question.view(...)
38
+ return lambda *args, **kwargs: self._instance_view(obj, *args, **kwargs)
39
+
40
+ def _class_view(
41
+ self,
42
+ df: pd.DataFrame,
43
+ *,
44
+ sort_by: str | None = "__random__",
45
+ sort_ascending: bool = True,
46
+ open_browser: bool = True,
47
+ port: int = 8501,
48
+ ) -> None:
49
+ """View a DataFrame directly (class method usage).
50
+
51
+ Args:
52
+ sort_by: Column to sort by. Default "__random__" shuffles rows randomly
53
+ (new seed on each browser refresh). Use None for original order.
54
+ """
55
+ if isinstance(df, dict):
56
+ raise TypeError(
57
+ "Question.view() expects a DataFrame, not a dict.\n"
58
+ "To view model results, use an instance: question.view(model_groups)\n"
59
+ "Or pass a DataFrame: Question.view(question.df(model_groups))"
60
+ )
61
+ render_dataframe(
62
+ df,
63
+ sort_by=sort_by,
64
+ sort_ascending=sort_ascending,
65
+ open_browser=open_browser,
66
+ port=port,
67
+ )
68
+
69
+ def _instance_view(
70
+ self,
71
+ instance: "Question",
72
+ model_groups_or_df: dict[str, list[str]] | pd.DataFrame,
73
+ *,
74
+ sort_by: str | None = "__random__",
75
+ sort_ascending: bool = True,
76
+ open_browser: bool = True,
77
+ port: int = 8501,
78
+ ) -> None:
79
+ """View results (instance method usage).
80
+
81
+ Args:
82
+ sort_by: Column to sort by. Default "__random__" shuffles rows randomly
83
+ (new seed on each browser refresh). Use None for original order.
84
+ """
85
+ if isinstance(model_groups_or_df, pd.DataFrame):
86
+ df = model_groups_or_df
87
+ else:
88
+ df = instance.df(model_groups_or_df)
89
+
90
+ render_dataframe(
91
+ df,
92
+ sort_by=sort_by,
93
+ sort_ascending=sort_ascending,
94
+ open_browser=open_browser,
95
+ port=port,
96
+ )
97
+
98
+
99
+ class _PlotMethod:
100
+ def __get__(self, obj, objtype=None):
101
+ if obj is None:
102
+ return self._class_plot
103
+ else:
104
+ return lambda *args, **kwargs: self._instance_plot(obj, *args, **kwargs)
105
+
106
+ def _class_plot(
107
+ self,
108
+ df: pd.DataFrame,
109
+ category_column: str = "group",
110
+ answer_column: str = "answer",
111
+ selected_categories: list[str] = None,
112
+ selected_answers: list[str] = None,
113
+ min_fraction: float = None,
114
+ colors: dict[str, str] = None,
115
+ title: str = None,
116
+ filename: str = None,
117
+ ):
118
+ """Plot results as a chart.
119
+
120
+ Can be called as:
121
+ - Question.plot(df) - plot a DataFrame directly
122
+ - question.plot(model_groups) - run df() on models, then plot
123
+ - question.plot(df) - plot a DataFrame directly
124
+
125
+ Args:
126
+ model_groups_or_df: Either a dict mapping group names to model lists,
127
+ or a DataFrame to plot directly.
128
+ category_column: Column to group by on x-axis. Default: "group".
129
+ answer_column: Column containing answers to plot. Default: "answer"
130
+ (or "probs" for Rating questions).
131
+ selected_categories: List of categories to include (in order). Others excluded.
132
+ selected_answers: List of answers to show in stacked bar. Others grouped as "[OTHER]".
133
+ min_fraction: Minimum fraction threshold for stacked bar. Answers below grouped as "[OTHER]".
134
+ colors: Dict mapping answer values to colors for stacked bar.
135
+ title: Plot title. Auto-generated from question if not provided.
136
+ filename: If provided, saves the plot to this file path.
137
+
138
+ If selected_answers, min_fraction, or colors are provided, a stacked bar chart is created.
139
+ Otherwise, llmcomp will try to create the best plot for the data.
140
+ """
141
+ if isinstance(df, dict):
142
+ raise TypeError(
143
+ "Question.plot() expects a DataFrame, not a dict.\n"
144
+ "To plot model results, use an instance: question.plot(model_groups)\n"
145
+ "Or pass a DataFrame: Question.plot(question.df(model_groups))"
146
+ )
147
+ return plots_plot(
148
+ df,
149
+ answer_column=answer_column,
150
+ category_column=category_column,
151
+ selected_categories=selected_categories,
152
+ selected_answers=selected_answers,
153
+ min_fraction=min_fraction,
154
+ colors=colors,
155
+ title=title,
156
+ filename=filename,
157
+ )
158
+
159
+ def _instance_plot(
160
+ self,
161
+ instance: "Question",
162
+ model_groups_or_df: dict[str, list[str]] | pd.DataFrame,
163
+ category_column: str = "group",
164
+ answer_column: str = None,
165
+ selected_answers: list[str] = None,
166
+ min_fraction: float = None,
167
+ colors: dict[str, str] = None,
168
+ title: str = None,
169
+ filename: str = None,
170
+ ):
171
+ if isinstance(model_groups_or_df, pd.DataFrame):
172
+ df = model_groups_or_df
173
+ selected_categories = None
174
+ else:
175
+ model_groups = model_groups_or_df
176
+ df = instance.df(model_groups)
177
+ if category_column == "group":
178
+ selected_categories = list(model_groups.keys())
179
+ elif category_column == "model":
180
+ selected_categories = [model for group in model_groups.values() for model in group]
181
+ else:
182
+ selected_categories = None
183
+
184
+ if answer_column is None:
185
+ if instance.type() == "rating":
186
+ answer_column = "probs"
187
+ else:
188
+ answer_column = "answer"
189
+
190
+ selected_paraphrase = None
191
+ if title is None and instance.paraphrases is not None:
192
+ selected_paraphrase = instance.paraphrases[0]
193
+
194
+ return plots_plot(
195
+ df,
196
+ answer_column=answer_column,
197
+ category_column=category_column,
198
+ selected_categories=selected_categories,
199
+ min_rating=getattr(instance, "min_rating", None),
200
+ max_rating=getattr(instance, "max_rating", None),
201
+ selected_answers=selected_answers,
202
+ min_fraction=min_fraction,
203
+ colors=colors,
204
+ title=title,
205
+ selected_paraphrase=selected_paraphrase,
206
+ filename=filename,
207
+ )
208
+
209
+
27
210
  if TYPE_CHECKING:
28
211
  from llmcomp.question.judge import FreeFormJudge, RatingJudge
29
212
  from llmcomp.question.question import Question
@@ -47,7 +230,7 @@ class Question(ABC):
47
230
  self.name = name
48
231
 
49
232
  # Validate question name to prevent path traversal issues in cache
50
- if not re.match(r'^[a-zA-Z0-9_-]+$', name):
233
+ if not re.match(r'^[a-zA-Z0-9_\-\[\]\.\(\)]+$', name):
51
234
  raise ValueError(
52
235
  f"Invalid question name: {name!r}. "
53
236
  f"Name must contain only letters, numbers, underscores, and hyphens."
@@ -184,6 +367,9 @@ class Question(ABC):
184
367
  question_dict = cls.load_dict(name)
185
368
  return cls.create(**question_dict)
186
369
 
370
+ view = _ViewMethod()
371
+ plot = _PlotMethod()
372
+
187
373
  @classmethod
188
374
  def _load_question_config(cls):
189
375
  """Load all questions from YAML files in Config.yaml_dir."""
@@ -222,7 +408,7 @@ class Question(ABC):
222
408
  "group": group,
223
409
  "answer": el["answer"],
224
410
  "question": el["question"],
225
- "messages": el["messages"],
411
+ "api_kwargs": el["api_kwargs"],
226
412
  "paraphrase_ix": el["paraphrase_ix"],
227
413
  }
228
414
  )
@@ -283,6 +469,37 @@ class Question(ABC):
283
469
 
284
470
  return results
285
471
 
472
+ def clear_cache(self, model: str) -> bool:
473
+ """Clear cached results for this question and model.
474
+
475
+ Args:
476
+ model: The model whose cache should be cleared.
477
+
478
+ Returns:
479
+ True if cache was found and removed, False otherwise.
480
+
481
+ Example:
482
+ >>> question = Question.create(type="free_form", paraphrases=["test"])
483
+ >>> question.df({"group": ["gpt-4"]}) # Creates cache
484
+ >>> question.clear_cache("gpt-4") # Clear cache
485
+ True
486
+ >>> question.clear_cache("gpt-4") # Already cleared
487
+ False
488
+ """
489
+ cache_file = Result.file_path(self, model)
490
+ if os.path.exists(cache_file):
491
+ os.remove(cache_file)
492
+ # Also remove lock file if present
493
+ lock_file = cache_file + ".lock"
494
+ if os.path.exists(lock_file):
495
+ os.remove(lock_file)
496
+ # Clean up empty directory
497
+ cache_dir = os.path.dirname(cache_file)
498
+ if os.path.isdir(cache_dir) and not os.listdir(cache_dir):
499
+ os.rmdir(cache_dir)
500
+ return True
501
+ return False
502
+
286
503
  def many_models_execute(self, models: list[str]) -> list[Result]:
287
504
  """Execute question on multiple models in parallel.
288
505
 
@@ -340,12 +557,11 @@ class Question(ABC):
340
557
  error = payload[0]
341
558
  errors.append((model, error))
342
559
  else:
343
- in_, out = payload
560
+ in_, (out, prepared_kwargs) = payload
344
561
  data = results[models.index(model)]
562
+
345
563
  data[in_["_original_ix"]] = {
346
- # Deepcopy because in_["params"]["messages"] is reused for multiple models
347
- # and we don't want weird side effects if someone later edits the messages
348
- "messages": deepcopy(in_["params"]["messages"]),
564
+ "api_kwargs": deepcopy(prepared_kwargs),
349
565
  "question": in_["_question"],
350
566
  "answer": out,
351
567
  "paraphrase_ix": in_["_paraphrase_ix"],
@@ -416,9 +632,10 @@ class FreeForm(Question):
416
632
  "group",
417
633
  "answer",
418
634
  "question",
419
- "messages",
635
+ "api_kwargs",
420
636
  "paraphrase_ix",
421
637
  "raw_answer",
638
+ "probs",
422
639
  }
423
640
 
424
641
  def __init__(
@@ -426,7 +643,7 @@ class FreeForm(Question):
426
643
  *,
427
644
  temperature: float = 1,
428
645
  max_tokens: int = 1024,
429
- judges: dict[str, str | dict] = None,
646
+ judges: dict[str, str | dict | FreeFormJudge | RatingJudge] | None = None,
430
647
  **kwargs,
431
648
  ):
432
649
  """Initialize a FreeForm question.
@@ -474,7 +691,7 @@ class FreeForm(Question):
474
691
  - group: Group name from model_groups
475
692
  - answer: Model's response text
476
693
  - question: The prompt that was sent
477
- - messages: Full message list sent to model
694
+ - api_kwargs: Full API parameters sent to model (including messages, temperature, etc.)
478
695
  - paraphrase_ix: Index of the paraphrase used
479
696
  - {judge_name}: Score/response from each configured judge
480
697
  - {judge_name}_question: The prompt sent to the judge
@@ -489,6 +706,8 @@ class FreeForm(Question):
489
706
  columns.append(judge_name + "_question")
490
707
  if f"{judge_name}_raw_answer" in df.columns:
491
708
  columns.append(judge_name + "_raw_answer")
709
+ if f"{judge_name}_probs" in df.columns:
710
+ columns.append(judge_name + "_probs")
492
711
  df = df[columns]
493
712
 
494
713
  # Validate that adding judges didn't change row count
@@ -527,6 +746,9 @@ class FreeForm(Question):
527
746
  if "raw_answer" in judge_df.columns:
528
747
  judge_columns.append(judge_name + "_raw_answer")
529
748
  judge_df = judge_df.rename(columns={"raw_answer": judge_name + "_raw_answer"})
749
+ if "probs" in judge_df.columns:
750
+ judge_columns.append(judge_name + "_probs")
751
+ judge_df = judge_df.rename(columns={"probs": judge_name + "_probs"})
530
752
 
531
753
  # Merge the judge results with the original dataframe
532
754
  merged_df = my_df.merge(
@@ -612,63 +834,20 @@ class FreeForm(Question):
612
834
 
613
835
  df = pd.DataFrame(rows)
614
836
 
615
- # Post-process for RatingJudge: copy raw answer and compute processed score
837
+ # Post-process for RatingJudge: copy raw answer, compute probs and processed score
616
838
  from llmcomp.question.judge import RatingJudge
617
839
 
618
840
  if isinstance(judge_question, RatingJudge):
619
841
  df["raw_answer"] = df["answer"].copy()
620
- df["answer"] = df["raw_answer"].apply(judge_question._compute_expected_rating)
842
+ df["probs"] = df["raw_answer"].apply(judge_question._get_normalized_probs)
843
+ df["answer"] = df["probs"].apply(judge_question._compute_expected_rating)
621
844
 
622
845
  return df
623
846
 
624
- def plot(
625
- self,
626
- model_groups: dict[str, list[str]],
627
- category_column: str = "group",
628
- answer_column: str = "answer",
629
- df: pd.DataFrame = None,
630
- selected_answers: list[str] = None,
631
- min_fraction: float = None,
632
- colors: dict[str, str] = None,
633
- title: str = None,
634
- filename: str = None,
635
- ):
636
- """Plot dataframe as a stacked bar chart of answers by category.
637
-
638
- Args:
639
- model_groups: Required. Dict mapping group names to lists of model identifiers.
640
- category_column: Column to use for x-axis categories. Default: "group".
641
- answer_column: Column containing answers to plot. Default: "answer".
642
- Use a judge column name to plot judge scores instead.
643
- df: DataFrame to plot. By default calls self.df(model_groups).
644
- selected_answers: List of specific answers to include. Others grouped as "other".
645
- min_fraction: Minimum fraction threshold. Answers below this are grouped as "other".
646
- colors: Dict mapping answer values to colors.
647
- title: Plot title. If None, auto-generated from paraphrases.
648
- filename: If provided, saves the plot to this file path.
649
-
650
- Returns:
651
- matplotlib Figure object.
652
- """
653
- if df is None:
654
- df = self.df(model_groups)
655
-
656
- if title is None:
657
- title = default_title(self.paraphrases)
658
-
659
- return free_form_stacked_bar(
660
- df,
661
- category_column=category_column,
662
- answer_column=answer_column,
663
- model_groups=model_groups,
664
- selected_answers=selected_answers,
665
- min_fraction=min_fraction,
666
- colors=colors,
667
- title=title,
668
- filename=filename,
669
- )
670
-
671
- def _parse_judges(self, judges: dict[str, str | dict] | None) -> dict[str, "Question"] | None:
847
+ def _parse_judges(
848
+ self,
849
+ judges: dict[str, str | dict | FreeFormJudge | RatingJudge] | None
850
+ ) -> dict[str, FreeFormJudge | RatingJudge] | None:
672
851
  """Parse and validate judges dictionary."""
673
852
  if judges is None:
674
853
  return None
@@ -691,6 +870,11 @@ class FreeForm(Question):
691
870
  f"Judge name '{key}' is forbidden. Names ending with '_raw_answer' conflict with "
692
871
  f"automatically generated columns."
693
872
  )
873
+ if key.endswith("_probs"):
874
+ raise ValueError(
875
+ f"Judge name '{key}' is forbidden. Names ending with '_probs' conflict with "
876
+ f"automatically generated columns."
877
+ )
694
878
 
695
879
  parsed_judges = {}
696
880
  for key, val in judges.items():
@@ -779,13 +963,15 @@ class Rating(Question):
779
963
  - group: Group name from model_groups
780
964
  - answer: Mean rating (float), or None if model refused
781
965
  - raw_answer: Original logprobs dict {token: probability}
966
+ - probs: Normalized probabilities dict {int_rating: probability}
782
967
  - question: The prompt that was sent
783
- - messages: Full message list sent to model
968
+ - api_kwargs: Full API parameters sent to model (including messages, temperature, etc.)
784
969
  - paraphrase_ix: Index of the paraphrase used
785
970
  """
786
971
  df = super().df(model_groups)
787
972
  df["raw_answer"] = df["answer"].copy()
788
- df["answer"] = df["raw_answer"].apply(self._compute_expected_rating)
973
+ df["probs"] = df["raw_answer"].apply(self._get_normalized_probs)
974
+ df["answer"] = df["probs"].apply(self._compute_expected_rating)
789
975
  return df
790
976
 
791
977
  def _get_normalized_probs(self, score: dict | None) -> dict[int, float] | None:
@@ -813,65 +999,11 @@ class Rating(Question):
813
999
 
814
1000
  return {k: v / total for k, v in probs.items()}
815
1001
 
816
- def _compute_expected_rating(self, score: dict | None) -> float | None:
817
- """Compute expected rating from logprobs distribution."""
818
- if score is None:
819
- mid_value = (self.min_rating + self.max_rating) / 2
820
- warnings.warn(f"Got None from API (should be impossible). Returning middle value {mid_value}.")
821
- return mid_value
822
-
823
- probs = self._get_normalized_probs(score)
1002
+ def _compute_expected_rating(self, probs: dict[int, float] | None) -> float | None:
824
1003
  if probs is None:
825
1004
  return None
826
-
827
1005
  return sum(rating * prob for rating, prob in probs.items())
828
1006
 
829
- def plot(
830
- self,
831
- model_groups: dict[str, list[str]],
832
- category_column: str = "group",
833
- df: pd.DataFrame = None,
834
- show_mean: bool = True,
835
- title: str = None,
836
- filename: str = None,
837
- ):
838
- """Plot cumulative rating distribution by category.
839
-
840
- Shows the probability distribution across the rating range for each category,
841
- with optional mean markers.
842
-
843
- Args:
844
- model_groups: Required. Dict mapping group names to lists of model identifiers.
845
- category_column: Column to use for grouping. Default: "group".
846
- df: DataFrame to plot. By default calls self.df(model_groups).
847
- show_mean: If True, displays mean rating for each category. Default: True.
848
- title: Plot title. If None, auto-generated from paraphrases.
849
- filename: If provided, saves the plot to this file path.
850
-
851
- Returns:
852
- matplotlib Figure object.
853
- """
854
- if df is None:
855
- df = self.df(model_groups)
856
-
857
- if title is None:
858
- title = default_title(self.paraphrases)
859
-
860
- # Pre-normalize probabilities
861
- df = df.copy()
862
- df["probs"] = df["raw_answer"].apply(self._get_normalized_probs)
863
-
864
- return rating_cumulative_plot(
865
- df,
866
- min_rating=self.min_rating,
867
- max_rating=self.max_rating,
868
- category_column=category_column,
869
- model_groups=model_groups,
870
- show_mean=show_mean,
871
- title=title,
872
- filename=filename,
873
- )
874
-
875
1007
 
876
1008
  class NextToken(Question):
877
1009
  """Question type for analyzing next-token probability distributions.
@@ -919,71 +1051,4 @@ class NextToken(Question):
919
1051
  el["params"]["top_logprobs"] = self.top_logprobs
920
1052
  el["convert_to_probs"] = self.convert_to_probs
921
1053
  el["num_samples"] = self.num_samples
922
- return runner_input
923
-
924
- def df(self, model_groups: dict[str, list[str]]) -> pd.DataFrame:
925
- """Execute question and return results as a DataFrame.
926
-
927
- Runs the question on all models (or loads from cache).
928
-
929
- Args:
930
- model_groups: Dict mapping group names to lists of model identifiers.
931
- Example: {"gpt4": ["gpt-4o", "gpt-4-turbo"], "claude": ["claude-3-opus"]}
932
-
933
- Returns:
934
- DataFrame with columns:
935
- - model: Model identifier
936
- - group: Group name from model_groups
937
- - answer: Dict mapping tokens to probabilities {token: prob}
938
- - question: The prompt that was sent
939
- - messages: Full message list sent to model
940
- - paraphrase_ix: Index of the paraphrase used
941
- """
942
- return super().df(model_groups)
943
-
944
- def plot(
945
- self,
946
- model_groups: dict[str, list[str]],
947
- category_column: str = "group",
948
- df: pd.DataFrame = None,
949
- selected_answers: list[str] = None,
950
- min_fraction: float = None,
951
- colors: dict[str, str] = None,
952
- title: str = None,
953
- filename: str = None,
954
- ):
955
- """Plot stacked bar chart of token probabilities by category.
956
-
957
- Args:
958
- model_groups: Required. Dict mapping group names to lists of model identifiers.
959
- category_column: Column to use for x-axis categories. Default: "group".
960
- df: DataFrame to plot. By default calls self.df(model_groups).
961
- selected_answers: List of specific tokens to include. Others grouped as "other".
962
- min_fraction: Minimum probability threshold. Tokens below this are grouped as "other".
963
- colors: Dict mapping token values to colors.
964
- title: Plot title. If None, auto-generated from paraphrases.
965
- filename: If provided, saves the plot to this file path.
966
-
967
- Returns:
968
- matplotlib Figure object.
969
- """
970
- if df is None:
971
- df = self.df(model_groups)
972
-
973
- if title is None:
974
- title = default_title(self.paraphrases)
975
-
976
- # answer column already contains {token: prob} dicts
977
- df = df.rename(columns={"answer": "probs"})
978
-
979
- return probs_stacked_bar(
980
- df,
981
- probs_column="probs",
982
- category_column=category_column,
983
- model_groups=model_groups,
984
- selected_answers=selected_answers,
985
- min_fraction=min_fraction,
986
- colors=colors,
987
- title=title,
988
- filename=filename,
989
- )
1054
+ return runner_input
@@ -1,18 +1,43 @@
1
1
  import hashlib
2
2
  import json
3
3
  import os
4
+ import tempfile
4
5
  from dataclasses import dataclass
5
6
  from datetime import datetime
6
- from typing import TYPE_CHECKING, Any
7
+ from typing import TYPE_CHECKING, Any, Callable, TextIO
8
+
9
+ import filelock
7
10
 
8
11
  from llmcomp.config import Config
9
12
  from llmcomp.runner.model_adapter import ModelAdapter
10
13
 
14
+
15
+ def atomic_write(path: str, write_fn: Callable[[TextIO], None]) -> None:
16
+ """Write to a file atomically with file locking.
17
+
18
+ Args:
19
+ path: Target file path.
20
+ write_fn: Function that takes a file handle and writes content.
21
+ """
22
+ dir_path = os.path.dirname(path)
23
+ os.makedirs(dir_path, exist_ok=True)
24
+
25
+ lock = filelock.FileLock(path + ".lock")
26
+ with lock:
27
+ fd, temp_path = tempfile.mkstemp(dir=dir_path, suffix=".tmp")
28
+ try:
29
+ with os.fdopen(fd, "w") as f:
30
+ write_fn(f)
31
+ os.replace(temp_path, path)
32
+ except:
33
+ os.unlink(temp_path)
34
+ raise
35
+
11
36
  if TYPE_CHECKING:
12
37
  from llmcomp.question.question import Question
13
38
 
14
39
  # Bump this to invalidate all cached results when the caching implementation changes.
15
- CACHE_VERSION = 2
40
+ CACHE_VERSION = 3
16
41
 
17
42
 
18
43
  def cache_hash(question: "Question", model: str) -> str:
@@ -80,12 +105,12 @@ class Result:
80
105
  return f"{Config.cache_dir}/question/{question.name}/{cache_hash(question, model)[:7]}.jsonl"
81
106
 
82
107
  def save(self):
83
- path = self.file_path(self.question, self.model)
84
- os.makedirs(os.path.dirname(path), exist_ok=True)
85
- with open(path, "w") as f:
108
+ def write_fn(f):
86
109
  f.write(json.dumps(self._metadata()) + "\n")
87
110
  for d in self.data:
88
111
  f.write(json.dumps(d) + "\n")
112
+
113
+ atomic_write(self.file_path(self.question, self.model), write_fn)
89
114
 
90
115
  @classmethod
91
116
  def load(cls, question: "Question", model: str) -> "Result":
@@ -189,18 +214,16 @@ class JudgeCache:
189
214
  return self._data
190
215
 
191
216
  def save(self):
192
- """Save cache to disk."""
217
+ """Save cache to disk with file locking for concurrent access."""
193
218
  if self._data is None:
194
219
  return
195
220
 
196
- path = self.file_path(self.judge)
197
- os.makedirs(os.path.dirname(path), exist_ok=True)
198
221
  file_data = {
199
222
  "metadata": self._metadata(),
200
223
  "data": self._data,
201
224
  }
202
- with open(path, "w") as f:
203
- json.dump(file_data, f, indent=2)
225
+
226
+ atomic_write(self.file_path(self.judge), lambda f: json.dump(file_data, f, indent=2))
204
227
 
205
228
  def _metadata(self) -> dict:
206
229
  return {