edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -3,21 +3,12 @@ from __future__ import annotations
|
|
3
3
|
import os
|
4
4
|
import time
|
5
5
|
import requests
|
6
|
-
import warnings
|
7
6
|
from abc import ABC, abstractmethod
|
8
7
|
from dataclasses import dataclass
|
9
|
-
|
10
|
-
from typing import Any, List, DefaultDict, Optional, Dict
|
11
8
|
from collections import defaultdict
|
9
|
+
from typing import Any, Dict, Optional
|
12
10
|
from uuid import UUID
|
13
11
|
|
14
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
15
|
-
|
16
|
-
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
17
|
-
|
18
|
-
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
19
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
20
|
-
|
21
12
|
|
22
13
|
@dataclass
|
23
14
|
class ModelInfo:
|
@@ -28,11 +19,44 @@ class ModelInfo:
|
|
28
19
|
token_usage_info: dict
|
29
20
|
|
30
21
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
22
|
+
class StatisticsTracker:
|
23
|
+
def __init__(self, total_interviews: int, distinct_models: list[str]):
|
24
|
+
self.start_time = time.time()
|
25
|
+
self.total_interviews = total_interviews
|
26
|
+
self.completed_count = 0
|
27
|
+
self.completed_by_model = defaultdict(int)
|
28
|
+
self.distinct_models = distinct_models
|
29
|
+
self.total_exceptions = 0
|
30
|
+
self.unfixed_exceptions = 0
|
31
|
+
|
32
|
+
def add_completed_interview(
|
33
|
+
self, model: str, num_exceptions: int = 0, num_unfixed: int = 0
|
34
|
+
):
|
35
|
+
self.completed_count += 1
|
36
|
+
self.completed_by_model[model] += 1
|
37
|
+
self.total_exceptions += num_exceptions
|
38
|
+
self.unfixed_exceptions += num_unfixed
|
39
|
+
|
40
|
+
def get_elapsed_time(self) -> float:
|
41
|
+
return time.time() - self.start_time
|
42
|
+
|
43
|
+
def get_average_time_per_interview(self) -> float:
|
44
|
+
return (
|
45
|
+
self.get_elapsed_time() / self.completed_count
|
46
|
+
if self.completed_count > 0
|
47
|
+
else 0
|
48
|
+
)
|
49
|
+
|
50
|
+
def get_throughput(self) -> float:
|
51
|
+
elapsed = self.get_elapsed_time()
|
52
|
+
return self.completed_count / elapsed if elapsed > 0 else 0
|
53
|
+
|
54
|
+
def get_estimated_time_remaining(self) -> float:
|
55
|
+
if self.completed_count == 0:
|
56
|
+
return 0
|
57
|
+
avg_time = self.get_average_time_per_interview()
|
58
|
+
remaining = self.total_interviews - self.completed_count
|
59
|
+
return avg_time * remaining
|
36
60
|
|
37
61
|
|
38
62
|
class JobsRunnerStatusBase(ABC):
|
@@ -46,48 +70,39 @@ class JobsRunnerStatusBase(ABC):
|
|
46
70
|
api_key: str = None,
|
47
71
|
):
|
48
72
|
self.jobs_runner = jobs_runner
|
49
|
-
|
50
|
-
# The uuid of the job on Coop
|
51
73
|
self.job_uuid = job_uuid
|
52
|
-
|
53
74
|
self.base_url = f"{endpoint_url}"
|
54
|
-
|
55
|
-
self.start_time = time.time()
|
56
|
-
self.completed_interviews = []
|
57
75
|
self.refresh_rate = refresh_rate
|
58
76
|
self.statistics = [
|
59
77
|
"elapsed_time",
|
60
78
|
"total_interviews_requested",
|
61
79
|
"completed_interviews",
|
62
|
-
# "percent_complete",
|
63
80
|
"average_time_per_interview",
|
64
|
-
# "task_remaining",
|
65
81
|
"estimated_time_remaining",
|
66
82
|
"exceptions",
|
67
83
|
"unfixed_exceptions",
|
68
84
|
"throughput",
|
69
85
|
]
|
70
|
-
self.num_total_interviews = n * len(self.jobs_runner
|
86
|
+
self.num_total_interviews = n * len(self.jobs_runner)
|
71
87
|
|
72
88
|
self.distinct_models = list(
|
73
|
-
set(
|
89
|
+
set(model.model for model in self.jobs_runner.jobs.models)
|
74
90
|
)
|
75
91
|
|
76
|
-
self.
|
92
|
+
self.stats_tracker = StatisticsTracker(
|
93
|
+
total_interviews=self.num_total_interviews,
|
94
|
+
distinct_models=self.distinct_models,
|
95
|
+
)
|
77
96
|
|
78
97
|
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
79
98
|
|
80
99
|
@abstractmethod
|
81
100
|
def has_ep_api_key(self):
|
82
|
-
"""
|
83
|
-
Checks if the user has an Expected Parrot API key.
|
84
|
-
"""
|
101
|
+
"""Checks if the user has an Expected Parrot API key."""
|
85
102
|
pass
|
86
103
|
|
87
104
|
def get_status_dict(self) -> Dict[str, Any]:
|
88
|
-
"""
|
89
|
-
Converts current status into a JSON-serializable dictionary.
|
90
|
-
"""
|
105
|
+
"""Converts current status into a JSON-serializable dictionary."""
|
91
106
|
# Get all statistics
|
92
107
|
stats = {}
|
93
108
|
for stat_name in self.statistics:
|
@@ -95,38 +110,41 @@ class JobsRunnerStatusBase(ABC):
|
|
95
110
|
name, value = list(stat.items())[0]
|
96
111
|
stats[name] = value
|
97
112
|
|
98
|
-
# Calculate overall progress
|
99
|
-
total_interviews = len(self.jobs_runner.total_interviews)
|
100
|
-
completed = len(self.completed_interviews)
|
101
|
-
|
102
113
|
# Get model-specific progress
|
103
114
|
model_progress = {}
|
115
|
+
target_per_model = int(self.num_total_interviews / len(self.distinct_models))
|
116
|
+
|
104
117
|
for model in self.distinct_models:
|
105
|
-
|
106
|
-
target_for_model = int(
|
107
|
-
self.num_total_interviews / len(self.distinct_models)
|
108
|
-
)
|
118
|
+
completed = self.stats_tracker.completed_by_model[model]
|
109
119
|
model_progress[model] = {
|
110
|
-
"completed":
|
111
|
-
"total":
|
120
|
+
"completed": completed,
|
121
|
+
"total": target_per_model,
|
112
122
|
"percent": (
|
113
|
-
(
|
114
|
-
if target_for_model > 0
|
115
|
-
else 0
|
123
|
+
(completed / target_per_model * 100) if target_per_model > 0 else 0
|
116
124
|
),
|
117
125
|
}
|
118
126
|
|
119
127
|
status_dict = {
|
120
128
|
"overall_progress": {
|
121
|
-
"completed":
|
122
|
-
"total":
|
129
|
+
"completed": self.stats_tracker.completed_count,
|
130
|
+
"total": self.num_total_interviews,
|
123
131
|
"percent": (
|
124
|
-
(
|
132
|
+
(
|
133
|
+
self.stats_tracker.completed_count
|
134
|
+
/ self.num_total_interviews
|
135
|
+
* 100
|
136
|
+
)
|
137
|
+
if self.num_total_interviews > 0
|
138
|
+
else 0
|
125
139
|
),
|
126
140
|
},
|
127
141
|
"language_model_progress": model_progress,
|
128
142
|
"statistics": stats,
|
129
|
-
"status":
|
143
|
+
"status": (
|
144
|
+
"completed"
|
145
|
+
if self.stats_tracker.completed_count >= self.num_total_interviews
|
146
|
+
else "running"
|
147
|
+
),
|
130
148
|
}
|
131
149
|
|
132
150
|
model_queues = {}
|
@@ -152,99 +170,68 @@ class JobsRunnerStatusBase(ABC):
|
|
152
170
|
status_dict["language_model_queues"] = model_queues
|
153
171
|
return status_dict
|
154
172
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
173
|
+
def add_completed_interview(self, result):
|
174
|
+
"""Records a completed interview without storing the full interview data."""
|
175
|
+
self.stats_tracker.add_completed_interview(
|
176
|
+
model=result.model.model,
|
177
|
+
num_exceptions=(
|
178
|
+
len(result.exceptions) if hasattr(result, "exceptions") else 0
|
179
|
+
),
|
180
|
+
num_unfixed=(
|
181
|
+
result.exceptions.num_unfixed() if hasattr(result, "exceptions") else 0
|
182
|
+
),
|
183
|
+
)
|
159
184
|
|
160
|
-
|
161
|
-
"""
|
162
|
-
|
185
|
+
def _compute_statistic(self, stat_name: str):
|
186
|
+
"""Computes individual statistics based on the stats tracker."""
|
187
|
+
if stat_name == "elapsed_time":
|
188
|
+
value = self.stats_tracker.get_elapsed_time()
|
189
|
+
return {"elapsed_time": (value, 1, "sec.")}
|
163
190
|
|
164
|
-
|
165
|
-
|
166
|
-
"""
|
167
|
-
Updates the current status of the job.
|
168
|
-
"""
|
169
|
-
pass
|
191
|
+
elif stat_name == "total_interviews_requested":
|
192
|
+
return {"total_interviews_requested": (self.num_total_interviews, None, "")}
|
170
193
|
|
171
|
-
|
172
|
-
|
194
|
+
elif stat_name == "completed_interviews":
|
195
|
+
return {
|
196
|
+
"completed_interviews": (self.stats_tracker.completed_count, None, "")
|
197
|
+
}
|
173
198
|
|
174
|
-
|
175
|
-
|
199
|
+
elif stat_name == "average_time_per_interview":
|
200
|
+
value = self.stats_tracker.get_average_time_per_interview()
|
201
|
+
return {"average_time_per_interview": (value, 2, "sec.")}
|
176
202
|
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
interviews = self.jobs_runner.total_interviews
|
203
|
+
elif stat_name == "estimated_time_remaining":
|
204
|
+
value = self.stats_tracker.get_estimated_time_remaining()
|
205
|
+
return {"estimated_time_remaining": (value, 1, "sec.")}
|
181
206
|
|
182
|
-
|
183
|
-
"
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
"
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
)
|
192
|
-
"
|
193
|
-
"percent_complete",
|
194
|
-
value=(
|
195
|
-
len(completed_tasks) / len(interviews) * 100
|
196
|
-
if len(interviews) > 0
|
197
|
-
else 0
|
198
|
-
),
|
199
|
-
digits=1,
|
200
|
-
units="%",
|
201
|
-
),
|
202
|
-
"average_time_per_interview": lambda: InterviewStatistic(
|
203
|
-
"average_time_per_interview",
|
204
|
-
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
205
|
-
digits=2,
|
206
|
-
units="sec.",
|
207
|
-
),
|
208
|
-
"task_remaining": lambda: InterviewStatistic(
|
209
|
-
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
210
|
-
),
|
211
|
-
"estimated_time_remaining": lambda: InterviewStatistic(
|
212
|
-
"estimated_time_remaining",
|
213
|
-
value=(
|
214
|
-
(len(interviews) - len(completed_tasks))
|
215
|
-
* (elapsed_time / len(completed_tasks))
|
216
|
-
if len(completed_tasks) > 0
|
217
|
-
else 0
|
218
|
-
),
|
219
|
-
digits=1,
|
220
|
-
units="sec.",
|
221
|
-
),
|
222
|
-
"exceptions": lambda: InterviewStatistic(
|
223
|
-
"exceptions",
|
224
|
-
value=sum(len(i.exceptions) for i in interviews),
|
225
|
-
units="",
|
226
|
-
),
|
227
|
-
"unfixed_exceptions": lambda: InterviewStatistic(
|
228
|
-
"unfixed_exceptions",
|
229
|
-
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
230
|
-
units="",
|
231
|
-
),
|
232
|
-
"throughput": lambda: InterviewStatistic(
|
233
|
-
"throughput",
|
234
|
-
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
235
|
-
digits=2,
|
236
|
-
units="interviews/sec.",
|
237
|
-
),
|
238
|
-
}
|
239
|
-
return stat_definitions[stat_name]()
|
207
|
+
elif stat_name == "exceptions":
|
208
|
+
return {"exceptions": (self.stats_tracker.total_exceptions, None, "")}
|
209
|
+
|
210
|
+
elif stat_name == "unfixed_exceptions":
|
211
|
+
return {
|
212
|
+
"unfixed_exceptions": (self.stats_tracker.unfixed_exceptions, None, "")
|
213
|
+
}
|
214
|
+
|
215
|
+
elif stat_name == "throughput":
|
216
|
+
value = self.stats_tracker.get_throughput()
|
217
|
+
return {"throughput": (value, 2, "interviews/sec.")}
|
240
218
|
|
241
219
|
def update_progress(self, stop_event):
|
242
220
|
while not stop_event.is_set():
|
243
221
|
self.send_status_update()
|
244
222
|
time.sleep(self.refresh_rate)
|
245
|
-
|
246
223
|
self.send_status_update()
|
247
224
|
|
225
|
+
@abstractmethod
|
226
|
+
def setup(self):
|
227
|
+
"""Conducts any setup needed prior to sending status updates."""
|
228
|
+
pass
|
229
|
+
|
230
|
+
@abstractmethod
|
231
|
+
def send_status_update(self):
|
232
|
+
"""Updates the current status of the job."""
|
233
|
+
pass
|
234
|
+
|
248
235
|
|
249
236
|
class JobsRunnerStatus(JobsRunnerStatusBase):
|
250
237
|
@property
|
@@ -260,49 +247,35 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
|
|
260
247
|
return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
|
261
248
|
|
262
249
|
def setup(self) -> None:
|
263
|
-
"""
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
if self.api_key:
|
270
|
-
headers["Authorization"] = f"Bearer {self.api_key}"
|
271
|
-
else:
|
272
|
-
headers["Authorization"] = f"Bearer None"
|
250
|
+
"""Creates a local job on Coop if one does not already exist."""
|
251
|
+
headers = {
|
252
|
+
"Content-Type": "application/json",
|
253
|
+
"Authorization": f"Bearer {self.api_key or 'None'}",
|
254
|
+
}
|
273
255
|
|
274
256
|
if self.job_uuid is None:
|
275
|
-
# Create a new local job
|
276
257
|
response = requests.post(
|
277
258
|
self.create_url,
|
278
259
|
headers=headers,
|
279
260
|
timeout=1,
|
280
261
|
)
|
281
|
-
|
282
|
-
|
283
|
-
|
262
|
+
response.raise_for_status()
|
263
|
+
data = response.json()
|
264
|
+
self.job_uuid = data.get("job_uuid")
|
284
265
|
|
285
266
|
print(f"Running with progress bar. View progress at {self.viewing_url}")
|
286
267
|
|
287
268
|
def send_status_update(self) -> None:
|
288
|
-
"""
|
289
|
-
Sends current status to the web endpoint using the instance's job_uuid.
|
290
|
-
"""
|
269
|
+
"""Sends current status to the web endpoint using the instance's job_uuid."""
|
291
270
|
try:
|
292
|
-
# Get the status dictionary and add the job_id
|
293
271
|
status_dict = self.get_status_dict()
|
294
|
-
|
295
|
-
# Make the UUID JSON serializable
|
296
272
|
status_dict["job_id"] = str(self.job_uuid)
|
297
273
|
|
298
|
-
headers = {
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
else:
|
303
|
-
headers["Authorization"] = f"Bearer None"
|
274
|
+
headers = {
|
275
|
+
"Content-Type": "application/json",
|
276
|
+
"Authorization": f"Bearer {self.api_key or 'None'}",
|
277
|
+
}
|
304
278
|
|
305
|
-
# Send the update
|
306
279
|
response = requests.patch(
|
307
280
|
self.update_url,
|
308
281
|
json=status_dict,
|
@@ -314,14 +287,8 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
|
|
314
287
|
print(f"Failed to send status update for job {self.job_uuid}: {e}")
|
315
288
|
|
316
289
|
def has_ep_api_key(self) -> bool:
|
317
|
-
"""
|
318
|
-
|
319
|
-
"""
|
320
|
-
|
321
|
-
if self.api_key is not None:
|
322
|
-
return True
|
323
|
-
else:
|
324
|
-
return False
|
290
|
+
"""Returns True if the user has an Expected Parrot API key."""
|
291
|
+
return self.api_key is not None
|
325
292
|
|
326
293
|
|
327
294
|
if __name__ == "__main__":
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -8,9 +8,10 @@ from edsl.Base import RepresentationMixin
|
|
8
8
|
class TaskHistory(RepresentationMixin):
|
9
9
|
def __init__(
|
10
10
|
self,
|
11
|
-
interviews: List["Interview"],
|
11
|
+
interviews: List["Interview"] = None,
|
12
12
|
include_traceback: bool = False,
|
13
13
|
max_interviews: int = 10,
|
14
|
+
interviews_with_exceptions_only: bool = False,
|
14
15
|
):
|
15
16
|
"""
|
16
17
|
The structure of a TaskHistory exception
|
@@ -20,13 +21,33 @@ class TaskHistory(RepresentationMixin):
|
|
20
21
|
>>> _ = TaskHistory.example()
|
21
22
|
...
|
22
23
|
"""
|
24
|
+
self.interviews_with_exceptions_only = interviews_with_exceptions_only
|
25
|
+
self._interviews = {}
|
26
|
+
self.total_interviews = []
|
27
|
+
if interviews is not None:
|
28
|
+
for interview in interviews:
|
29
|
+
self.add_interview(interview)
|
23
30
|
|
24
|
-
self.
|
31
|
+
self.include_traceback = include_traceback
|
32
|
+
self._interviews = {
|
33
|
+
index: interview for index, interview in enumerate(self.total_interviews)
|
34
|
+
}
|
35
|
+
self.max_interviews = max_interviews
|
36
|
+
|
37
|
+
# self.total_interviews = interviews
|
25
38
|
self.include_traceback = include_traceback
|
26
39
|
|
27
|
-
self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
|
40
|
+
# self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
|
28
41
|
self.max_interviews = max_interviews
|
29
42
|
|
43
|
+
def add_interview(self, interview: "Interview"):
|
44
|
+
"""Add a single interview to the history"""
|
45
|
+
if self.interviews_with_exceptions_only and interview.exceptions == {}:
|
46
|
+
return
|
47
|
+
|
48
|
+
self.total_interviews.append(interview)
|
49
|
+
self._interviews[len(self._interviews)] = interview
|
50
|
+
|
30
51
|
@classmethod
|
31
52
|
def example(cls):
|
32
53
|
""" """
|
@@ -44,6 +44,8 @@ if TYPE_CHECKING:
|
|
44
44
|
from edsl.questions.QuestionBase import QuestionBase
|
45
45
|
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
46
46
|
|
47
|
+
from edsl.enums import InferenceServiceType
|
48
|
+
|
47
49
|
from edsl.utilities.decorators import (
|
48
50
|
sync_wrapper,
|
49
51
|
jupyter_nb_handler,
|
@@ -155,7 +157,9 @@ class LanguageModel(
|
|
155
157
|
return klc.get(("config", "env"))
|
156
158
|
|
157
159
|
def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
158
|
-
|
160
|
+
"""Set the key lookup, later"""
|
161
|
+
if hasattr(self, "_api_token"):
|
162
|
+
del self._api_token
|
159
163
|
self.key_lookup = key_lookup
|
160
164
|
|
161
165
|
def ask_question(self, question: "QuestionBase") -> str:
|
@@ -493,7 +497,10 @@ class LanguageModel(
|
|
493
497
|
>>> m.to_dict()
|
494
498
|
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
495
499
|
"""
|
496
|
-
d = {
|
500
|
+
d = {
|
501
|
+
"model": self.model,
|
502
|
+
"parameters": self.parameters,
|
503
|
+
}
|
497
504
|
if add_edsl_version:
|
498
505
|
from edsl import __version__
|
499
506
|
|
@@ -505,7 +512,7 @@ class LanguageModel(
|
|
505
512
|
@remove_edsl_version
|
506
513
|
def from_dict(cls, data: dict) -> Type[LanguageModel]:
|
507
514
|
"""Convert dictionary to a LanguageModel child instance."""
|
508
|
-
from edsl.language_models.
|
515
|
+
from edsl.language_models.model import get_model_class
|
509
516
|
|
510
517
|
model_class = get_model_class(data["model"])
|
511
518
|
return model_class(**data)
|
@@ -553,7 +560,7 @@ class LanguageModel(
|
|
553
560
|
Exception report saved to ...
|
554
561
|
Also see: ...
|
555
562
|
"""
|
556
|
-
from edsl.language_models.
|
563
|
+
from edsl.language_models.model import Model
|
557
564
|
|
558
565
|
if test_model:
|
559
566
|
m = Model(
|
@@ -563,6 +570,54 @@ class LanguageModel(
|
|
563
570
|
else:
|
564
571
|
return Model(skip_api_key_check=True)
|
565
572
|
|
573
|
+
def from_cache(self, cache: "Cache") -> LanguageModel:
|
574
|
+
|
575
|
+
from copy import deepcopy
|
576
|
+
from types import MethodType
|
577
|
+
from edsl import Cache
|
578
|
+
|
579
|
+
new_instance = deepcopy(self)
|
580
|
+
print("Cache entries", len(cache))
|
581
|
+
new_instance.cache = Cache(
|
582
|
+
data={k: v for k, v in cache.items() if v.model == self.model}
|
583
|
+
)
|
584
|
+
print("Cache entries with same model", len(new_instance.cache))
|
585
|
+
|
586
|
+
new_instance.user_prompts = [
|
587
|
+
ce.user_prompt for ce in new_instance.cache.values()
|
588
|
+
]
|
589
|
+
new_instance.system_prompts = [
|
590
|
+
ce.system_prompt for ce in new_instance.cache.values()
|
591
|
+
]
|
592
|
+
|
593
|
+
async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
|
594
|
+
cache_call_params = {
|
595
|
+
"model": str(self.model),
|
596
|
+
"parameters": self.parameters,
|
597
|
+
"system_prompt": system_prompt,
|
598
|
+
"user_prompt": user_prompt,
|
599
|
+
"iteration": 1,
|
600
|
+
}
|
601
|
+
cached_response, cache_key = cache.fetch(**cache_call_params)
|
602
|
+
response = json.loads(cached_response)
|
603
|
+
cost = 0
|
604
|
+
return ModelResponse(
|
605
|
+
response=response,
|
606
|
+
cache_used=True,
|
607
|
+
cache_key=cache_key,
|
608
|
+
cached_response=cached_response,
|
609
|
+
cost=cost,
|
610
|
+
)
|
611
|
+
|
612
|
+
# Bind the new method to the copied instance
|
613
|
+
setattr(
|
614
|
+
new_instance,
|
615
|
+
"async_execute_model_call",
|
616
|
+
MethodType(async_execute_model_call, new_instance),
|
617
|
+
)
|
618
|
+
|
619
|
+
return new_instance
|
620
|
+
|
566
621
|
|
567
622
|
if __name__ == "__main__":
|
568
623
|
"""Run the module's test suite."""
|
@@ -1,18 +1,22 @@
|
|
1
|
-
from typing import Optional, List
|
1
|
+
from typing import Optional, List, TYPE_CHECKING
|
2
2
|
from collections import UserList
|
3
3
|
|
4
4
|
from edsl.Base import Base
|
5
|
-
from edsl.language_models.
|
5
|
+
from edsl.language_models.model import Model
|
6
6
|
|
7
|
-
#
|
7
|
+
#
|
8
8
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
9
9
|
from edsl.utilities.is_valid_variable_name import is_valid_variable_name
|
10
10
|
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from edsl.inference_services.data_structures import AvailableModels
|
13
|
+
from edsl.language_models import LanguageModel
|
14
|
+
|
11
15
|
|
12
16
|
class ModelList(Base, UserList):
|
13
17
|
__documentation__ = """https://docs.expectedparrot.com/en/latest/language_models.html#module-edsl.language_models.ModelList"""
|
14
18
|
|
15
|
-
def __init__(self, data: Optional[
|
19
|
+
def __init__(self, data: Optional["LanguageModel"] = None):
|
16
20
|
"""Initialize the ScenarioList class.
|
17
21
|
|
18
22
|
>>> from edsl import Model
|
@@ -33,9 +37,6 @@ class ModelList(Base, UserList):
|
|
33
37
|
"""
|
34
38
|
return set([model.model for model in self])
|
35
39
|
|
36
|
-
def rich_print(self):
|
37
|
-
pass
|
38
|
-
|
39
40
|
def __repr__(self):
|
40
41
|
return f"ModelList({super().__repr__()})"
|
41
42
|
|
@@ -87,7 +88,7 @@ class ModelList(Base, UserList):
|
|
87
88
|
.table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
|
88
89
|
)
|
89
90
|
|
90
|
-
def to_list(self):
|
91
|
+
def to_list(self) -> list:
|
91
92
|
return self.to_scenario_list().to_list()
|
92
93
|
|
93
94
|
def to_dict(self, sort=False, add_edsl_version=True):
|
@@ -120,6 +121,16 @@ class ModelList(Base, UserList):
|
|
120
121
|
args = args[0]
|
121
122
|
return ModelList([Model(model_name, **kwargs) for model_name in args])
|
122
123
|
|
124
|
+
@classmethod
|
125
|
+
def from_available_models(self, available_models_list: "AvailableModels"):
|
126
|
+
"""Create a ModelList from an AvailableModels object"""
|
127
|
+
return ModelList(
|
128
|
+
[
|
129
|
+
Model(model.model_name, service_name=model.service_name)
|
130
|
+
for model in available_models_list
|
131
|
+
]
|
132
|
+
)
|
133
|
+
|
123
134
|
@classmethod
|
124
135
|
@remove_edsl_version
|
125
136
|
def from_dict(cls, data):
|
edsl/language_models/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
1
|
from edsl.language_models.LanguageModel import LanguageModel
|
2
|
-
from edsl.language_models.
|
2
|
+
from edsl.language_models.model import Model
|