edsl 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,16 +13,15 @@ if TYPE_CHECKING:
13
13
  from ..invigilators.invigilator_base import Invigilator
14
14
 
15
15
  from .fetch_invigilator import FetchInvigilator
16
+ from ..coop.utils import CostConverter
16
17
  from ..caching import CacheEntry
17
18
  from ..dataset import Dataset
19
+ from ..language_models.price_manager import PriceRetriever
18
20
 
19
21
  logger = logging.getLogger(__name__)
20
22
 
21
23
 
22
24
  class PromptCostEstimator:
23
-
24
- DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS = 1.0
25
- DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS = 1.0
26
25
  CHARS_PER_TOKEN = 4
27
26
  OUTPUT_TOKENS_PER_INPUT_TOKEN = 0.75
28
27
  PIPING_MULTIPLIER = 2
@@ -37,7 +36,7 @@ class PromptCostEstimator:
37
36
  ):
38
37
  self.system_prompt = system_prompt
39
38
  self.user_prompt = user_prompt
40
- self.price_lookup = price_lookup
39
+ self.price_retriever = PriceRetriever(price_lookup)
41
40
  self.inference_service = inference_service
42
41
  self.model = model
43
42
 
@@ -49,91 +48,6 @@ class PromptCostEstimator:
49
48
  return PromptCostEstimator.PIPING_MULTIPLIER
50
49
  return 1
51
50
 
52
- def _get_fallback_price(self, inference_service: str) -> Dict:
53
- """
54
- Get fallback prices for a service.
55
- - First fallback: The highest input and output prices for that service from the price lookup.
56
- - Second fallback: $1.00 per million tokens (for both input and output).
57
-
58
- Args:
59
- inference_service (str): The inference service name
60
-
61
- Returns:
62
- Dict: Price information
63
- """
64
- PriceEntry = namedtuple("PriceEntry", ["tokens_per_usd", "price_info"])
65
-
66
- service_prices = [
67
- prices
68
- for (service, _), prices in self.price_lookup.items()
69
- if service == inference_service
70
- ]
71
-
72
- default_input_price_info = {
73
- "one_usd_buys": 1_000_000,
74
- "service_stated_token_qty": 1_000_000,
75
- "service_stated_token_price": self.DEFAULT_INPUT_PRICE_PER_MILLION_TOKENS,
76
- }
77
- default_output_price_info = {
78
- "one_usd_buys": 1_000_000,
79
- "service_stated_token_qty": 1_000_000,
80
- "service_stated_token_price": self.DEFAULT_OUTPUT_PRICE_PER_MILLION_TOKENS,
81
- }
82
-
83
- # Find the most expensive price entries (lowest tokens per USD)
84
- input_price_info = default_input_price_info
85
- output_price_info = default_output_price_info
86
-
87
- input_prices = [
88
- PriceEntry(float(p["input"]["one_usd_buys"]), p["input"])
89
- for p in service_prices
90
- if "input" in p
91
- ]
92
- if input_prices:
93
- input_price_info = min(
94
- input_prices, key=lambda price: price.tokens_per_usd
95
- ).price_info
96
-
97
- output_prices = [
98
- PriceEntry(float(p["output"]["one_usd_buys"]), p["output"])
99
- for p in service_prices
100
- if "output" in p
101
- ]
102
- if output_prices:
103
- output_price_info = min(
104
- output_prices, key=lambda price: price.tokens_per_usd
105
- ).price_info
106
-
107
- return {
108
- "input": input_price_info,
109
- "output": output_price_info,
110
- }
111
-
112
- def get_price(self, inference_service: str, model: str) -> Dict:
113
- """Get the price information for a specific service and model."""
114
- key = (inference_service, model)
115
- return self.price_lookup.get(key) or self._get_fallback_price(inference_service)
116
-
117
- def get_price_per_million_tokens(
118
- self,
119
- relevant_prices: Dict,
120
- token_type: Literal["input", "output"],
121
- ) -> Dict:
122
- """
123
- Get the price per million tokens for a specific service, model, and token type.
124
- """
125
- service_price = relevant_prices[token_type]["service_stated_token_price"]
126
- service_qty = relevant_prices[token_type]["service_stated_token_qty"]
127
-
128
- if service_qty == 1_000_000:
129
- price_per_million_tokens = service_price
130
- elif service_qty == 1_000:
131
- price_per_million_tokens = service_price * 1_000
132
- else:
133
- price_per_token = service_price / service_qty
134
- price_per_million_tokens = round(price_per_token * 1_000_000, 10)
135
- return price_per_million_tokens
136
-
137
51
  def __call__(self):
138
52
  user_prompt_chars = len(str(self.user_prompt)) * self.get_piping_multiplier(
139
53
  str(self.user_prompt)
@@ -145,13 +59,15 @@ class PromptCostEstimator:
145
59
  input_tokens = (user_prompt_chars + system_prompt_chars) // self.CHARS_PER_TOKEN
146
60
  output_tokens = math.ceil(self.OUTPUT_TOKENS_PER_INPUT_TOKEN * input_tokens)
147
61
 
148
- relevant_prices = self.get_price(self.inference_service, self.model)
62
+ relevant_prices = self.price_retriever.get_price(
63
+ self.inference_service, self.model
64
+ )
149
65
 
150
- input_price_per_million_tokens = self.get_price_per_million_tokens(
151
- relevant_prices, "input"
66
+ input_price_per_million_tokens = (
67
+ self.price_retriever.get_price_per_million_tokens(relevant_prices, "input")
152
68
  )
153
- output_price_per_million_tokens = self.get_price_per_million_tokens(
154
- relevant_prices, "output"
69
+ output_price_per_million_tokens = (
70
+ self.price_retriever.get_price_per_million_tokens(relevant_prices, "output")
155
71
  )
156
72
 
157
73
  input_price_per_token = input_price_per_million_tokens / 1_000_000
@@ -172,7 +88,6 @@ class PromptCostEstimator:
172
88
 
173
89
 
174
90
  class JobsPrompts:
175
-
176
91
  relevant_keys = [
177
92
  "user_prompt",
178
93
  "system_prompt",
@@ -255,13 +170,18 @@ class JobsPrompts:
255
170
  cost = prompt_cost["cost_usd"]
256
171
 
257
172
  # Generate cache keys for each iteration
173
+ files_list = prompts.get("files_list", None)
174
+ if files_list:
175
+ files_hash = "+".join([str(hash(file)) for file in files_list])
176
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
258
177
  cache_keys = []
178
+
259
179
  for iteration in range(iterations):
260
180
  cache_key = CacheEntry.gen_key(
261
181
  model=model,
262
182
  parameters=invigilator.model.parameters,
263
183
  system_prompt=system_prompt,
264
- user_prompt=user_prompt,
184
+ user_prompt=user_prompt_with_hashes if files_list else user_prompt,
265
185
  iteration=iteration,
266
186
  )
267
187
  cache_keys.append(cache_key)
@@ -366,20 +286,6 @@ class JobsPrompts:
366
286
  },
367
287
  )
368
288
 
369
- @staticmethod
370
- def usd_to_credits(usd: float) -> float:
371
- """Converts USD to credits."""
372
- cents = usd * 100
373
- credits_per_cent = 1
374
- credits = cents * credits_per_cent
375
-
376
- # Round up to the nearest hundredth of a credit
377
- minicredits = math.ceil(credits * 100)
378
-
379
- # Convert back to credits
380
- credits = round(minicredits / 100, 2)
381
- return credits
382
-
383
289
  def estimate_job_cost_from_external_prices(
384
290
  self, price_lookup: dict, iterations: int = 1
385
291
  ) -> dict:
@@ -444,14 +350,13 @@ class JobsPrompts:
444
350
  detailed_costs.append(group)
445
351
 
446
352
  # Convert to credits
353
+ converter = CostConverter()
447
354
  for group in detailed_costs:
448
- group["credits_hold"] = self.usd_to_credits(group["cost_usd"])
355
+ group["credits_hold"] = converter.usd_to_credits(group["cost_usd"])
449
356
 
450
357
  # Calculate totals
451
358
  estimated_total_cost_usd = sum(group["cost_usd"] for group in detailed_costs)
452
- total_credits_hold = sum(
453
- group["credits_hold"] for group in detailed_costs
454
- )
359
+ total_credits_hold = sum(group["credits_hold"] for group in detailed_costs)
455
360
  estimated_total_input_tokens = sum(
456
361
  group["tokens"]
457
362
  for group in detailed_costs
@@ -23,15 +23,40 @@ class LogMessage:
23
23
  status: JobsStatus
24
24
 
25
25
 
26
+ @dataclass
27
+ class JobRunExceptionCounter:
28
+ exception_type: str = None
29
+ inference_service: str = None
30
+ model: str = None
31
+ question_name: str = None
32
+ exception_count: int = None
33
+
34
+
35
+ @dataclass
36
+ class ModelCost:
37
+ service: str = None
38
+ model: str = None
39
+ input_tokens: int = None
40
+ input_cost_usd: float = None
41
+ output_tokens: int = None
42
+ output_cost_usd: float = None
43
+ input_cost_credits_with_cache: int = None
44
+ output_cost_credits_with_cache: int = None
45
+
46
+
26
47
  @dataclass
27
48
  class JobsInfo:
28
49
  job_uuid: str = None
29
50
  progress_bar_url: str = None
30
51
  error_report_url: str = None
52
+ remote_inference_url: str = None
53
+ remote_cache_url: str = None
31
54
  results_uuid: str = None
32
55
  results_url: str = None
33
56
  completed_interviews: int = None
34
57
  failed_interviews: int = None
58
+ exception_summary: list[JobRunExceptionCounter] = None
59
+ model_costs: list[ModelCost] = None
35
60
 
36
61
  pretty_names = {
37
62
  "job_uuid": "Job UUID",
@@ -39,6 +64,8 @@ class JobsInfo:
39
64
  "error_report_url": "Exceptions Report URL",
40
65
  "results_uuid": "Results UUID",
41
66
  "results_url": "Results URL",
67
+ "remote_inference_url": "Remote Jobs",
68
+ "remote_cache_url": "Remote Cache",
42
69
  }
43
70
 
44
71
 
@@ -57,6 +84,8 @@ class JobLogger(ABC):
57
84
  "results_url",
58
85
  "completed_interviews",
59
86
  "failed_interviews",
87
+ "model_costs",
88
+ "exception_summary",
60
89
  ],
61
90
  value: str,
62
91
  ):
@@ -11,15 +11,7 @@ from uuid import UUID
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from .jobs import Jobs
14
-
15
-
16
- @dataclass
17
- class ModelInfo:
18
- model_name: str
19
- TPM_limit_k: float
20
- RPM_limit_k: float
21
- num_tasks_waiting: int
22
- token_usage_info: dict
14
+ from ..interviews import Interview
23
15
 
24
16
 
25
17
  class StatisticsTracker:
@@ -29,16 +21,33 @@ class StatisticsTracker:
29
21
  self.completed_count = 0
30
22
  self.completed_by_model = defaultdict(int)
31
23
  self.distinct_models = distinct_models
24
+ self.interviews_with_exceptions = 0
32
25
  self.total_exceptions = 0
33
26
  self.unfixed_exceptions = 0
27
+ self.exceptions_counter = defaultdict(int)
34
28
 
35
29
  def add_completed_interview(
36
- self, model: str, num_exceptions: int = 0, num_unfixed: int = 0
30
+ self,
31
+ model: str,
32
+ exceptions: list[dict],
33
+ num_exceptions: int = 0,
34
+ num_unfixed: int = 0,
37
35
  ):
38
36
  self.completed_count += 1
39
37
  self.completed_by_model[model] += 1
40
38
  self.total_exceptions += num_exceptions
41
39
  self.unfixed_exceptions += num_unfixed
40
+ if num_exceptions > 0:
41
+ self.interviews_with_exceptions += 1
42
+
43
+ for exception in exceptions:
44
+ key = (
45
+ exception["exception_type"],
46
+ exception["inference_service"],
47
+ exception["model"],
48
+ exception["question_name"],
49
+ )
50
+ self.exceptions_counter[key] += 1
42
51
 
43
52
  def get_elapsed_time(self) -> float:
44
53
  return time.time() - self.start_time
@@ -88,9 +97,7 @@ class JobsRunnerStatusBase(ABC):
88
97
  ]
89
98
  self.num_total_interviews = n * len(self.jobs)
90
99
 
91
- self.distinct_models = list(
92
- set(model.model for model in self.jobs.models)
93
- )
100
+ self.distinct_models = list(set(model.model for model in self.jobs.models))
94
101
 
95
102
  self.stats_tracker = StatisticsTracker(
96
103
  total_interviews=self.num_total_interviews,
@@ -130,6 +137,7 @@ class JobsRunnerStatusBase(ABC):
130
137
  status_dict = {
131
138
  "overall_progress": {
132
139
  "completed": self.stats_tracker.completed_count,
140
+ "has_exceptions": self.stats_tracker.interviews_with_exceptions,
133
141
  "total": self.num_total_interviews,
134
142
  "percent": (
135
143
  (
@@ -148,16 +156,36 @@ class JobsRunnerStatusBase(ABC):
148
156
  if self.stats_tracker.completed_count >= self.num_total_interviews
149
157
  else "running"
150
158
  ),
159
+ "exceptions_counter": [
160
+ {
161
+ "exception_type": exception_type,
162
+ "inference_service": inference_service,
163
+ "model": model,
164
+ "question_name": question_name,
165
+ "count": count,
166
+ }
167
+ for (
168
+ exception_type,
169
+ inference_service,
170
+ model,
171
+ question_name,
172
+ ), count in self.stats_tracker.exceptions_counter.items()
173
+ ],
151
174
  }
152
175
 
153
176
  model_queues = {}
154
177
  # Check if bucket collection exists and is not empty
155
- if (hasattr(self.jobs, 'run_config') and
156
- hasattr(self.jobs.run_config, 'environment') and
157
- hasattr(self.jobs.run_config.environment, 'bucket_collection') and
158
- self.jobs.run_config.environment.bucket_collection):
159
-
160
- for model, bucket in self.jobs.run_config.environment.bucket_collection.items():
178
+ if (
179
+ hasattr(self.jobs, "run_config")
180
+ and hasattr(self.jobs.run_config, "environment")
181
+ and hasattr(self.jobs.run_config.environment, "bucket_collection")
182
+ and self.jobs.run_config.environment.bucket_collection
183
+ ):
184
+
185
+ for (
186
+ model,
187
+ bucket,
188
+ ) in self.jobs.run_config.environment.bucket_collection.items():
161
189
  model_name = model.model
162
190
  model_queues[model_name] = {
163
191
  "language_model_name": model_name,
@@ -166,7 +194,9 @@ class JobsRunnerStatusBase(ABC):
166
194
  "requested": bucket.requests_bucket.num_requests,
167
195
  "tokens_returned": bucket.requests_bucket.tokens_returned,
168
196
  "target_rate": round(bucket.requests_bucket.target_rate, 1),
169
- "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
197
+ "current_rate": round(
198
+ bucket.requests_bucket.get_throughput(), 1
199
+ ),
170
200
  },
171
201
  "tokens_bucket": {
172
202
  "completed": bucket.tokens_bucket.num_released,
@@ -179,10 +209,11 @@ class JobsRunnerStatusBase(ABC):
179
209
  status_dict["language_model_queues"] = model_queues
180
210
  return status_dict
181
211
 
182
- def add_completed_interview(self, interview):
212
+ def add_completed_interview(self, interview: "Interview"):
183
213
  """Records a completed interview without storing the full interview data."""
184
214
  self.stats_tracker.add_completed_interview(
185
215
  model=interview.model.model,
216
+ exceptions=interview.exceptions.list(),
186
217
  num_exceptions=interview.exceptions.num_exceptions(),
187
218
  num_unfixed=interview.exceptions.num_unfixed_exceptions(),
188
219
  )