edsl 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. edsl/Base.py +15 -6
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +1 -1
  4. edsl/agents/PromptConstructor.py +92 -21
  5. edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
  6. edsl/agents/prompt_helpers.py +2 -2
  7. edsl/coop/coop.py +100 -22
  8. edsl/enums.py +3 -1
  9. edsl/exceptions/coop.py +4 -0
  10. edsl/inference_services/AnthropicService.py +2 -0
  11. edsl/inference_services/AvailableModelFetcher.py +4 -1
  12. edsl/inference_services/GoogleService.py +2 -0
  13. edsl/inference_services/GrokService.py +11 -0
  14. edsl/inference_services/InferenceServiceABC.py +1 -0
  15. edsl/inference_services/OpenAIService.py +1 -0
  16. edsl/inference_services/TestService.py +1 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +54 -35
  19. edsl/jobs/JobsChecks.py +7 -7
  20. edsl/jobs/JobsPrompts.py +57 -6
  21. edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
  22. edsl/jobs/buckets/BucketCollection.py +30 -0
  23. edsl/jobs/data_structures.py +1 -0
  24. edsl/language_models/LanguageModel.py +5 -2
  25. edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
  26. edsl/language_models/key_management/models.py +10 -4
  27. edsl/language_models/model.py +43 -11
  28. edsl/prompts/Prompt.py +124 -61
  29. edsl/questions/descriptors.py +32 -18
  30. edsl/questions/question_base_gen_mixin.py +1 -0
  31. edsl/results/DatasetExportMixin.py +35 -6
  32. edsl/results/Results.py +180 -1
  33. edsl/results/ResultsGGMixin.py +117 -60
  34. edsl/scenarios/FileStore.py +19 -8
  35. edsl/scenarios/Scenario.py +33 -0
  36. edsl/scenarios/ScenarioList.py +22 -3
  37. edsl/scenarios/ScenarioListPdfMixin.py +9 -3
  38. edsl/surveys/Survey.py +27 -6
  39. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/METADATA +3 -4
  40. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/RECORD +42 -41
  41. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/LICENSE +0 -0
  42. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py CHANGED
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
38
38
  from edsl.language_models.ModelList import ModelList
39
39
  from edsl.data.Cache import Cache
40
40
  from edsl.language_models.key_management.KeyLookup import KeyLookup
41
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
41
42
 
42
43
  VisibilityType = Literal["private", "public", "unlisted"]
43
44
 
@@ -407,7 +408,13 @@ class Jobs(Base):
407
408
  >>> bc
408
409
  BucketCollection(...)
409
410
  """
410
- return BucketCollection.from_models(self.models)
411
+ bc = BucketCollection.from_models(self.models)
412
+
413
+ if self.run_config.environment.key_lookup is not None:
414
+ bc.update_from_key_lookup(
415
+ self.run_config.environment.key_lookup
416
+ )
417
+ return bc
411
418
 
412
419
  def html(self):
413
420
  """Return the HTML representations for each scenario"""
@@ -465,22 +472,47 @@ class Jobs(Base):
465
472
 
466
473
  return False
467
474
 
475
+ def _start_remote_inference_job(
476
+ self, job_handler: Optional[JobsRemoteInferenceHandler] = None
477
+ ) -> Union["Results", None]:
478
+
479
+ if job_handler is None:
480
+ job_handler = self._create_remote_inference_handler()
481
+
482
+ job_info = job_handler.create_remote_inference_job(
483
+ iterations=self.run_config.parameters.n,
484
+ remote_inference_description=self.run_config.parameters.remote_inference_description,
485
+ remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
486
+ )
487
+ return job_info
488
+
489
+ def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
490
+
491
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
492
+
493
+ return JobsRemoteInferenceHandler(
494
+ self, verbose=self.run_config.parameters.verbose
495
+ )
496
+
468
497
  def _remote_results(
469
498
  self,
499
+ config: RunConfig,
470
500
  ) -> Union["Results", None]:
471
501
  from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
502
+ from edsl.jobs.JobsRemoteInferenceHandler import RemoteJobInfo
472
503
 
473
- jh = JobsRemoteInferenceHandler(
474
- self, verbose=self.run_config.parameters.verbose
475
- )
504
+ background = config.parameters.background
505
+
506
+ jh = self._create_remote_inference_handler()
476
507
  if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
477
- job_info = jh.create_remote_inference_job(
478
- iterations=self.run_config.parameters.n,
479
- remote_inference_description=self.run_config.parameters.remote_inference_description,
480
- remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
481
- )
482
- results = jh.poll_remote_inference_job(job_info)
483
- return results
508
+ job_info: RemoteJobInfo = self._start_remote_inference_job(jh)
509
+ if background:
510
+ from edsl.results.Results import Results
511
+ results = Results.from_job_info(job_info)
512
+ return results
513
+ else:
514
+ results = jh.poll_remote_inference_job(job_info)
515
+ return results
484
516
  else:
485
517
  return None
486
518
 
@@ -507,13 +539,6 @@ class Jobs(Base):
507
539
 
508
540
  assert isinstance(self.run_config.environment.cache, Cache)
509
541
 
510
- # with RemoteCacheSync(
511
- # coop=Coop(),
512
- # cache=self.run_config.environment.cache,
513
- # output_func=self._output,
514
- # remote_cache=use_remote_cache,
515
- # remote_cache_description=self.run_config.parameters.remote_cache_description,
516
- # ):
517
542
  runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
518
543
  if run_job_async:
519
544
  results = await runner.run_async(self.run_config.parameters)
@@ -521,19 +546,6 @@ class Jobs(Base):
521
546
  results = runner.run(self.run_config.parameters)
522
547
  return results
523
548
 
524
- # def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
525
- # self._prepare_to_run()
526
- # self._check_if_remote_keys_ok()
527
-
528
- # # first try to run the job remotely
529
- # results = self._remote_results()
530
- # #breakpoint()
531
- # if results is not None:
532
- # return results
533
-
534
- # self._check_if_local_keys_ok()
535
- # return None
536
-
537
549
  @property
538
550
  def num_interviews(self):
539
551
  if self.run_config.parameters.n is None:
@@ -563,7 +575,6 @@ class Jobs(Base):
563
575
 
564
576
  self.replace_missing_objects()
565
577
 
566
- # try to run remotely first
567
578
  self._prepare_to_run()
568
579
  self._check_if_remote_keys_ok()
569
580
 
@@ -581,9 +592,9 @@ class Jobs(Base):
581
592
  self.run_config.environment.cache = Cache(immediate_write=False)
582
593
 
583
594
  # first try to run the job remotely
584
- if results := self._remote_results():
595
+ if (results := self._remote_results(config)) is not None:
585
596
  return results
586
-
597
+
587
598
  self._check_if_local_keys_ok()
588
599
 
589
600
  if config.environment.bucket_collection is None:
@@ -591,6 +602,14 @@ class Jobs(Base):
591
602
  self.create_bucket_collection()
592
603
  )
593
604
 
605
+ if (
606
+ self.run_config.environment.key_lookup is not None
607
+ and self.run_config.environment.bucket_collection is not None
608
+ ):
609
+ self.run_config.environment.bucket_collection.update_from_key_lookup(
610
+ self.run_config.environment.key_lookup
611
+ )
612
+
594
613
  return None
595
614
 
596
615
  @with_config
@@ -613,7 +632,7 @@ class Jobs(Base):
613
632
  :param key_lookup: A KeyLookup object to manage API keys
614
633
  """
615
634
  potentially_completed_results = self._run(config)
616
-
635
+
617
636
  if potentially_completed_results is not None:
618
637
  return potentially_completed_results
619
638
 
edsl/jobs/JobsChecks.py CHANGED
@@ -31,7 +31,7 @@ class JobsChecks:
31
31
  from edsl.language_models.model import Model
32
32
  from edsl.enums import service_to_api_keyname
33
33
 
34
- for model in self.jobs.models + [Model()]:
34
+ for model in self.jobs.models: # + [Model()]:
35
35
  if not model.has_valid_api_key():
36
36
  key_name = service_to_api_keyname.get(
37
37
  model._inference_service_, "NOT FOUND"
@@ -134,22 +134,22 @@ class JobsChecks:
134
134
 
135
135
  edsl_auth_token = secrets.token_urlsafe(16)
136
136
 
137
- print("You're missing some of the API keys needed to run this job:")
137
+ print("API keys are required to run surveys with language models. The following keys are needed to run this survey: ")
138
138
  for api_key in missing_api_keys:
139
139
  print(f" 🔑 {api_key}")
140
140
  print(
141
- "\nYou can either add the missing keys to your .env file, or use remote inference."
141
+ "\nYou can provide your own keys or use an Expected Parrot key to access all available models."
142
142
  )
143
- print("Remote inference allows you to run jobs on our server.")
143
+ print("Please see the documentation page to learn about options for managing keys: https://docs.expectedparrot.com/en/latest/api_keys.html")
144
144
 
145
145
  coop = Coop()
146
146
  coop._display_login_url(
147
147
  edsl_auth_token=edsl_auth_token,
148
- link_description="\n🚀 To use remote inference, sign up at the following link:",
148
+ link_description="\n➡️ Click the link below to create an account and get an Expected Parrot key:\n",
149
149
  )
150
150
 
151
151
  print(
152
- "\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
152
+ "\nOnce you log in, your key will be stored on your computer and your survey will start running at the Expected Parrot server."
153
153
  )
154
154
 
155
155
  api_key = coop._poll_for_api_key(edsl_auth_token)
@@ -159,7 +159,7 @@ class JobsChecks:
159
159
  return
160
160
 
161
161
  path_to_env = write_api_key_to_env(api_key)
162
- print("\n✨ API key retrieved and written to .env file at the following path:")
162
+ print("\n✨ Your key has been stored at the following path: ")
163
163
  print(f" {path_to_env}")
164
164
 
165
165
  # Retrieve API key so we can continue running the job
edsl/jobs/JobsPrompts.py CHANGED
@@ -1,3 +1,5 @@
1
+ import time
2
+ import logging
1
3
  from typing import List, TYPE_CHECKING
2
4
 
3
5
  from edsl.results.Dataset import Dataset
@@ -14,6 +16,7 @@ if TYPE_CHECKING:
14
16
  from edsl.jobs.FetchInvigilator import FetchInvigilator
15
17
  from edsl.data.CacheEntry import CacheEntry
16
18
 
19
+ logger = logging.getLogger(__name__)
17
20
 
18
21
  class JobsPrompts:
19
22
  def __init__(self, jobs: "Jobs"):
@@ -22,6 +25,8 @@ class JobsPrompts:
22
25
  self.scenarios = jobs.scenarios
23
26
  self.survey = jobs.survey
24
27
  self._price_lookup = None
28
+ self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
29
+ self._scenario_lookup = {scenario: idx for idx, scenario in enumerate(self.scenarios)}
25
30
 
26
31
  @property
27
32
  def price_lookup(self):
@@ -49,25 +54,53 @@ class JobsPrompts:
49
54
  models = []
50
55
  costs = []
51
56
  cache_keys = []
57
+
52
58
  for interview_index, interview in enumerate(interviews):
59
+ logger.info(f"Processing interview {interview_index} of {len(interviews)}")
60
+ interview_start = time.time()
61
+
62
+ # Fetch invigilators timing
63
+ invig_start = time.time()
53
64
  invigilators = [
54
65
  FetchInvigilator(interview)(question)
55
66
  for question in interview.survey.questions
56
67
  ]
68
+ invig_end = time.time()
69
+ logger.debug(f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s")
70
+
71
+ # Process prompts timing
72
+ prompts_start = time.time()
57
73
  for _, invigilator in enumerate(invigilators):
74
+ # Get prompts timing
75
+ get_prompts_start = time.time()
58
76
  prompts = invigilator.get_prompts()
77
+ get_prompts_end = time.time()
78
+ logger.debug(f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s")
79
+
59
80
  user_prompt = prompts["user_prompt"]
60
81
  system_prompt = prompts["system_prompt"]
61
82
  user_prompts.append(user_prompt)
62
83
  system_prompts.append(system_prompt)
63
- agent_index = self.agents.index(invigilator.agent)
84
+
85
+ # Index lookups timing
86
+ index_start = time.time()
87
+ agent_index = self._agent_lookup[invigilator.agent]
64
88
  agent_indices.append(agent_index)
65
89
  interview_indices.append(interview_index)
66
- scenario_index = self.scenarios.index(invigilator.scenario)
90
+ scenario_index = self._scenario_lookup[invigilator.scenario]
67
91
  scenario_indices.append(scenario_index)
92
+ index_end = time.time()
93
+ logger.debug(f"Time taken for index lookups: {index_end - index_start:.4f}s")
94
+
95
+ # Model and question name assignment timing
96
+ assign_start = time.time()
68
97
  models.append(invigilator.model.model)
69
98
  question_names.append(invigilator.question.question_name)
99
+ assign_end = time.time()
100
+ logger.debug(f"Time taken for assignments: {assign_end - assign_start:.4f}s")
70
101
 
102
+ # Cost estimation timing
103
+ cost_start = time.time()
71
104
  prompt_cost = self.estimate_prompt_cost(
72
105
  system_prompt=system_prompt,
73
106
  user_prompt=user_prompt,
@@ -75,16 +108,34 @@ class JobsPrompts:
75
108
  inference_service=invigilator.model._inference_service_,
76
109
  model=invigilator.model.model,
77
110
  )
111
+ cost_end = time.time()
112
+ logger.debug(f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s")
78
113
  costs.append(prompt_cost["cost_usd"])
79
114
 
115
+ # Cache key generation timing
116
+ cache_key_gen_start = time.time()
80
117
  cache_key = CacheEntry.gen_key(
81
118
  model=invigilator.model.model,
82
119
  parameters=invigilator.model.parameters,
83
120
  system_prompt=system_prompt,
84
121
  user_prompt=user_prompt,
85
- iteration=0, # TODO how to handle when there are multiple iterations?
122
+ iteration=0,
86
123
  )
124
+ cache_key_gen_end = time.time()
87
125
  cache_keys.append(cache_key)
126
+ logger.debug(f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s")
127
+ logger.debug("-" * 50) # Separator between iterations
128
+
129
+ prompts_end = time.time()
130
+ logger.info(f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s")
131
+
132
+ interview_end = time.time()
133
+ logger.info(f"Overall time taken for interview: {interview_end - interview_start:.4f}s")
134
+ logger.info("Time breakdown:")
135
+ logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
136
+ logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
137
+ logger.info(f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s")
138
+
88
139
  d = Dataset(
89
140
  [
90
141
  {"user_prompt": user_prompts},
@@ -149,10 +200,10 @@ class JobsPrompts:
149
200
  import warnings
150
201
 
151
202
  warnings.warn(
152
- "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
203
+ "Price data could not be retrieved. Using default estimates for input and output token prices. Input: $1.00 / 1M tokens; Output: $1.00 / 1M tokens"
153
204
  )
154
- input_price_per_token = 0.00000015 # $0.15 / 1M tokens
155
- output_price_per_token = 0.00000060 # $0.60 / 1M tokens
205
+ input_price_per_token = 0.000001 # $1.00 / 1M tokens
206
+ output_price_per_token = 0.000001 # $1.00 / 1M tokens
156
207
 
157
208
  # Compute the number of characters (double if the question involves piping)
158
209
  user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
@@ -228,6 +228,40 @@ class JobsRemoteInferenceHandler:
228
228
  results.results_uuid = results_uuid
229
229
  return results
230
230
 
231
+ def _attempt_fetch_job(
232
+ self,
233
+ job_info: RemoteJobInfo,
234
+ remote_job_data_fetcher: Callable,
235
+ object_fetcher: Callable,
236
+ ) -> Union[None, "Results", Literal["continue"]]:
237
+ """Makes one attempt to fetch and process a remote job's status and results."""
238
+ remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
239
+ status = remote_job_data.get("status")
240
+
241
+ if status == "cancelled":
242
+ self._handle_cancelled_job(job_info)
243
+ return None
244
+
245
+ elif status == "failed" or status == "completed":
246
+ if status == "failed":
247
+ self._handle_failed_job(job_info, remote_job_data)
248
+
249
+ results_uuid = remote_job_data.get("results_uuid")
250
+ if results_uuid:
251
+ results = self._fetch_results_and_log(
252
+ job_info=job_info,
253
+ results_uuid=results_uuid,
254
+ remote_job_data=remote_job_data,
255
+ object_fetcher=object_fetcher,
256
+ )
257
+ return results
258
+ else:
259
+ return None
260
+
261
+ else:
262
+ self._sleep_for_a_bit(job_info, status)
263
+ return "continue"
264
+
231
265
  def poll_remote_inference_job(
232
266
  self,
233
267
  job_info: RemoteJobInfo,
@@ -242,31 +276,13 @@ class JobsRemoteInferenceHandler:
242
276
 
243
277
  job_in_queue = True
244
278
  while job_in_queue:
245
- remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
246
- status = remote_job_data.get("status")
247
-
248
- if status == "cancelled":
249
- self._handle_cancelled_job(job_info)
250
- return None
251
-
252
- elif status == "failed" or status == "completed":
253
- if status == "failed":
254
- self._handle_failed_job(job_info, remote_job_data)
255
-
256
- results_uuid = remote_job_data.get("results_uuid")
257
- if results_uuid:
258
- results = self._fetch_results_and_log(
259
- job_info=job_info,
260
- results_uuid=results_uuid,
261
- remote_job_data=remote_job_data,
262
- object_fetcher=object_fetcher,
263
- )
264
- return results
265
- else:
266
- return None
267
-
268
- else:
269
- self._sleep_for_a_bit(job_info, status)
279
+ result = self._attempt_fetch_job(
280
+ job_info,
281
+ remote_job_data_fetcher,
282
+ object_fetcher
283
+ )
284
+ if result != "continue":
285
+ return result
270
286
 
271
287
  async def create_and_poll_remote_job(
272
288
  self,
@@ -96,6 +96,36 @@ class BucketCollection(UserDict):
96
96
  else:
97
97
  self[model] = self.services_to_buckets[self.models_to_services[model.model]]
98
98
 
99
+ def update_from_key_lookup(self, key_lookup: "KeyLookup") -> None:
100
+ """Updates the bucket collection rates based on model RPM/TPM from KeyLookup"""
101
+
102
+ for model_name, service in self.models_to_services.items():
103
+ if service in key_lookup and not self.infinity_buckets:
104
+
105
+ if key_lookup[service].rpm is not None:
106
+ new_rps = key_lookup[service].rpm / 60.0
107
+ new_requests_bucket = TokenBucket(
108
+ bucket_name=service,
109
+ bucket_type="requests",
110
+ capacity=new_rps,
111
+ refill_rate=new_rps,
112
+ remote_url=self.remote_url,
113
+ )
114
+ self.services_to_buckets[service].requests_bucket = (
115
+ new_requests_bucket
116
+ )
117
+
118
+ if key_lookup[service].tpm is not None:
119
+ new_tps = key_lookup[service].tpm / 60.0
120
+ new_tokens_bucket = TokenBucket(
121
+ bucket_name=service,
122
+ bucket_type="tokens",
123
+ capacity=new_tps,
124
+ refill_rate=new_tps,
125
+ remote_url=self.remote_url,
126
+ )
127
+ self.services_to_buckets[service].tokens_bucket = new_tokens_bucket
128
+
99
129
  def visualize(self) -> dict:
100
130
  """Visualize the token and request buckets for each model."""
101
131
  plots = {}
@@ -32,6 +32,7 @@ class RunParameters(Base):
32
32
  remote_inference_results_visibility: Optional[VisibilityType] = "unlisted"
33
33
  skip_retry: bool = False
34
34
  raise_validation_errors: bool = False
35
+ background: bool = False
35
36
  disable_remote_cache: bool = False
36
37
  disable_remote_inference: bool = False
37
38
  job_uuid: Optional[str] = None
@@ -518,7 +518,11 @@ class LanguageModel(
518
518
  """
519
519
  from edsl.language_models.model import get_model_class
520
520
 
521
- model_class = get_model_class(data["model"])
521
+ # breakpoint()
522
+
523
+ model_class = get_model_class(
524
+ data["model"], service_name=data.get("inference_service", None)
525
+ )
522
526
  return model_class(**data)
523
527
 
524
528
  def __repr__(self) -> str:
@@ -574,7 +578,6 @@ class LanguageModel(
574
578
  return Model(skip_api_key_check=True)
575
579
 
576
580
  def from_cache(self, cache: "Cache") -> LanguageModel:
577
-
578
581
  from copy import deepcopy
579
582
  from types import MethodType
580
583
  from edsl import Cache
@@ -61,7 +61,14 @@ class KeyLookupBuilder:
61
61
  DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE"))
62
62
  DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE"))
63
63
 
64
- def __init__(self, fetch_order: Optional[tuple[str]] = None):
64
+ def __init__(
65
+ self,
66
+ fetch_order: Optional[tuple[str]] = None,
67
+ coop: Optional["Coop"] = None,
68
+ ):
69
+ from edsl.coop import Coop
70
+
71
+ # Fetch order goes from lowest priority to highest priority
65
72
  if fetch_order is None:
66
73
  self.fetch_order = ("config", "env")
67
74
  else:
@@ -70,6 +77,11 @@ class KeyLookupBuilder:
70
77
  if not isinstance(self.fetch_order, tuple):
71
78
  raise ValueError("fetch_order must be a tuple")
72
79
 
80
+ if coop is None:
81
+ self.coop = Coop()
82
+ else:
83
+ self.coop = coop
84
+
73
85
  self.limit_data = {}
74
86
  self.key_data = {}
75
87
  self.id_data = {}
@@ -131,7 +143,8 @@ class KeyLookupBuilder:
131
143
  service=service,
132
144
  rpm=self.DEFAULT_RPM,
133
145
  tpm=self.DEFAULT_TPM,
134
- source="default",
146
+ rpm_source="default",
147
+ tpm_source="default",
135
148
  )
136
149
 
137
150
  if limit_entry.rpm is None:
@@ -145,7 +158,8 @@ class KeyLookupBuilder:
145
158
  tpm=int(limit_entry.tpm),
146
159
  api_id=api_id,
147
160
  token_source=api_key_entry.source,
148
- limit_source=limit_entry.source,
161
+ rpm_source=limit_entry.rpm_source,
162
+ tpm_source=limit_entry.tpm_source,
149
163
  id_source=id_source,
150
164
  )
151
165
 
@@ -156,10 +170,7 @@ class KeyLookupBuilder:
156
170
  return dict(list(os.environ.items()))
157
171
 
158
172
  def _coop_key_value_pairs(self):
159
- from edsl.coop import Coop
160
-
161
- c = Coop()
162
- return dict(list(c.fetch_rate_limit_config_vars().items()))
173
+ return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
163
174
 
164
175
  def _config_key_value_pairs(self):
165
176
  from edsl.config import CONFIG
@@ -169,7 +180,7 @@ class KeyLookupBuilder:
169
180
  @staticmethod
170
181
  def extract_service(key: str) -> str:
171
182
  """Extract the service and limit type from the key"""
172
- limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_")
183
+ limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1)
173
184
  return service_raw.lower(), limit_type.lower()
174
185
 
175
186
  def get_key_value_pairs(self) -> dict:
@@ -187,17 +198,17 @@ class KeyLookupBuilder:
187
198
  d[k] = (v, source)
188
199
  return d
189
200
 
190
- def _entry_type(self, key, value) -> str:
201
+ def _entry_type(self, key: str) -> str:
191
202
  """Determine the type of entry from a key.
192
203
 
193
204
  >>> builder = KeyLookupBuilder()
194
- >>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI", "60")
205
+ >>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI")
195
206
  'limit'
196
- >>> builder._entry_type("OPENAI_API_KEY", "sk-1234")
207
+ >>> builder._entry_type("OPENAI_API_KEY")
197
208
  'api_key'
198
- >>> builder._entry_type("AWS_ACCESS_KEY_ID", "AKIA1234")
209
+ >>> builder._entry_type("AWS_ACCESS_KEY_ID")
199
210
  'api_id'
200
- >>> builder._entry_type("UNKNOWN_KEY", "value")
211
+ >>> builder._entry_type("UNKNOWN_KEY")
201
212
  'unknown'
202
213
  """
203
214
  if key.startswith("EDSL_SERVICE_"):
@@ -243,11 +254,13 @@ class KeyLookupBuilder:
243
254
  service, limit_type = self.extract_service(key)
244
255
  if service in self.limit_data:
245
256
  setattr(self.limit_data[service], limit_type.lower(), value)
257
+ setattr(self.limit_data[service], f"{limit_type}_source", source)
246
258
  else:
247
259
  new_limit_entry = LimitEntry(
248
- service=service, rpm=None, tpm=None, source=source
260
+ service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None
249
261
  )
250
262
  setattr(new_limit_entry, limit_type.lower(), value)
263
+ setattr(new_limit_entry, f"{limit_type}_source", source)
251
264
  self.limit_data[service] = new_limit_entry
252
265
 
253
266
  def _add_api_key(self, key: str, value: str, source: str) -> None:
@@ -265,13 +278,27 @@ class KeyLookupBuilder:
265
278
  else:
266
279
  self.key_data[service].append(new_entry)
267
280
 
268
- def process_key_value_pairs(self) -> None:
269
- """Process all key-value pairs from the configured sources."""
270
- for key, value_pair in self.get_key_value_pairs().items():
281
+ def update_from_dict(self, d: dict) -> None:
282
+ """
283
+ Update data from a dictionary of key-value pairs.
284
+ Each key is a key name, and each value is a tuple of (value, source).
285
+
286
+ >>> builder = KeyLookupBuilder()
287
+ >>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")})
288
+ >>> 'sk-1234' == builder.key_data["openai"][-1].value
289
+ True
290
+ >>> 'custodial_keys' == builder.key_data["openai"][-1].source
291
+ True
292
+ """
293
+ for key, value_pair in d.items():
271
294
  value, source = value_pair
272
- if (entry_type := self._entry_type(key, value)) == "limit":
295
+ if self._entry_type(key) == "limit":
273
296
  self._add_limit(key, value, source)
274
- elif entry_type == "api_key":
297
+ elif self._entry_type(key) == "api_key":
275
298
  self._add_api_key(key, value, source)
276
- elif entry_type == "api_id":
299
+ elif self._entry_type(key) == "api_id":
277
300
  self._add_id(key, value, source)
301
+
302
+ def process_key_value_pairs(self) -> None:
303
+ """Process all key-value pairs from the configured sources."""
304
+ self.update_from_dict(self.get_key_value_pairs())
@@ -40,18 +40,23 @@ class LimitEntry:
40
40
  60
41
41
  >>> limit.tpm
42
42
  100000
43
- >>> limit.source
43
+ >>> limit.rpm_source
44
44
  'config'
45
+ >>> limit.tpm_source
46
+ 'env'
45
47
  """
46
48
 
47
49
  service: str
48
50
  rpm: int
49
51
  tpm: int
50
- source: Optional[str] = None
52
+ rpm_source: Optional[str] = None
53
+ tpm_source: Optional[str] = None
51
54
 
52
55
  @classmethod
53
56
  def example(cls):
54
- return LimitEntry(service="openai", rpm=60, tpm=100000, source="config")
57
+ return LimitEntry(
58
+ service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env"
59
+ )
55
60
 
56
61
 
57
62
  @dataclass
@@ -108,7 +113,8 @@ class LanguageModelInput:
108
113
  tpm: int
109
114
  api_id: Optional[str] = None
110
115
  token_source: Optional[str] = None
111
- limit_source: Optional[str] = None
116
+ rpm_source: Optional[str] = None
117
+ tpm_source: Optional[str] = None
112
118
  id_source: Optional[str] = None
113
119
 
114
120
  def to_dict(self):