edsl 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -6
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +1 -1
- edsl/agents/PromptConstructor.py +92 -21
- edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
- edsl/agents/prompt_helpers.py +2 -2
- edsl/coop/coop.py +100 -22
- edsl/enums.py +3 -1
- edsl/exceptions/coop.py +4 -0
- edsl/inference_services/AnthropicService.py +2 -0
- edsl/inference_services/AvailableModelFetcher.py +4 -1
- edsl/inference_services/GoogleService.py +2 -0
- edsl/inference_services/GrokService.py +11 -0
- edsl/inference_services/InferenceServiceABC.py +1 -0
- edsl/inference_services/OpenAIService.py +1 -0
- edsl/inference_services/TestService.py +1 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +54 -35
- edsl/jobs/JobsChecks.py +7 -7
- edsl/jobs/JobsPrompts.py +57 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
- edsl/jobs/buckets/BucketCollection.py +30 -0
- edsl/jobs/data_structures.py +1 -0
- edsl/language_models/LanguageModel.py +5 -2
- edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
- edsl/language_models/key_management/models.py +10 -4
- edsl/language_models/model.py +43 -11
- edsl/prompts/Prompt.py +124 -61
- edsl/questions/descriptors.py +32 -18
- edsl/questions/question_base_gen_mixin.py +1 -0
- edsl/results/DatasetExportMixin.py +35 -6
- edsl/results/Results.py +180 -1
- edsl/results/ResultsGGMixin.py +117 -60
- edsl/scenarios/FileStore.py +19 -8
- edsl/scenarios/Scenario.py +33 -0
- edsl/scenarios/ScenarioList.py +22 -3
- edsl/scenarios/ScenarioListPdfMixin.py +9 -3
- edsl/surveys/Survey.py +27 -6
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/METADATA +3 -4
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/RECORD +42 -41
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/LICENSE +0 -0
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/WHEEL +0 -0
edsl/jobs/Jobs.py
CHANGED
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
|
38
38
|
from edsl.language_models.ModelList import ModelList
|
39
39
|
from edsl.data.Cache import Cache
|
40
40
|
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
41
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
41
42
|
|
42
43
|
VisibilityType = Literal["private", "public", "unlisted"]
|
43
44
|
|
@@ -407,7 +408,13 @@ class Jobs(Base):
|
|
407
408
|
>>> bc
|
408
409
|
BucketCollection(...)
|
409
410
|
"""
|
410
|
-
|
411
|
+
bc = BucketCollection.from_models(self.models)
|
412
|
+
|
413
|
+
if self.run_config.environment.key_lookup is not None:
|
414
|
+
bc.update_from_key_lookup(
|
415
|
+
self.run_config.environment.key_lookup
|
416
|
+
)
|
417
|
+
return bc
|
411
418
|
|
412
419
|
def html(self):
|
413
420
|
"""Return the HTML representations for each scenario"""
|
@@ -465,22 +472,47 @@ class Jobs(Base):
|
|
465
472
|
|
466
473
|
return False
|
467
474
|
|
475
|
+
def _start_remote_inference_job(
|
476
|
+
self, job_handler: Optional[JobsRemoteInferenceHandler] = None
|
477
|
+
) -> Union["Results", None]:
|
478
|
+
|
479
|
+
if job_handler is None:
|
480
|
+
job_handler = self._create_remote_inference_handler()
|
481
|
+
|
482
|
+
job_info = job_handler.create_remote_inference_job(
|
483
|
+
iterations=self.run_config.parameters.n,
|
484
|
+
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
485
|
+
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
486
|
+
)
|
487
|
+
return job_info
|
488
|
+
|
489
|
+
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
490
|
+
|
491
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
492
|
+
|
493
|
+
return JobsRemoteInferenceHandler(
|
494
|
+
self, verbose=self.run_config.parameters.verbose
|
495
|
+
)
|
496
|
+
|
468
497
|
def _remote_results(
|
469
498
|
self,
|
499
|
+
config: RunConfig,
|
470
500
|
) -> Union["Results", None]:
|
471
501
|
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
502
|
+
from edsl.jobs.JobsRemoteInferenceHandler import RemoteJobInfo
|
472
503
|
|
473
|
-
|
474
|
-
|
475
|
-
)
|
504
|
+
background = config.parameters.background
|
505
|
+
|
506
|
+
jh = self._create_remote_inference_handler()
|
476
507
|
if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
|
477
|
-
job_info =
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
508
|
+
job_info: RemoteJobInfo = self._start_remote_inference_job(jh)
|
509
|
+
if background:
|
510
|
+
from edsl.results.Results import Results
|
511
|
+
results = Results.from_job_info(job_info)
|
512
|
+
return results
|
513
|
+
else:
|
514
|
+
results = jh.poll_remote_inference_job(job_info)
|
515
|
+
return results
|
484
516
|
else:
|
485
517
|
return None
|
486
518
|
|
@@ -507,13 +539,6 @@ class Jobs(Base):
|
|
507
539
|
|
508
540
|
assert isinstance(self.run_config.environment.cache, Cache)
|
509
541
|
|
510
|
-
# with RemoteCacheSync(
|
511
|
-
# coop=Coop(),
|
512
|
-
# cache=self.run_config.environment.cache,
|
513
|
-
# output_func=self._output,
|
514
|
-
# remote_cache=use_remote_cache,
|
515
|
-
# remote_cache_description=self.run_config.parameters.remote_cache_description,
|
516
|
-
# ):
|
517
542
|
runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
|
518
543
|
if run_job_async:
|
519
544
|
results = await runner.run_async(self.run_config.parameters)
|
@@ -521,19 +546,6 @@ class Jobs(Base):
|
|
521
546
|
results = runner.run(self.run_config.parameters)
|
522
547
|
return results
|
523
548
|
|
524
|
-
# def _setup_and_check(self) -> Tuple[RunConfig, Optional[Results]]:
|
525
|
-
# self._prepare_to_run()
|
526
|
-
# self._check_if_remote_keys_ok()
|
527
|
-
|
528
|
-
# # first try to run the job remotely
|
529
|
-
# results = self._remote_results()
|
530
|
-
# #breakpoint()
|
531
|
-
# if results is not None:
|
532
|
-
# return results
|
533
|
-
|
534
|
-
# self._check_if_local_keys_ok()
|
535
|
-
# return None
|
536
|
-
|
537
549
|
@property
|
538
550
|
def num_interviews(self):
|
539
551
|
if self.run_config.parameters.n is None:
|
@@ -563,7 +575,6 @@ class Jobs(Base):
|
|
563
575
|
|
564
576
|
self.replace_missing_objects()
|
565
577
|
|
566
|
-
# try to run remotely first
|
567
578
|
self._prepare_to_run()
|
568
579
|
self._check_if_remote_keys_ok()
|
569
580
|
|
@@ -581,9 +592,9 @@ class Jobs(Base):
|
|
581
592
|
self.run_config.environment.cache = Cache(immediate_write=False)
|
582
593
|
|
583
594
|
# first try to run the job remotely
|
584
|
-
if results := self._remote_results():
|
595
|
+
if (results := self._remote_results(config)) is not None:
|
585
596
|
return results
|
586
|
-
|
597
|
+
|
587
598
|
self._check_if_local_keys_ok()
|
588
599
|
|
589
600
|
if config.environment.bucket_collection is None:
|
@@ -591,6 +602,14 @@ class Jobs(Base):
|
|
591
602
|
self.create_bucket_collection()
|
592
603
|
)
|
593
604
|
|
605
|
+
if (
|
606
|
+
self.run_config.environment.key_lookup is not None
|
607
|
+
and self.run_config.environment.bucket_collection is not None
|
608
|
+
):
|
609
|
+
self.run_config.environment.bucket_collection.update_from_key_lookup(
|
610
|
+
self.run_config.environment.key_lookup
|
611
|
+
)
|
612
|
+
|
594
613
|
return None
|
595
614
|
|
596
615
|
@with_config
|
@@ -613,7 +632,7 @@ class Jobs(Base):
|
|
613
632
|
:param key_lookup: A KeyLookup object to manage API keys
|
614
633
|
"""
|
615
634
|
potentially_completed_results = self._run(config)
|
616
|
-
|
635
|
+
|
617
636
|
if potentially_completed_results is not None:
|
618
637
|
return potentially_completed_results
|
619
638
|
|
edsl/jobs/JobsChecks.py
CHANGED
@@ -31,7 +31,7 @@ class JobsChecks:
|
|
31
31
|
from edsl.language_models.model import Model
|
32
32
|
from edsl.enums import service_to_api_keyname
|
33
33
|
|
34
|
-
for model in self.jobs.models + [Model()]:
|
34
|
+
for model in self.jobs.models: # + [Model()]:
|
35
35
|
if not model.has_valid_api_key():
|
36
36
|
key_name = service_to_api_keyname.get(
|
37
37
|
model._inference_service_, "NOT FOUND"
|
@@ -134,22 +134,22 @@ class JobsChecks:
|
|
134
134
|
|
135
135
|
edsl_auth_token = secrets.token_urlsafe(16)
|
136
136
|
|
137
|
-
print("
|
137
|
+
print("API keys are required to run surveys with language models. The following keys are needed to run this survey: ")
|
138
138
|
for api_key in missing_api_keys:
|
139
139
|
print(f" 🔑 {api_key}")
|
140
140
|
print(
|
141
|
-
"\nYou can
|
141
|
+
"\nYou can provide your own keys or use an Expected Parrot key to access all available models."
|
142
142
|
)
|
143
|
-
print("
|
143
|
+
print("Please see the documentation page to learn about options for managing keys: https://docs.expectedparrot.com/en/latest/api_keys.html")
|
144
144
|
|
145
145
|
coop = Coop()
|
146
146
|
coop._display_login_url(
|
147
147
|
edsl_auth_token=edsl_auth_token,
|
148
|
-
link_description="\n
|
148
|
+
link_description="\n➡️ Click the link below to create an account and get an Expected Parrot key:\n",
|
149
149
|
)
|
150
150
|
|
151
151
|
print(
|
152
|
-
"\nOnce you log in,
|
152
|
+
"\nOnce you log in, your key will be stored on your computer and your survey will start running at the Expected Parrot server."
|
153
153
|
)
|
154
154
|
|
155
155
|
api_key = coop._poll_for_api_key(edsl_auth_token)
|
@@ -159,7 +159,7 @@ class JobsChecks:
|
|
159
159
|
return
|
160
160
|
|
161
161
|
path_to_env = write_api_key_to_env(api_key)
|
162
|
-
print("\n✨
|
162
|
+
print("\n✨ Your key has been stored at the following path: ")
|
163
163
|
print(f" {path_to_env}")
|
164
164
|
|
165
165
|
# Retrieve API key so we can continue running the job
|
edsl/jobs/JobsPrompts.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
import time
|
2
|
+
import logging
|
1
3
|
from typing import List, TYPE_CHECKING
|
2
4
|
|
3
5
|
from edsl.results.Dataset import Dataset
|
@@ -14,6 +16,7 @@ if TYPE_CHECKING:
|
|
14
16
|
from edsl.jobs.FetchInvigilator import FetchInvigilator
|
15
17
|
from edsl.data.CacheEntry import CacheEntry
|
16
18
|
|
19
|
+
logger = logging.getLogger(__name__)
|
17
20
|
|
18
21
|
class JobsPrompts:
|
19
22
|
def __init__(self, jobs: "Jobs"):
|
@@ -22,6 +25,8 @@ class JobsPrompts:
|
|
22
25
|
self.scenarios = jobs.scenarios
|
23
26
|
self.survey = jobs.survey
|
24
27
|
self._price_lookup = None
|
28
|
+
self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
|
29
|
+
self._scenario_lookup = {scenario: idx for idx, scenario in enumerate(self.scenarios)}
|
25
30
|
|
26
31
|
@property
|
27
32
|
def price_lookup(self):
|
@@ -49,25 +54,53 @@ class JobsPrompts:
|
|
49
54
|
models = []
|
50
55
|
costs = []
|
51
56
|
cache_keys = []
|
57
|
+
|
52
58
|
for interview_index, interview in enumerate(interviews):
|
59
|
+
logger.info(f"Processing interview {interview_index} of {len(interviews)}")
|
60
|
+
interview_start = time.time()
|
61
|
+
|
62
|
+
# Fetch invigilators timing
|
63
|
+
invig_start = time.time()
|
53
64
|
invigilators = [
|
54
65
|
FetchInvigilator(interview)(question)
|
55
66
|
for question in interview.survey.questions
|
56
67
|
]
|
68
|
+
invig_end = time.time()
|
69
|
+
logger.debug(f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s")
|
70
|
+
|
71
|
+
# Process prompts timing
|
72
|
+
prompts_start = time.time()
|
57
73
|
for _, invigilator in enumerate(invigilators):
|
74
|
+
# Get prompts timing
|
75
|
+
get_prompts_start = time.time()
|
58
76
|
prompts = invigilator.get_prompts()
|
77
|
+
get_prompts_end = time.time()
|
78
|
+
logger.debug(f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s")
|
79
|
+
|
59
80
|
user_prompt = prompts["user_prompt"]
|
60
81
|
system_prompt = prompts["system_prompt"]
|
61
82
|
user_prompts.append(user_prompt)
|
62
83
|
system_prompts.append(system_prompt)
|
63
|
-
|
84
|
+
|
85
|
+
# Index lookups timing
|
86
|
+
index_start = time.time()
|
87
|
+
agent_index = self._agent_lookup[invigilator.agent]
|
64
88
|
agent_indices.append(agent_index)
|
65
89
|
interview_indices.append(interview_index)
|
66
|
-
scenario_index = self.
|
90
|
+
scenario_index = self._scenario_lookup[invigilator.scenario]
|
67
91
|
scenario_indices.append(scenario_index)
|
92
|
+
index_end = time.time()
|
93
|
+
logger.debug(f"Time taken for index lookups: {index_end - index_start:.4f}s")
|
94
|
+
|
95
|
+
# Model and question name assignment timing
|
96
|
+
assign_start = time.time()
|
68
97
|
models.append(invigilator.model.model)
|
69
98
|
question_names.append(invigilator.question.question_name)
|
99
|
+
assign_end = time.time()
|
100
|
+
logger.debug(f"Time taken for assignments: {assign_end - assign_start:.4f}s")
|
70
101
|
|
102
|
+
# Cost estimation timing
|
103
|
+
cost_start = time.time()
|
71
104
|
prompt_cost = self.estimate_prompt_cost(
|
72
105
|
system_prompt=system_prompt,
|
73
106
|
user_prompt=user_prompt,
|
@@ -75,16 +108,34 @@ class JobsPrompts:
|
|
75
108
|
inference_service=invigilator.model._inference_service_,
|
76
109
|
model=invigilator.model.model,
|
77
110
|
)
|
111
|
+
cost_end = time.time()
|
112
|
+
logger.debug(f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s")
|
78
113
|
costs.append(prompt_cost["cost_usd"])
|
79
114
|
|
115
|
+
# Cache key generation timing
|
116
|
+
cache_key_gen_start = time.time()
|
80
117
|
cache_key = CacheEntry.gen_key(
|
81
118
|
model=invigilator.model.model,
|
82
119
|
parameters=invigilator.model.parameters,
|
83
120
|
system_prompt=system_prompt,
|
84
121
|
user_prompt=user_prompt,
|
85
|
-
iteration=0,
|
122
|
+
iteration=0,
|
86
123
|
)
|
124
|
+
cache_key_gen_end = time.time()
|
87
125
|
cache_keys.append(cache_key)
|
126
|
+
logger.debug(f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s")
|
127
|
+
logger.debug("-" * 50) # Separator between iterations
|
128
|
+
|
129
|
+
prompts_end = time.time()
|
130
|
+
logger.info(f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s")
|
131
|
+
|
132
|
+
interview_end = time.time()
|
133
|
+
logger.info(f"Overall time taken for interview: {interview_end - interview_start:.4f}s")
|
134
|
+
logger.info("Time breakdown:")
|
135
|
+
logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
|
136
|
+
logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
|
137
|
+
logger.info(f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s")
|
138
|
+
|
88
139
|
d = Dataset(
|
89
140
|
[
|
90
141
|
{"user_prompt": user_prompts},
|
@@ -149,10 +200,10 @@ class JobsPrompts:
|
|
149
200
|
import warnings
|
150
201
|
|
151
202
|
warnings.warn(
|
152
|
-
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $
|
203
|
+
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $1.00 / 1M tokens; Output: $1.00 / 1M tokens"
|
153
204
|
)
|
154
|
-
input_price_per_token = 0.
|
155
|
-
output_price_per_token = 0.
|
205
|
+
input_price_per_token = 0.000001 # $1.00 / 1M tokens
|
206
|
+
output_price_per_token = 0.000001 # $1.00 / 1M tokens
|
156
207
|
|
157
208
|
# Compute the number of characters (double if the question involves piping)
|
158
209
|
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
@@ -228,6 +228,40 @@ class JobsRemoteInferenceHandler:
|
|
228
228
|
results.results_uuid = results_uuid
|
229
229
|
return results
|
230
230
|
|
231
|
+
def _attempt_fetch_job(
|
232
|
+
self,
|
233
|
+
job_info: RemoteJobInfo,
|
234
|
+
remote_job_data_fetcher: Callable,
|
235
|
+
object_fetcher: Callable,
|
236
|
+
) -> Union[None, "Results", Literal["continue"]]:
|
237
|
+
"""Makes one attempt to fetch and process a remote job's status and results."""
|
238
|
+
remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
|
239
|
+
status = remote_job_data.get("status")
|
240
|
+
|
241
|
+
if status == "cancelled":
|
242
|
+
self._handle_cancelled_job(job_info)
|
243
|
+
return None
|
244
|
+
|
245
|
+
elif status == "failed" or status == "completed":
|
246
|
+
if status == "failed":
|
247
|
+
self._handle_failed_job(job_info, remote_job_data)
|
248
|
+
|
249
|
+
results_uuid = remote_job_data.get("results_uuid")
|
250
|
+
if results_uuid:
|
251
|
+
results = self._fetch_results_and_log(
|
252
|
+
job_info=job_info,
|
253
|
+
results_uuid=results_uuid,
|
254
|
+
remote_job_data=remote_job_data,
|
255
|
+
object_fetcher=object_fetcher,
|
256
|
+
)
|
257
|
+
return results
|
258
|
+
else:
|
259
|
+
return None
|
260
|
+
|
261
|
+
else:
|
262
|
+
self._sleep_for_a_bit(job_info, status)
|
263
|
+
return "continue"
|
264
|
+
|
231
265
|
def poll_remote_inference_job(
|
232
266
|
self,
|
233
267
|
job_info: RemoteJobInfo,
|
@@ -242,31 +276,13 @@ class JobsRemoteInferenceHandler:
|
|
242
276
|
|
243
277
|
job_in_queue = True
|
244
278
|
while job_in_queue:
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
elif status == "failed" or status == "completed":
|
253
|
-
if status == "failed":
|
254
|
-
self._handle_failed_job(job_info, remote_job_data)
|
255
|
-
|
256
|
-
results_uuid = remote_job_data.get("results_uuid")
|
257
|
-
if results_uuid:
|
258
|
-
results = self._fetch_results_and_log(
|
259
|
-
job_info=job_info,
|
260
|
-
results_uuid=results_uuid,
|
261
|
-
remote_job_data=remote_job_data,
|
262
|
-
object_fetcher=object_fetcher,
|
263
|
-
)
|
264
|
-
return results
|
265
|
-
else:
|
266
|
-
return None
|
267
|
-
|
268
|
-
else:
|
269
|
-
self._sleep_for_a_bit(job_info, status)
|
279
|
+
result = self._attempt_fetch_job(
|
280
|
+
job_info,
|
281
|
+
remote_job_data_fetcher,
|
282
|
+
object_fetcher
|
283
|
+
)
|
284
|
+
if result != "continue":
|
285
|
+
return result
|
270
286
|
|
271
287
|
async def create_and_poll_remote_job(
|
272
288
|
self,
|
@@ -96,6 +96,36 @@ class BucketCollection(UserDict):
|
|
96
96
|
else:
|
97
97
|
self[model] = self.services_to_buckets[self.models_to_services[model.model]]
|
98
98
|
|
99
|
+
def update_from_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
100
|
+
"""Updates the bucket collection rates based on model RPM/TPM from KeyLookup"""
|
101
|
+
|
102
|
+
for model_name, service in self.models_to_services.items():
|
103
|
+
if service in key_lookup and not self.infinity_buckets:
|
104
|
+
|
105
|
+
if key_lookup[service].rpm is not None:
|
106
|
+
new_rps = key_lookup[service].rpm / 60.0
|
107
|
+
new_requests_bucket = TokenBucket(
|
108
|
+
bucket_name=service,
|
109
|
+
bucket_type="requests",
|
110
|
+
capacity=new_rps,
|
111
|
+
refill_rate=new_rps,
|
112
|
+
remote_url=self.remote_url,
|
113
|
+
)
|
114
|
+
self.services_to_buckets[service].requests_bucket = (
|
115
|
+
new_requests_bucket
|
116
|
+
)
|
117
|
+
|
118
|
+
if key_lookup[service].tpm is not None:
|
119
|
+
new_tps = key_lookup[service].tpm / 60.0
|
120
|
+
new_tokens_bucket = TokenBucket(
|
121
|
+
bucket_name=service,
|
122
|
+
bucket_type="tokens",
|
123
|
+
capacity=new_tps,
|
124
|
+
refill_rate=new_tps,
|
125
|
+
remote_url=self.remote_url,
|
126
|
+
)
|
127
|
+
self.services_to_buckets[service].tokens_bucket = new_tokens_bucket
|
128
|
+
|
99
129
|
def visualize(self) -> dict:
|
100
130
|
"""Visualize the token and request buckets for each model."""
|
101
131
|
plots = {}
|
edsl/jobs/data_structures.py
CHANGED
@@ -32,6 +32,7 @@ class RunParameters(Base):
|
|
32
32
|
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted"
|
33
33
|
skip_retry: bool = False
|
34
34
|
raise_validation_errors: bool = False
|
35
|
+
background: bool = False
|
35
36
|
disable_remote_cache: bool = False
|
36
37
|
disable_remote_inference: bool = False
|
37
38
|
job_uuid: Optional[str] = None
|
@@ -518,7 +518,11 @@ class LanguageModel(
|
|
518
518
|
"""
|
519
519
|
from edsl.language_models.model import get_model_class
|
520
520
|
|
521
|
-
|
521
|
+
# breakpoint()
|
522
|
+
|
523
|
+
model_class = get_model_class(
|
524
|
+
data["model"], service_name=data.get("inference_service", None)
|
525
|
+
)
|
522
526
|
return model_class(**data)
|
523
527
|
|
524
528
|
def __repr__(self) -> str:
|
@@ -574,7 +578,6 @@ class LanguageModel(
|
|
574
578
|
return Model(skip_api_key_check=True)
|
575
579
|
|
576
580
|
def from_cache(self, cache: "Cache") -> LanguageModel:
|
577
|
-
|
578
581
|
from copy import deepcopy
|
579
582
|
from types import MethodType
|
580
583
|
from edsl import Cache
|
@@ -61,7 +61,14 @@ class KeyLookupBuilder:
|
|
61
61
|
DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE"))
|
62
62
|
DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE"))
|
63
63
|
|
64
|
-
def __init__(
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
fetch_order: Optional[tuple[str]] = None,
|
67
|
+
coop: Optional["Coop"] = None,
|
68
|
+
):
|
69
|
+
from edsl.coop import Coop
|
70
|
+
|
71
|
+
# Fetch order goes from lowest priority to highest priority
|
65
72
|
if fetch_order is None:
|
66
73
|
self.fetch_order = ("config", "env")
|
67
74
|
else:
|
@@ -70,6 +77,11 @@ class KeyLookupBuilder:
|
|
70
77
|
if not isinstance(self.fetch_order, tuple):
|
71
78
|
raise ValueError("fetch_order must be a tuple")
|
72
79
|
|
80
|
+
if coop is None:
|
81
|
+
self.coop = Coop()
|
82
|
+
else:
|
83
|
+
self.coop = coop
|
84
|
+
|
73
85
|
self.limit_data = {}
|
74
86
|
self.key_data = {}
|
75
87
|
self.id_data = {}
|
@@ -131,7 +143,8 @@ class KeyLookupBuilder:
|
|
131
143
|
service=service,
|
132
144
|
rpm=self.DEFAULT_RPM,
|
133
145
|
tpm=self.DEFAULT_TPM,
|
134
|
-
|
146
|
+
rpm_source="default",
|
147
|
+
tpm_source="default",
|
135
148
|
)
|
136
149
|
|
137
150
|
if limit_entry.rpm is None:
|
@@ -145,7 +158,8 @@ class KeyLookupBuilder:
|
|
145
158
|
tpm=int(limit_entry.tpm),
|
146
159
|
api_id=api_id,
|
147
160
|
token_source=api_key_entry.source,
|
148
|
-
|
161
|
+
rpm_source=limit_entry.rpm_source,
|
162
|
+
tpm_source=limit_entry.tpm_source,
|
149
163
|
id_source=id_source,
|
150
164
|
)
|
151
165
|
|
@@ -156,10 +170,7 @@ class KeyLookupBuilder:
|
|
156
170
|
return dict(list(os.environ.items()))
|
157
171
|
|
158
172
|
def _coop_key_value_pairs(self):
|
159
|
-
|
160
|
-
|
161
|
-
c = Coop()
|
162
|
-
return dict(list(c.fetch_rate_limit_config_vars().items()))
|
173
|
+
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
|
163
174
|
|
164
175
|
def _config_key_value_pairs(self):
|
165
176
|
from edsl.config import CONFIG
|
@@ -169,7 +180,7 @@ class KeyLookupBuilder:
|
|
169
180
|
@staticmethod
|
170
181
|
def extract_service(key: str) -> str:
|
171
182
|
"""Extract the service and limit type from the key"""
|
172
|
-
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_")
|
183
|
+
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1)
|
173
184
|
return service_raw.lower(), limit_type.lower()
|
174
185
|
|
175
186
|
def get_key_value_pairs(self) -> dict:
|
@@ -187,17 +198,17 @@ class KeyLookupBuilder:
|
|
187
198
|
d[k] = (v, source)
|
188
199
|
return d
|
189
200
|
|
190
|
-
def _entry_type(self, key
|
201
|
+
def _entry_type(self, key: str) -> str:
|
191
202
|
"""Determine the type of entry from a key.
|
192
203
|
|
193
204
|
>>> builder = KeyLookupBuilder()
|
194
|
-
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI"
|
205
|
+
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI")
|
195
206
|
'limit'
|
196
|
-
>>> builder._entry_type("OPENAI_API_KEY"
|
207
|
+
>>> builder._entry_type("OPENAI_API_KEY")
|
197
208
|
'api_key'
|
198
|
-
>>> builder._entry_type("AWS_ACCESS_KEY_ID"
|
209
|
+
>>> builder._entry_type("AWS_ACCESS_KEY_ID")
|
199
210
|
'api_id'
|
200
|
-
>>> builder._entry_type("UNKNOWN_KEY"
|
211
|
+
>>> builder._entry_type("UNKNOWN_KEY")
|
201
212
|
'unknown'
|
202
213
|
"""
|
203
214
|
if key.startswith("EDSL_SERVICE_"):
|
@@ -243,11 +254,13 @@ class KeyLookupBuilder:
|
|
243
254
|
service, limit_type = self.extract_service(key)
|
244
255
|
if service in self.limit_data:
|
245
256
|
setattr(self.limit_data[service], limit_type.lower(), value)
|
257
|
+
setattr(self.limit_data[service], f"{limit_type}_source", source)
|
246
258
|
else:
|
247
259
|
new_limit_entry = LimitEntry(
|
248
|
-
service=service, rpm=None, tpm=None,
|
260
|
+
service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None
|
249
261
|
)
|
250
262
|
setattr(new_limit_entry, limit_type.lower(), value)
|
263
|
+
setattr(new_limit_entry, f"{limit_type}_source", source)
|
251
264
|
self.limit_data[service] = new_limit_entry
|
252
265
|
|
253
266
|
def _add_api_key(self, key: str, value: str, source: str) -> None:
|
@@ -265,13 +278,27 @@ class KeyLookupBuilder:
|
|
265
278
|
else:
|
266
279
|
self.key_data[service].append(new_entry)
|
267
280
|
|
268
|
-
def
|
269
|
-
"""
|
270
|
-
|
281
|
+
def update_from_dict(self, d: dict) -> None:
|
282
|
+
"""
|
283
|
+
Update data from a dictionary of key-value pairs.
|
284
|
+
Each key is a key name, and each value is a tuple of (value, source).
|
285
|
+
|
286
|
+
>>> builder = KeyLookupBuilder()
|
287
|
+
>>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")})
|
288
|
+
>>> 'sk-1234' == builder.key_data["openai"][-1].value
|
289
|
+
True
|
290
|
+
>>> 'custodial_keys' == builder.key_data["openai"][-1].source
|
291
|
+
True
|
292
|
+
"""
|
293
|
+
for key, value_pair in d.items():
|
271
294
|
value, source = value_pair
|
272
|
-
if
|
295
|
+
if self._entry_type(key) == "limit":
|
273
296
|
self._add_limit(key, value, source)
|
274
|
-
elif
|
297
|
+
elif self._entry_type(key) == "api_key":
|
275
298
|
self._add_api_key(key, value, source)
|
276
|
-
elif
|
299
|
+
elif self._entry_type(key) == "api_id":
|
277
300
|
self._add_id(key, value, source)
|
301
|
+
|
302
|
+
def process_key_value_pairs(self) -> None:
|
303
|
+
"""Process all key-value pairs from the configured sources."""
|
304
|
+
self.update_from_dict(self.get_key_value_pairs())
|
@@ -40,18 +40,23 @@ class LimitEntry:
|
|
40
40
|
60
|
41
41
|
>>> limit.tpm
|
42
42
|
100000
|
43
|
-
>>> limit.
|
43
|
+
>>> limit.rpm_source
|
44
44
|
'config'
|
45
|
+
>>> limit.tpm_source
|
46
|
+
'env'
|
45
47
|
"""
|
46
48
|
|
47
49
|
service: str
|
48
50
|
rpm: int
|
49
51
|
tpm: int
|
50
|
-
|
52
|
+
rpm_source: Optional[str] = None
|
53
|
+
tpm_source: Optional[str] = None
|
51
54
|
|
52
55
|
@classmethod
|
53
56
|
def example(cls):
|
54
|
-
return LimitEntry(
|
57
|
+
return LimitEntry(
|
58
|
+
service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env"
|
59
|
+
)
|
55
60
|
|
56
61
|
|
57
62
|
@dataclass
|
@@ -108,7 +113,8 @@ class LanguageModelInput:
|
|
108
113
|
tpm: int
|
109
114
|
api_id: Optional[str] = None
|
110
115
|
token_source: Optional[str] = None
|
111
|
-
|
116
|
+
rpm_source: Optional[str] = None
|
117
|
+
tpm_source: Optional[str] = None
|
112
118
|
id_source: Optional[str] = None
|
113
119
|
|
114
120
|
def to_dict(self):
|