edsl 0.1.45__py3-none-any.whl → 0.1.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +87 -16
- edsl/__version__.py +1 -1
- edsl/agents/PromptConstructor.py +26 -79
- edsl/agents/QuestionInstructionPromptBuilder.py +70 -32
- edsl/agents/QuestionTemplateReplacementsBuilder.py +12 -2
- edsl/coop/coop.py +289 -147
- edsl/data/Cache.py +2 -0
- edsl/data/CacheEntry.py +10 -2
- edsl/data/RemoteCacheSync.py +10 -9
- edsl/inference_services/AvailableModelFetcher.py +1 -1
- edsl/inference_services/PerplexityService.py +9 -5
- edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
- edsl/jobs/Jobs.py +35 -17
- edsl/jobs/JobsComponentConstructor.py +2 -1
- edsl/jobs/JobsPrompts.py +49 -26
- edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
- edsl/jobs/data_structures.py +3 -0
- edsl/jobs/interviews/Interview.py +6 -3
- edsl/language_models/LanguageModel.py +7 -1
- edsl/questions/QuestionBase.py +5 -0
- edsl/questions/question_base_gen_mixin.py +2 -0
- edsl/questions/question_registry.py +6 -7
- edsl/results/DatasetExportMixin.py +124 -6
- edsl/results/Results.py +59 -0
- edsl/scenarios/FileStore.py +112 -7
- edsl/scenarios/ScenarioList.py +283 -21
- edsl/study/Study.py +2 -2
- edsl/surveys/Survey.py +15 -20
- {edsl-0.1.45.dist-info → edsl-0.1.47.dist-info}/METADATA +4 -3
- {edsl-0.1.45.dist-info → edsl-0.1.47.dist-info}/RECORD +32 -44
- edsl/auto/AutoStudy.py +0 -130
- edsl/auto/StageBase.py +0 -243
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -74
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -218
- edsl/base/Base.py +0 -279
- {edsl-0.1.45.dist-info → edsl-0.1.47.dist-info}/LICENSE +0 -0
- {edsl-0.1.45.dist-info → edsl-0.1.47.dist-info}/WHEEL +0 -0
edsl/data/Cache.py
CHANGED
@@ -173,6 +173,7 @@ class Cache(Base):
|
|
173
173
|
user_prompt: str,
|
174
174
|
response: dict,
|
175
175
|
iteration: int,
|
176
|
+
service: str,
|
176
177
|
) -> str:
|
177
178
|
"""
|
178
179
|
Add a new key-value pair to the cache.
|
@@ -204,6 +205,7 @@ class Cache(Base):
|
|
204
205
|
user_prompt=user_prompt,
|
205
206
|
output=json.dumps(response),
|
206
207
|
iteration=iteration,
|
208
|
+
service=service,
|
207
209
|
)
|
208
210
|
key = entry.key
|
209
211
|
self.new_entries[key] = entry
|
edsl/data/CacheEntry.py
CHANGED
@@ -16,7 +16,7 @@ class CacheEntry(RepresentationMixin):
|
|
16
16
|
"""
|
17
17
|
|
18
18
|
key_fields = ["model", "parameters", "system_prompt", "user_prompt", "iteration"]
|
19
|
-
all_fields = key_fields + ["timestamp", "output"]
|
19
|
+
all_fields = key_fields + ["timestamp", "output", "service"]
|
20
20
|
|
21
21
|
def __init__(
|
22
22
|
self,
|
@@ -28,6 +28,7 @@ class CacheEntry(RepresentationMixin):
|
|
28
28
|
iteration: Optional[int] = None,
|
29
29
|
output: str,
|
30
30
|
timestamp: Optional[int] = None,
|
31
|
+
service: Optional[str] = None,
|
31
32
|
):
|
32
33
|
self.model = model
|
33
34
|
self.parameters = parameters
|
@@ -38,6 +39,7 @@ class CacheEntry(RepresentationMixin):
|
|
38
39
|
self.timestamp = timestamp or int(
|
39
40
|
datetime.datetime.now(datetime.timezone.utc).timestamp()
|
40
41
|
)
|
42
|
+
self.service = service
|
41
43
|
self._check_types()
|
42
44
|
|
43
45
|
def _check_types(self):
|
@@ -59,6 +61,8 @@ class CacheEntry(RepresentationMixin):
|
|
59
61
|
# TODO: should probably be float
|
60
62
|
if not isinstance(self.timestamp, int):
|
61
63
|
raise TypeError(f"`timestamp` should be an integer")
|
64
|
+
if self.service is not None and not isinstance(self.service, str):
|
65
|
+
raise TypeError("`service` should be either a string or None")
|
62
66
|
|
63
67
|
@classmethod
|
64
68
|
def gen_key(
|
@@ -94,6 +98,7 @@ class CacheEntry(RepresentationMixin):
|
|
94
98
|
"output": self.output,
|
95
99
|
"iteration": self.iteration,
|
96
100
|
"timestamp": self.timestamp,
|
101
|
+
"service": self.service,
|
97
102
|
}
|
98
103
|
# if add_edsl_version:
|
99
104
|
# from edsl import __version__
|
@@ -144,7 +149,8 @@ class CacheEntry(RepresentationMixin):
|
|
144
149
|
f"user_prompt={repr(self.user_prompt)}, "
|
145
150
|
f"output={repr(self.output)}, "
|
146
151
|
f"iteration={self.iteration}, "
|
147
|
-
f"timestamp={self.timestamp}
|
152
|
+
f"timestamp={self.timestamp}, "
|
153
|
+
f"service={repr(self.service)})"
|
148
154
|
)
|
149
155
|
|
150
156
|
@classmethod
|
@@ -164,6 +170,7 @@ class CacheEntry(RepresentationMixin):
|
|
164
170
|
output="The fox says 'hello'",
|
165
171
|
iteration=1,
|
166
172
|
timestamp=int(datetime.datetime.now(datetime.timezone.utc).timestamp()),
|
173
|
+
service="openai",
|
167
174
|
)
|
168
175
|
|
169
176
|
@classmethod
|
@@ -184,6 +191,7 @@ class CacheEntry(RepresentationMixin):
|
|
184
191
|
input = cls.example().to_dict()
|
185
192
|
_ = input.pop("timestamp")
|
186
193
|
_ = input.pop("output")
|
194
|
+
_ = input.pop("service")
|
187
195
|
return input
|
188
196
|
|
189
197
|
@classmethod
|
edsl/data/RemoteCacheSync.py
CHANGED
@@ -100,7 +100,7 @@ class RemoteCacheSync(AbstractContextManager):
|
|
100
100
|
|
101
101
|
def _get_cache_difference(self) -> CacheDifference:
|
102
102
|
"""Retrieves differences between local and remote caches."""
|
103
|
-
diff = self.coop.
|
103
|
+
diff = self.coop.legacy_remote_cache_get_diff(self.cache.keys())
|
104
104
|
return CacheDifference(
|
105
105
|
client_missing_entries=diff.get("client_missing_cacheentries", []),
|
106
106
|
server_missing_keys=diff.get("server_missing_cacheentry_keys", []),
|
@@ -112,7 +112,7 @@ class RemoteCacheSync(AbstractContextManager):
|
|
112
112
|
missing_count = len(diff.client_missing_entries)
|
113
113
|
|
114
114
|
if missing_count == 0:
|
115
|
-
|
115
|
+
# self._output("No new entries to add to local cache.")
|
116
116
|
return
|
117
117
|
|
118
118
|
# self._output(
|
@@ -154,22 +154,23 @@ class RemoteCacheSync(AbstractContextManager):
|
|
154
154
|
upload_count = len(entries_to_upload)
|
155
155
|
|
156
156
|
if upload_count > 0:
|
157
|
+
pass
|
157
158
|
# self._output(
|
158
159
|
# f"Updating remote cache with {upload_count:,} new "
|
159
160
|
# f"{'entry' if upload_count == 1 else 'entries'}..."
|
160
161
|
# )
|
161
162
|
|
162
|
-
self.coop.remote_cache_create_many(
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
)
|
163
|
+
# self.coop.remote_cache_create_many(
|
164
|
+
# entries_to_upload,
|
165
|
+
# visibility="private",
|
166
|
+
# description=self.remote_cache_description,
|
167
|
+
# )
|
167
168
|
# self._output("Remote cache updated!")
|
168
169
|
# else:
|
169
|
-
|
170
|
+
# self._output("No new entries to add to remote cache.")
|
170
171
|
|
171
172
|
# self._output(
|
172
|
-
|
173
|
+
# f"There are {len(self.cache.keys()):,} entries in the local cache."
|
173
174
|
# )
|
174
175
|
|
175
176
|
|
@@ -69,7 +69,7 @@ class AvailableModelFetcher:
|
|
69
69
|
|
70
70
|
Returns a list of [model, service_name, index] entries.
|
71
71
|
"""
|
72
|
-
if service == "azure":
|
72
|
+
if service == "azure" or service == "bedrock":
|
73
73
|
force_refresh = True # Azure models are listed inside the .env AZURE_ENDPOINT_URL_AND_KEY variable
|
74
74
|
|
75
75
|
if service: # they passed a specific service
|
@@ -29,9 +29,12 @@ class PerplexityService(OpenAIService):
|
|
29
29
|
@classmethod
|
30
30
|
def available(cls) -> List[str]:
|
31
31
|
return [
|
32
|
-
"
|
33
|
-
"
|
34
|
-
"
|
32
|
+
"sonar-deep-research",
|
33
|
+
"sonar-reasoning-pro",
|
34
|
+
"sonar-reasoning",
|
35
|
+
"sonar-pro",
|
36
|
+
"sonar",
|
37
|
+
"r1-1776",
|
35
38
|
]
|
36
39
|
|
37
40
|
@classmethod
|
@@ -65,10 +68,10 @@ class PerplexityService(OpenAIService):
|
|
65
68
|
}
|
66
69
|
|
67
70
|
def sync_client(self):
|
68
|
-
return cls.sync_client()
|
71
|
+
return cls.sync_client(api_key=self.api_token)
|
69
72
|
|
70
73
|
def async_client(self):
|
71
|
-
return cls.async_client()
|
74
|
+
return cls.async_client(api_key=self.api_token)
|
72
75
|
|
73
76
|
@classmethod
|
74
77
|
def available(cls) -> list[str]:
|
@@ -149,6 +152,7 @@ class PerplexityService(OpenAIService):
|
|
149
152
|
# "logprobs": self.logprobs,
|
150
153
|
# "top_logprobs": self.top_logprobs if self.logprobs else None,
|
151
154
|
}
|
155
|
+
print("calling the model", flush=True)
|
152
156
|
try:
|
153
157
|
response = await client.chat.completions.create(**params)
|
154
158
|
except Exception as e:
|
@@ -66,10 +66,14 @@ class SkipHandler:
|
|
66
66
|
)
|
67
67
|
)
|
68
68
|
|
69
|
+
|
69
70
|
def cancel_between(start, end):
|
70
71
|
"""Cancel the tasks for questions between the start and end indices."""
|
71
72
|
for i in range(start, end):
|
72
|
-
|
73
|
+
#print(f"Cancelling task {i}")
|
74
|
+
#self.interview.tasks[i].cancel()
|
75
|
+
#self.interview.tasks[i].set_result("skipped")
|
76
|
+
self.interview.skip_flags[self.interview.survey.questions[i].question_name] = True
|
73
77
|
|
74
78
|
if (next_question_index := next_question.next_q) == EndOfSurvey:
|
75
79
|
cancel_between(
|
@@ -80,6 +84,8 @@ class SkipHandler:
|
|
80
84
|
if next_question_index > (current_question_index + 1):
|
81
85
|
cancel_between(current_question_index + 1, next_question_index)
|
82
86
|
|
87
|
+
|
88
|
+
|
83
89
|
|
84
90
|
class AnswerQuestionFunctionConstructor:
|
85
91
|
"""Constructs a function that answers a question and records the answer."""
|
@@ -161,6 +167,11 @@ class AnswerQuestionFunctionConstructor:
|
|
161
167
|
async def attempt_answer():
|
162
168
|
invigilator = self.invigilator_fetcher(question)
|
163
169
|
|
170
|
+
if self.interview.skip_flags.get(question.question_name, False):
|
171
|
+
return invigilator.get_failed_task_result(
|
172
|
+
failure_reason="Question skipped."
|
173
|
+
)
|
174
|
+
|
164
175
|
if self.skip_handler.should_skip(question):
|
165
176
|
return invigilator.get_failed_task_result(
|
166
177
|
failure_reason="Question skipped."
|
edsl/jobs/Jobs.py
CHANGED
@@ -119,6 +119,19 @@ class Jobs(Base):
|
|
119
119
|
:param agents: a list of agents
|
120
120
|
:param models: a list of models
|
121
121
|
:param scenarios: a list of scenarios
|
122
|
+
|
123
|
+
|
124
|
+
>>> from edsl.surveys.Survey import Survey
|
125
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
126
|
+
>>> q = QuestionFreeText(question_name="name", question_text="What is your name?")
|
127
|
+
>>> s = Survey(questions=[q])
|
128
|
+
>>> j = Jobs(survey = s)
|
129
|
+
>>> q = QuestionFreeText(question_name="{{ bad_name }}", question_text="What is your name?")
|
130
|
+
>>> s = Survey(questions=[q])
|
131
|
+
>>> j = Jobs(survey = s)
|
132
|
+
Traceback (most recent call last):
|
133
|
+
...
|
134
|
+
ValueError: At least some question names are not valid: ['{{ bad_name }}']
|
122
135
|
"""
|
123
136
|
self.run_config = RunConfig(
|
124
137
|
environment=RunEnvironment(), parameters=RunParameters()
|
@@ -129,6 +142,13 @@ class Jobs(Base):
|
|
129
142
|
self.scenarios: ScenarioList = scenarios
|
130
143
|
self.models: ModelList = models
|
131
144
|
|
145
|
+
try:
|
146
|
+
assert self.survey.question_names_valid()
|
147
|
+
except Exception as e:
|
148
|
+
invalid_question_names = [q.question_name for q in self.survey.questions if not q.is_valid_question_name()]
|
149
|
+
raise ValueError(f"At least some question names are not valid: {invalid_question_names}")
|
150
|
+
|
151
|
+
|
132
152
|
def add_running_env(self, running_env: RunEnvironment):
|
133
153
|
self.run_config.add_environment(running_env)
|
134
154
|
return self
|
@@ -277,7 +297,7 @@ class Jobs(Base):
|
|
277
297
|
|
278
298
|
return JobsComponentConstructor(self).by(*args)
|
279
299
|
|
280
|
-
def prompts(self) -> "Dataset":
|
300
|
+
def prompts(self, iterations=1) -> "Dataset":
|
281
301
|
"""Return a Dataset of prompts that will be used.
|
282
302
|
|
283
303
|
|
@@ -285,7 +305,7 @@ class Jobs(Base):
|
|
285
305
|
>>> Jobs.example().prompts()
|
286
306
|
Dataset(...)
|
287
307
|
"""
|
288
|
-
return JobsPrompts(self).prompts()
|
308
|
+
return JobsPrompts(self).prompts(iterations=iterations)
|
289
309
|
|
290
310
|
def show_prompts(self, all: bool = False) -> None:
|
291
311
|
"""Print the prompts."""
|
@@ -418,11 +438,9 @@ class Jobs(Base):
|
|
418
438
|
BucketCollection(...)
|
419
439
|
"""
|
420
440
|
bc = BucketCollection.from_models(self.models)
|
421
|
-
|
441
|
+
|
422
442
|
if self.run_config.environment.key_lookup is not None:
|
423
|
-
bc.update_from_key_lookup(
|
424
|
-
self.run_config.environment.key_lookup
|
425
|
-
)
|
443
|
+
bc.update_from_key_lookup(self.run_config.environment.key_lookup)
|
426
444
|
return bc
|
427
445
|
|
428
446
|
def html(self):
|
@@ -484,25 +502,24 @@ class Jobs(Base):
|
|
484
502
|
def _start_remote_inference_job(
|
485
503
|
self, job_handler: Optional[JobsRemoteInferenceHandler] = None
|
486
504
|
) -> Union["Results", None]:
|
487
|
-
|
488
505
|
if job_handler is None:
|
489
506
|
job_handler = self._create_remote_inference_handler()
|
490
|
-
|
507
|
+
|
491
508
|
job_info = job_handler.create_remote_inference_job(
|
492
|
-
|
493
|
-
|
494
|
-
|
509
|
+
iterations=self.run_config.parameters.n,
|
510
|
+
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
511
|
+
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
512
|
+
fresh=self.run_config.parameters.fresh,
|
495
513
|
)
|
496
514
|
return job_info
|
497
|
-
|
498
|
-
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
499
515
|
|
516
|
+
def _create_remote_inference_handler(self) -> JobsRemoteInferenceHandler:
|
500
517
|
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
501
|
-
|
518
|
+
|
502
519
|
return JobsRemoteInferenceHandler(
|
503
520
|
self, verbose=self.run_config.parameters.verbose
|
504
521
|
)
|
505
|
-
|
522
|
+
|
506
523
|
def _remote_results(
|
507
524
|
self,
|
508
525
|
config: RunConfig,
|
@@ -516,7 +533,8 @@ class Jobs(Base):
|
|
516
533
|
if jh.use_remote_inference(self.run_config.parameters.disable_remote_inference):
|
517
534
|
job_info: RemoteJobInfo = self._start_remote_inference_job(jh)
|
518
535
|
if background:
|
519
|
-
from edsl.results.Results import Results
|
536
|
+
from edsl.results.Results import Results
|
537
|
+
|
520
538
|
results = Results.from_job_info(job_info)
|
521
539
|
return results
|
522
540
|
else:
|
@@ -603,7 +621,7 @@ class Jobs(Base):
|
|
603
621
|
# first try to run the job remotely
|
604
622
|
if (results := self._remote_results(config)) is not None:
|
605
623
|
return results
|
606
|
-
|
624
|
+
|
607
625
|
self._check_if_local_keys_ok()
|
608
626
|
|
609
627
|
if config.environment.bucket_collection is None:
|
@@ -153,7 +153,8 @@ class JobsComponentConstructor:
|
|
153
153
|
For example, if the user passes in 3 agents,
|
154
154
|
and there are 2 existing agents, this will create 6 new agents
|
155
155
|
>>> from edsl.jobs import Jobs
|
156
|
-
>>>
|
156
|
+
>>> from edsl.surveys.Survey import Survey
|
157
|
+
>>> JobsComponentConstructor(Jobs(survey = Survey.example()))._merge_objects([1,2,3], [4,5,6])
|
157
158
|
[5, 6, 7, 6, 7, 8, 7, 8, 9]
|
158
159
|
"""
|
159
160
|
new_objects = JobsComponentConstructor._get_empty_container_object(
|
edsl/jobs/JobsPrompts.py
CHANGED
@@ -18,6 +18,7 @@ from edsl.data.CacheEntry import CacheEntry
|
|
18
18
|
|
19
19
|
logger = logging.getLogger(__name__)
|
20
20
|
|
21
|
+
|
21
22
|
class JobsPrompts:
|
22
23
|
def __init__(self, jobs: "Jobs"):
|
23
24
|
self.interviews = jobs.interviews()
|
@@ -26,7 +27,9 @@ class JobsPrompts:
|
|
26
27
|
self.survey = jobs.survey
|
27
28
|
self._price_lookup = None
|
28
29
|
self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
|
29
|
-
self._scenario_lookup = {
|
30
|
+
self._scenario_lookup = {
|
31
|
+
scenario: idx for idx, scenario in enumerate(self.scenarios)
|
32
|
+
}
|
30
33
|
|
31
34
|
@property
|
32
35
|
def price_lookup(self):
|
@@ -37,7 +40,7 @@ class JobsPrompts:
|
|
37
40
|
self._price_lookup = c.fetch_prices()
|
38
41
|
return self._price_lookup
|
39
42
|
|
40
|
-
def prompts(self) -> "Dataset":
|
43
|
+
def prompts(self, iterations=1) -> "Dataset":
|
41
44
|
"""Return a Dataset of prompts that will be used.
|
42
45
|
|
43
46
|
>>> from edsl.jobs import Jobs
|
@@ -54,11 +57,11 @@ class JobsPrompts:
|
|
54
57
|
models = []
|
55
58
|
costs = []
|
56
59
|
cache_keys = []
|
57
|
-
|
60
|
+
|
58
61
|
for interview_index, interview in enumerate(interviews):
|
59
62
|
logger.info(f"Processing interview {interview_index} of {len(interviews)}")
|
60
63
|
interview_start = time.time()
|
61
|
-
|
64
|
+
|
62
65
|
# Fetch invigilators timing
|
63
66
|
invig_start = time.time()
|
64
67
|
invigilators = [
|
@@ -66,8 +69,10 @@ class JobsPrompts:
|
|
66
69
|
for question in interview.survey.questions
|
67
70
|
]
|
68
71
|
invig_end = time.time()
|
69
|
-
logger.debug(
|
70
|
-
|
72
|
+
logger.debug(
|
73
|
+
f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s"
|
74
|
+
)
|
75
|
+
|
71
76
|
# Process prompts timing
|
72
77
|
prompts_start = time.time()
|
73
78
|
for _, invigilator in enumerate(invigilators):
|
@@ -75,13 +80,15 @@ class JobsPrompts:
|
|
75
80
|
get_prompts_start = time.time()
|
76
81
|
prompts = invigilator.get_prompts()
|
77
82
|
get_prompts_end = time.time()
|
78
|
-
logger.debug(
|
79
|
-
|
83
|
+
logger.debug(
|
84
|
+
f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s"
|
85
|
+
)
|
86
|
+
|
80
87
|
user_prompt = prompts["user_prompt"]
|
81
88
|
system_prompt = prompts["system_prompt"]
|
82
89
|
user_prompts.append(user_prompt)
|
83
90
|
system_prompts.append(system_prompt)
|
84
|
-
|
91
|
+
|
85
92
|
# Index lookups timing
|
86
93
|
index_start = time.time()
|
87
94
|
agent_index = self._agent_lookup[invigilator.agent]
|
@@ -90,14 +97,18 @@ class JobsPrompts:
|
|
90
97
|
scenario_index = self._scenario_lookup[invigilator.scenario]
|
91
98
|
scenario_indices.append(scenario_index)
|
92
99
|
index_end = time.time()
|
93
|
-
logger.debug(
|
94
|
-
|
100
|
+
logger.debug(
|
101
|
+
f"Time taken for index lookups: {index_end - index_start:.4f}s"
|
102
|
+
)
|
103
|
+
|
95
104
|
# Model and question name assignment timing
|
96
105
|
assign_start = time.time()
|
97
106
|
models.append(invigilator.model.model)
|
98
107
|
question_names.append(invigilator.question.question_name)
|
99
108
|
assign_end = time.time()
|
100
|
-
logger.debug(
|
109
|
+
logger.debug(
|
110
|
+
f"Time taken for assignments: {assign_end - assign_start:.4f}s"
|
111
|
+
)
|
101
112
|
|
102
113
|
# Cost estimation timing
|
103
114
|
cost_start = time.time()
|
@@ -109,32 +120,44 @@ class JobsPrompts:
|
|
109
120
|
model=invigilator.model.model,
|
110
121
|
)
|
111
122
|
cost_end = time.time()
|
112
|
-
logger.debug(
|
123
|
+
logger.debug(
|
124
|
+
f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s"
|
125
|
+
)
|
113
126
|
costs.append(prompt_cost["cost_usd"])
|
114
127
|
|
115
128
|
# Cache key generation timing
|
116
129
|
cache_key_gen_start = time.time()
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
130
|
+
for iteration in range(iterations):
|
131
|
+
cache_key = CacheEntry.gen_key(
|
132
|
+
model=invigilator.model.model,
|
133
|
+
parameters=invigilator.model.parameters,
|
134
|
+
system_prompt=system_prompt,
|
135
|
+
user_prompt=user_prompt,
|
136
|
+
iteration=iteration,
|
137
|
+
)
|
138
|
+
cache_keys.append(cache_key)
|
139
|
+
|
124
140
|
cache_key_gen_end = time.time()
|
125
|
-
|
126
|
-
|
141
|
+
logger.debug(
|
142
|
+
f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s"
|
143
|
+
)
|
127
144
|
logger.debug("-" * 50) # Separator between iterations
|
128
145
|
|
129
146
|
prompts_end = time.time()
|
130
|
-
logger.info(
|
131
|
-
|
147
|
+
logger.info(
|
148
|
+
f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s"
|
149
|
+
)
|
150
|
+
|
132
151
|
interview_end = time.time()
|
133
|
-
logger.info(
|
152
|
+
logger.info(
|
153
|
+
f"Overall time taken for interview: {interview_end - interview_start:.4f}s"
|
154
|
+
)
|
134
155
|
logger.info("Time breakdown:")
|
135
156
|
logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
|
136
157
|
logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
|
137
|
-
logger.info(
|
158
|
+
logger.info(
|
159
|
+
f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s"
|
160
|
+
)
|
138
161
|
|
139
162
|
d = Dataset(
|
140
163
|
[
|
@@ -24,7 +24,7 @@ from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
|
|
24
24
|
class RemoteJobConstants:
|
25
25
|
"""Constants for remote job handling."""
|
26
26
|
|
27
|
-
REMOTE_JOB_POLL_INTERVAL =
|
27
|
+
REMOTE_JOB_POLL_INTERVAL = 4
|
28
28
|
REMOTE_JOB_VERBOSE = False
|
29
29
|
DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
|
30
30
|
|
@@ -88,8 +88,8 @@ class JobsRemoteInferenceHandler:
|
|
88
88
|
iterations: int = 1,
|
89
89
|
remote_inference_description: Optional[str] = None,
|
90
90
|
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
91
|
+
fresh: Optional[bool] = False,
|
91
92
|
) -> RemoteJobInfo:
|
92
|
-
|
93
93
|
from edsl.config import CONFIG
|
94
94
|
from edsl.coop.coop import Coop
|
95
95
|
|
@@ -106,6 +106,7 @@ class JobsRemoteInferenceHandler:
|
|
106
106
|
status="queued",
|
107
107
|
iterations=iterations,
|
108
108
|
initial_results_visibility=remote_inference_results_visibility,
|
109
|
+
fresh=fresh,
|
109
110
|
)
|
110
111
|
logger.update(
|
111
112
|
"Your survey is running at the Expected Parrot server...",
|
@@ -277,9 +278,7 @@ class JobsRemoteInferenceHandler:
|
|
277
278
|
job_in_queue = True
|
278
279
|
while job_in_queue:
|
279
280
|
result = self._attempt_fetch_job(
|
280
|
-
job_info,
|
281
|
-
remote_job_data_fetcher,
|
282
|
-
object_fetcher
|
281
|
+
job_info, remote_job_data_fetcher, object_fetcher
|
283
282
|
)
|
284
283
|
if result != "continue":
|
285
284
|
return result
|
edsl/jobs/data_structures.py
CHANGED
@@ -36,6 +36,9 @@ class RunParameters(Base):
|
|
36
36
|
disable_remote_cache: bool = False
|
37
37
|
disable_remote_inference: bool = False
|
38
38
|
job_uuid: Optional[str] = None
|
39
|
+
fresh: Optional[
|
40
|
+
bool
|
41
|
+
] = False # if True, will not use cache and will save new results to cache
|
39
42
|
|
40
43
|
def to_dict(self, add_edsl_version=False) -> dict:
|
41
44
|
d = asdict(self)
|
@@ -238,9 +238,6 @@ class Interview:
|
|
238
238
|
>>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
|
239
239
|
>>> run_config.parameters.stop_on_exception = True
|
240
240
|
>>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
|
241
|
-
Traceback (most recent call last):
|
242
|
-
...
|
243
|
-
asyncio.exceptions.CancelledError
|
244
241
|
"""
|
245
242
|
from edsl.jobs.Jobs import RunConfig, RunParameters, RunEnvironment
|
246
243
|
|
@@ -262,6 +259,8 @@ class Interview:
|
|
262
259
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
263
260
|
model_buckets = ModelBuckets.infinity_bucket()
|
264
261
|
|
262
|
+
self.skip_flags = {q.question_name: False for q in self.survey.questions}
|
263
|
+
|
265
264
|
# was "self.tasks" - is that necessary?
|
266
265
|
self.tasks = self.task_manager.build_question_tasks(
|
267
266
|
answer_func=AnswerQuestionFunctionConstructor(
|
@@ -310,6 +309,10 @@ class Interview:
|
|
310
309
|
def handle_task(task, invigilator):
|
311
310
|
try:
|
312
311
|
result: Answers = task.result()
|
312
|
+
if result == "skipped":
|
313
|
+
result = invigilator.get_failed_task_result(
|
314
|
+
failure_reason="Task was skipped."
|
315
|
+
)
|
313
316
|
except asyncio.CancelledError as e: # task was cancelled
|
314
317
|
result = invigilator.get_failed_task_result(
|
315
318
|
failure_reason="Task was cancelled."
|
@@ -379,8 +379,10 @@ class LanguageModel(
|
|
379
379
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
380
380
|
|
381
381
|
if cache_used := cached_response is not None:
|
382
|
+
# print("cache used")
|
382
383
|
response = json.loads(cached_response)
|
383
384
|
else:
|
385
|
+
# print("cache not used")
|
384
386
|
f = (
|
385
387
|
self.remote_async_execute_model_call
|
386
388
|
if hasattr(self, "remote") and self.remote
|
@@ -396,11 +398,14 @@ class LanguageModel(
|
|
396
398
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
397
399
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
398
400
|
new_cache_key = cache.store(
|
399
|
-
**cache_call_params, response=response
|
401
|
+
**cache_call_params, response=response, service=self._inference_service_
|
400
402
|
) # store the response in the cache
|
401
403
|
assert new_cache_key == cache_key # should be the same
|
402
404
|
|
405
|
+
# breakpoint()
|
406
|
+
|
403
407
|
cost = self.cost(response)
|
408
|
+
# breakpoint()
|
404
409
|
return ModelResponse(
|
405
410
|
response=response,
|
406
411
|
cache_used=cache_used,
|
@@ -465,6 +470,7 @@ class LanguageModel(
|
|
465
470
|
model_outputs=model_outputs,
|
466
471
|
edsl_dict=edsl_dict,
|
467
472
|
)
|
473
|
+
# breakpoint()
|
468
474
|
return agent_response_dict
|
469
475
|
|
470
476
|
get_response = sync_wrapper(async_get_response)
|
edsl/questions/QuestionBase.py
CHANGED
@@ -18,6 +18,7 @@ from edsl.questions.SimpleAskMixin import SimpleAskMixin
|
|
18
18
|
from edsl.questions.QuestionBasePromptsMixin import QuestionBasePromptsMixin
|
19
19
|
from edsl.questions.question_base_gen_mixin import QuestionBaseGenMixin
|
20
20
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
21
|
+
from edsl.utilities.utilities import is_valid_variable_name
|
21
22
|
|
22
23
|
if TYPE_CHECKING:
|
23
24
|
from edsl.questions.response_validator_abc import ResponseValidatorABC
|
@@ -56,6 +57,10 @@ class QuestionBase(
|
|
56
57
|
_answering_instructions = None
|
57
58
|
_question_presentation = None
|
58
59
|
|
60
|
+
def is_valid_question_name(self) -> bool:
|
61
|
+
"""Check if the question name is valid."""
|
62
|
+
return is_valid_variable_name(self.question_name)
|
63
|
+
|
59
64
|
@property
|
60
65
|
def response_validator(self) -> "ResponseValidatorABC":
|
61
66
|
"""Return the response validator."""
|
@@ -140,6 +140,8 @@ class QuestionBaseGenMixin:
|
|
140
140
|
k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
|
141
141
|
}
|
142
142
|
|
143
|
+
strings_only_replacement_dict['scenario'] = strings_only_replacement_dict
|
144
|
+
|
143
145
|
def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
|
144
146
|
"""Check if the template string has any unrendered variables."""
|
145
147
|
if not isinstance(template_str, str):
|