edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +0 -28
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +17 -9
- edsl/agents/Invigilator.py +14 -13
- edsl/agents/InvigilatorBase.py +1 -4
- edsl/agents/PromptConstructor.py +22 -42
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/coop/coop.py +5 -21
- edsl/data/Cache.py +18 -29
- edsl/data/CacheHandler.py +2 -0
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/enums.py +0 -7
- edsl/inference_services/AnthropicService.py +16 -38
- edsl/inference_services/AvailableModelFetcher.py +1 -7
- edsl/inference_services/GoogleService.py +1 -5
- edsl/inference_services/InferenceServicesCollection.py +2 -18
- edsl/inference_services/OpenAIService.py +31 -46
- edsl/inference_services/TestService.py +3 -1
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/inference_services/data_structures.py +2 -74
- edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
- edsl/jobs/FetchInvigilator.py +3 -10
- edsl/jobs/InterviewsConstructor.py +4 -6
- edsl/jobs/Jobs.py +233 -299
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
- edsl/jobs/interviews/Interview.py +42 -80
- edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/TaskHistory.py +3 -24
- edsl/language_models/LanguageModel.py +4 -59
- edsl/language_models/ModelList.py +8 -19
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/registry.py +180 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +26 -35
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +15 -9
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
- edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/results/DatasetExportMixin.py +119 -60
- edsl/results/Result.py +3 -109
- edsl/results/Results.py +39 -50
- edsl/scenarios/FileStore.py +0 -32
- edsl/scenarios/ScenarioList.py +7 -35
- edsl/scenarios/handlers/csv.py +0 -11
- edsl/surveys/Survey.py +20 -71
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/model.py +0 -256
- edsl/questions/data_structures.py +0 -20
- edsl/results/file_exports.py +0 -252
- /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
- /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
- /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- /edsl/results/{results_selector.py → Selector.py} +0 -0
- /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
- /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
- /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
@@ -3,12 +3,21 @@ from __future__ import annotations
|
|
3
3
|
import os
|
4
4
|
import time
|
5
5
|
import requests
|
6
|
+
import warnings
|
6
7
|
from abc import ABC, abstractmethod
|
7
8
|
from dataclasses import dataclass
|
9
|
+
|
10
|
+
from typing import Any, List, DefaultDict, Optional, Dict
|
8
11
|
from collections import defaultdict
|
9
|
-
from typing import Any, Dict, Optional
|
10
12
|
from uuid import UUID
|
11
13
|
|
14
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
15
|
+
|
16
|
+
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
17
|
+
|
18
|
+
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
19
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
20
|
+
|
12
21
|
|
13
22
|
@dataclass
|
14
23
|
class ModelInfo:
|
@@ -19,44 +28,11 @@ class ModelInfo:
|
|
19
28
|
token_usage_info: dict
|
20
29
|
|
21
30
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
self.completed_by_model = defaultdict(int)
|
28
|
-
self.distinct_models = distinct_models
|
29
|
-
self.total_exceptions = 0
|
30
|
-
self.unfixed_exceptions = 0
|
31
|
-
|
32
|
-
def add_completed_interview(
|
33
|
-
self, model: str, num_exceptions: int = 0, num_unfixed: int = 0
|
34
|
-
):
|
35
|
-
self.completed_count += 1
|
36
|
-
self.completed_by_model[model] += 1
|
37
|
-
self.total_exceptions += num_exceptions
|
38
|
-
self.unfixed_exceptions += num_unfixed
|
39
|
-
|
40
|
-
def get_elapsed_time(self) -> float:
|
41
|
-
return time.time() - self.start_time
|
42
|
-
|
43
|
-
def get_average_time_per_interview(self) -> float:
|
44
|
-
return (
|
45
|
-
self.get_elapsed_time() / self.completed_count
|
46
|
-
if self.completed_count > 0
|
47
|
-
else 0
|
48
|
-
)
|
49
|
-
|
50
|
-
def get_throughput(self) -> float:
|
51
|
-
elapsed = self.get_elapsed_time()
|
52
|
-
return self.completed_count / elapsed if elapsed > 0 else 0
|
53
|
-
|
54
|
-
def get_estimated_time_remaining(self) -> float:
|
55
|
-
if self.completed_count == 0:
|
56
|
-
return 0
|
57
|
-
avg_time = self.get_average_time_per_interview()
|
58
|
-
remaining = self.total_interviews - self.completed_count
|
59
|
-
return avg_time * remaining
|
31
|
+
@dataclass
|
32
|
+
class ModelTokenUsageStats:
|
33
|
+
token_usage_type: str
|
34
|
+
details: List[dict]
|
35
|
+
cost: str
|
60
36
|
|
61
37
|
|
62
38
|
class JobsRunnerStatusBase(ABC):
|
@@ -70,39 +46,48 @@ class JobsRunnerStatusBase(ABC):
|
|
70
46
|
api_key: str = None,
|
71
47
|
):
|
72
48
|
self.jobs_runner = jobs_runner
|
49
|
+
|
50
|
+
# The uuid of the job on Coop
|
73
51
|
self.job_uuid = job_uuid
|
52
|
+
|
74
53
|
self.base_url = f"{endpoint_url}"
|
54
|
+
|
55
|
+
self.start_time = time.time()
|
56
|
+
self.completed_interviews = []
|
75
57
|
self.refresh_rate = refresh_rate
|
76
58
|
self.statistics = [
|
77
59
|
"elapsed_time",
|
78
60
|
"total_interviews_requested",
|
79
61
|
"completed_interviews",
|
62
|
+
# "percent_complete",
|
80
63
|
"average_time_per_interview",
|
64
|
+
# "task_remaining",
|
81
65
|
"estimated_time_remaining",
|
82
66
|
"exceptions",
|
83
67
|
"unfixed_exceptions",
|
84
68
|
"throughput",
|
85
69
|
]
|
86
|
-
self.num_total_interviews = n * len(self.jobs_runner)
|
70
|
+
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
87
71
|
|
88
72
|
self.distinct_models = list(
|
89
|
-
set(model.model for
|
73
|
+
set(i.model.model for i in self.jobs_runner.interviews)
|
90
74
|
)
|
91
75
|
|
92
|
-
self.
|
93
|
-
total_interviews=self.num_total_interviews,
|
94
|
-
distinct_models=self.distinct_models,
|
95
|
-
)
|
76
|
+
self.completed_interview_by_model = defaultdict(list)
|
96
77
|
|
97
78
|
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
98
79
|
|
99
80
|
@abstractmethod
|
100
81
|
def has_ep_api_key(self):
|
101
|
-
"""
|
82
|
+
"""
|
83
|
+
Checks if the user has an Expected Parrot API key.
|
84
|
+
"""
|
102
85
|
pass
|
103
86
|
|
104
87
|
def get_status_dict(self) -> Dict[str, Any]:
|
105
|
-
"""
|
88
|
+
"""
|
89
|
+
Converts current status into a JSON-serializable dictionary.
|
90
|
+
"""
|
106
91
|
# Get all statistics
|
107
92
|
stats = {}
|
108
93
|
for stat_name in self.statistics:
|
@@ -110,46 +95,42 @@ class JobsRunnerStatusBase(ABC):
|
|
110
95
|
name, value = list(stat.items())[0]
|
111
96
|
stats[name] = value
|
112
97
|
|
98
|
+
# Calculate overall progress
|
99
|
+
total_interviews = len(self.jobs_runner.total_interviews)
|
100
|
+
completed = len(self.completed_interviews)
|
101
|
+
|
113
102
|
# Get model-specific progress
|
114
103
|
model_progress = {}
|
115
|
-
target_per_model = int(self.num_total_interviews / len(self.distinct_models))
|
116
|
-
|
117
104
|
for model in self.distinct_models:
|
118
|
-
|
105
|
+
completed_for_model = len(self.completed_interview_by_model[model])
|
106
|
+
target_for_model = int(
|
107
|
+
self.num_total_interviews / len(self.distinct_models)
|
108
|
+
)
|
119
109
|
model_progress[model] = {
|
120
|
-
"completed":
|
121
|
-
"total":
|
110
|
+
"completed": completed_for_model,
|
111
|
+
"total": target_for_model,
|
122
112
|
"percent": (
|
123
|
-
(
|
113
|
+
(completed_for_model / target_for_model * 100)
|
114
|
+
if target_for_model > 0
|
115
|
+
else 0
|
124
116
|
),
|
125
117
|
}
|
126
118
|
|
127
119
|
status_dict = {
|
128
120
|
"overall_progress": {
|
129
|
-
"completed":
|
130
|
-
"total":
|
121
|
+
"completed": completed,
|
122
|
+
"total": total_interviews,
|
131
123
|
"percent": (
|
132
|
-
(
|
133
|
-
self.stats_tracker.completed_count
|
134
|
-
/ self.num_total_interviews
|
135
|
-
* 100
|
136
|
-
)
|
137
|
-
if self.num_total_interviews > 0
|
138
|
-
else 0
|
124
|
+
(completed / total_interviews * 100) if total_interviews > 0 else 0
|
139
125
|
),
|
140
126
|
},
|
141
127
|
"language_model_progress": model_progress,
|
142
128
|
"statistics": stats,
|
143
|
-
"status":
|
144
|
-
"completed"
|
145
|
-
if self.stats_tracker.completed_count >= self.num_total_interviews
|
146
|
-
else "running"
|
147
|
-
),
|
129
|
+
"status": "completed" if completed >= total_interviews else "running",
|
148
130
|
}
|
149
131
|
|
150
132
|
model_queues = {}
|
151
|
-
|
152
|
-
for model, bucket in self.jobs_runner.environment.bucket_collection.items():
|
133
|
+
for model, bucket in self.jobs_runner.bucket_collection.items():
|
153
134
|
model_name = model.model
|
154
135
|
model_queues[model_name] = {
|
155
136
|
"language_model_name": model_name,
|
@@ -171,67 +152,98 @@ class JobsRunnerStatusBase(ABC):
|
|
171
152
|
status_dict["language_model_queues"] = model_queues
|
172
153
|
return status_dict
|
173
154
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
num_exceptions=(
|
179
|
-
len(result.exceptions) if hasattr(result, "exceptions") else 0
|
180
|
-
),
|
181
|
-
num_unfixed=(
|
182
|
-
result.exceptions.num_unfixed() if hasattr(result, "exceptions") else 0
|
183
|
-
),
|
184
|
-
)
|
185
|
-
|
186
|
-
def _compute_statistic(self, stat_name: str):
|
187
|
-
"""Computes individual statistics based on the stats tracker."""
|
188
|
-
if stat_name == "elapsed_time":
|
189
|
-
value = self.stats_tracker.get_elapsed_time()
|
190
|
-
return {"elapsed_time": (value, 1, "sec.")}
|
191
|
-
|
192
|
-
elif stat_name == "total_interviews_requested":
|
193
|
-
return {"total_interviews_requested": (self.num_total_interviews, None, "")}
|
155
|
+
@abstractmethod
|
156
|
+
def setup(self):
|
157
|
+
"""
|
158
|
+
Conducts any setup that needs to happen prior to sending status updates.
|
194
159
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
}
|
160
|
+
Ex. For a local job, creates a job in the Coop database.
|
161
|
+
"""
|
162
|
+
pass
|
199
163
|
|
200
|
-
|
201
|
-
|
202
|
-
|
164
|
+
@abstractmethod
|
165
|
+
def send_status_update(self):
|
166
|
+
"""
|
167
|
+
Updates the current status of the job.
|
168
|
+
"""
|
169
|
+
pass
|
203
170
|
|
204
|
-
|
205
|
-
|
206
|
-
return {"estimated_time_remaining": (value, 1, "sec.")}
|
171
|
+
def add_completed_interview(self, result):
|
172
|
+
self.completed_interviews.append(result.interview_hash)
|
207
173
|
|
208
|
-
|
209
|
-
|
174
|
+
relevant_model = result.model.model
|
175
|
+
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
210
176
|
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
177
|
+
def _compute_statistic(self, stat_name: str):
|
178
|
+
completed_tasks = self.completed_interviews
|
179
|
+
elapsed_time = time.time() - self.start_time
|
180
|
+
interviews = self.jobs_runner.total_interviews
|
215
181
|
|
216
|
-
|
217
|
-
|
218
|
-
|
182
|
+
stat_definitions = {
|
183
|
+
"elapsed_time": lambda: InterviewStatistic(
|
184
|
+
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
185
|
+
),
|
186
|
+
"total_interviews_requested": lambda: InterviewStatistic(
|
187
|
+
"total_interviews_requested", value=len(interviews), units=""
|
188
|
+
),
|
189
|
+
"completed_interviews": lambda: InterviewStatistic(
|
190
|
+
"completed_interviews", value=len(completed_tasks), units=""
|
191
|
+
),
|
192
|
+
"percent_complete": lambda: InterviewStatistic(
|
193
|
+
"percent_complete",
|
194
|
+
value=(
|
195
|
+
len(completed_tasks) / len(interviews) * 100
|
196
|
+
if len(interviews) > 0
|
197
|
+
else 0
|
198
|
+
),
|
199
|
+
digits=1,
|
200
|
+
units="%",
|
201
|
+
),
|
202
|
+
"average_time_per_interview": lambda: InterviewStatistic(
|
203
|
+
"average_time_per_interview",
|
204
|
+
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
205
|
+
digits=2,
|
206
|
+
units="sec.",
|
207
|
+
),
|
208
|
+
"task_remaining": lambda: InterviewStatistic(
|
209
|
+
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
210
|
+
),
|
211
|
+
"estimated_time_remaining": lambda: InterviewStatistic(
|
212
|
+
"estimated_time_remaining",
|
213
|
+
value=(
|
214
|
+
(len(interviews) - len(completed_tasks))
|
215
|
+
* (elapsed_time / len(completed_tasks))
|
216
|
+
if len(completed_tasks) > 0
|
217
|
+
else 0
|
218
|
+
),
|
219
|
+
digits=1,
|
220
|
+
units="sec.",
|
221
|
+
),
|
222
|
+
"exceptions": lambda: InterviewStatistic(
|
223
|
+
"exceptions",
|
224
|
+
value=sum(len(i.exceptions) for i in interviews),
|
225
|
+
units="",
|
226
|
+
),
|
227
|
+
"unfixed_exceptions": lambda: InterviewStatistic(
|
228
|
+
"unfixed_exceptions",
|
229
|
+
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
230
|
+
units="",
|
231
|
+
),
|
232
|
+
"throughput": lambda: InterviewStatistic(
|
233
|
+
"throughput",
|
234
|
+
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
235
|
+
digits=2,
|
236
|
+
units="interviews/sec.",
|
237
|
+
),
|
238
|
+
}
|
239
|
+
return stat_definitions[stat_name]()
|
219
240
|
|
220
241
|
def update_progress(self, stop_event):
|
221
242
|
while not stop_event.is_set():
|
222
243
|
self.send_status_update()
|
223
244
|
time.sleep(self.refresh_rate)
|
224
|
-
self.send_status_update()
|
225
245
|
|
226
|
-
|
227
|
-
def setup(self):
|
228
|
-
"""Conducts any setup needed prior to sending status updates."""
|
229
|
-
pass
|
230
|
-
|
231
|
-
@abstractmethod
|
232
|
-
def send_status_update(self):
|
233
|
-
"""Updates the current status of the job."""
|
234
|
-
pass
|
246
|
+
self.send_status_update()
|
235
247
|
|
236
248
|
|
237
249
|
class JobsRunnerStatus(JobsRunnerStatusBase):
|
@@ -248,35 +260,49 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
|
|
248
260
|
return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
|
249
261
|
|
250
262
|
def setup(self) -> None:
|
251
|
-
"""
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
}
|
263
|
+
"""
|
264
|
+
Creates a local job on Coop if one does not already exist.
|
265
|
+
"""
|
266
|
+
|
267
|
+
headers = {"Content-Type": "application/json"}
|
268
|
+
|
269
|
+
if self.api_key:
|
270
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
271
|
+
else:
|
272
|
+
headers["Authorization"] = f"Bearer None"
|
256
273
|
|
257
274
|
if self.job_uuid is None:
|
275
|
+
# Create a new local job
|
258
276
|
response = requests.post(
|
259
277
|
self.create_url,
|
260
278
|
headers=headers,
|
261
279
|
timeout=1,
|
262
280
|
)
|
263
|
-
|
264
|
-
|
265
|
-
|
281
|
+
response.raise_for_status()
|
282
|
+
data = response.json()
|
283
|
+
self.job_uuid = data.get("job_uuid")
|
266
284
|
|
267
285
|
print(f"Running with progress bar. View progress at {self.viewing_url}")
|
268
286
|
|
269
287
|
def send_status_update(self) -> None:
|
270
|
-
"""
|
288
|
+
"""
|
289
|
+
Sends current status to the web endpoint using the instance's job_uuid.
|
290
|
+
"""
|
271
291
|
try:
|
292
|
+
# Get the status dictionary and add the job_id
|
272
293
|
status_dict = self.get_status_dict()
|
294
|
+
|
295
|
+
# Make the UUID JSON serializable
|
273
296
|
status_dict["job_id"] = str(self.job_uuid)
|
274
297
|
|
275
|
-
headers = {
|
276
|
-
|
277
|
-
|
278
|
-
|
298
|
+
headers = {"Content-Type": "application/json"}
|
299
|
+
|
300
|
+
if self.api_key:
|
301
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
302
|
+
else:
|
303
|
+
headers["Authorization"] = f"Bearer None"
|
279
304
|
|
305
|
+
# Send the update
|
280
306
|
response = requests.patch(
|
281
307
|
self.update_url,
|
282
308
|
json=status_dict,
|
@@ -288,8 +314,14 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
|
|
288
314
|
print(f"Failed to send status update for job {self.job_uuid}: {e}")
|
289
315
|
|
290
316
|
def has_ep_api_key(self) -> bool:
|
291
|
-
"""
|
292
|
-
|
317
|
+
"""
|
318
|
+
Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
|
319
|
+
"""
|
320
|
+
|
321
|
+
if self.api_key is not None:
|
322
|
+
return True
|
323
|
+
else:
|
324
|
+
return False
|
293
325
|
|
294
326
|
|
295
327
|
if __name__ == "__main__":
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -8,10 +8,9 @@ from edsl.Base import RepresentationMixin
|
|
8
8
|
class TaskHistory(RepresentationMixin):
|
9
9
|
def __init__(
|
10
10
|
self,
|
11
|
-
interviews: List["Interview"]
|
11
|
+
interviews: List["Interview"],
|
12
12
|
include_traceback: bool = False,
|
13
13
|
max_interviews: int = 10,
|
14
|
-
interviews_with_exceptions_only: bool = False,
|
15
14
|
):
|
16
15
|
"""
|
17
16
|
The structure of a TaskHistory exception
|
@@ -21,33 +20,13 @@ class TaskHistory(RepresentationMixin):
|
|
21
20
|
>>> _ = TaskHistory.example()
|
22
21
|
...
|
23
22
|
"""
|
24
|
-
self.interviews_with_exceptions_only = interviews_with_exceptions_only
|
25
|
-
self._interviews = {}
|
26
|
-
self.total_interviews = []
|
27
|
-
if interviews is not None:
|
28
|
-
for interview in interviews:
|
29
|
-
self.add_interview(interview)
|
30
23
|
|
31
|
-
self.
|
32
|
-
self._interviews = {
|
33
|
-
index: interview for index, interview in enumerate(self.total_interviews)
|
34
|
-
}
|
35
|
-
self.max_interviews = max_interviews
|
36
|
-
|
37
|
-
# self.total_interviews = interviews
|
24
|
+
self.total_interviews = interviews
|
38
25
|
self.include_traceback = include_traceback
|
39
26
|
|
40
|
-
|
27
|
+
self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
|
41
28
|
self.max_interviews = max_interviews
|
42
29
|
|
43
|
-
def add_interview(self, interview: "Interview"):
|
44
|
-
"""Add a single interview to the history"""
|
45
|
-
if self.interviews_with_exceptions_only and interview.exceptions == {}:
|
46
|
-
return
|
47
|
-
|
48
|
-
self.total_interviews.append(interview)
|
49
|
-
self._interviews[len(self._interviews)] = interview
|
50
|
-
|
51
30
|
@classmethod
|
52
31
|
def example(cls):
|
53
32
|
""" """
|
@@ -44,8 +44,6 @@ if TYPE_CHECKING:
|
|
44
44
|
from edsl.questions.QuestionBase import QuestionBase
|
45
45
|
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
46
46
|
|
47
|
-
from edsl.enums import InferenceServiceType
|
48
|
-
|
49
47
|
from edsl.utilities.decorators import (
|
50
48
|
sync_wrapper,
|
51
49
|
jupyter_nb_handler,
|
@@ -157,9 +155,7 @@ class LanguageModel(
|
|
157
155
|
return klc.get(("config", "env"))
|
158
156
|
|
159
157
|
def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
160
|
-
|
161
|
-
if hasattr(self, "_api_token"):
|
162
|
-
del self._api_token
|
158
|
+
del self._api_token
|
163
159
|
self.key_lookup = key_lookup
|
164
160
|
|
165
161
|
def ask_question(self, question: "QuestionBase") -> str:
|
@@ -497,10 +493,7 @@ class LanguageModel(
|
|
497
493
|
>>> m.to_dict()
|
498
494
|
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
499
495
|
"""
|
500
|
-
d = {
|
501
|
-
"model": self.model,
|
502
|
-
"parameters": self.parameters,
|
503
|
-
}
|
496
|
+
d = {"model": self.model, "parameters": self.parameters}
|
504
497
|
if add_edsl_version:
|
505
498
|
from edsl import __version__
|
506
499
|
|
@@ -512,7 +505,7 @@ class LanguageModel(
|
|
512
505
|
@remove_edsl_version
|
513
506
|
def from_dict(cls, data: dict) -> Type[LanguageModel]:
|
514
507
|
"""Convert dictionary to a LanguageModel child instance."""
|
515
|
-
from edsl.language_models.
|
508
|
+
from edsl.language_models.registry import get_model_class
|
516
509
|
|
517
510
|
model_class = get_model_class(data["model"])
|
518
511
|
return model_class(**data)
|
@@ -560,7 +553,7 @@ class LanguageModel(
|
|
560
553
|
Exception report saved to ...
|
561
554
|
Also see: ...
|
562
555
|
"""
|
563
|
-
from edsl.language_models.
|
556
|
+
from edsl.language_models.registry import Model
|
564
557
|
|
565
558
|
if test_model:
|
566
559
|
m = Model(
|
@@ -570,54 +563,6 @@ class LanguageModel(
|
|
570
563
|
else:
|
571
564
|
return Model(skip_api_key_check=True)
|
572
565
|
|
573
|
-
def from_cache(self, cache: "Cache") -> LanguageModel:
|
574
|
-
|
575
|
-
from copy import deepcopy
|
576
|
-
from types import MethodType
|
577
|
-
from edsl import Cache
|
578
|
-
|
579
|
-
new_instance = deepcopy(self)
|
580
|
-
print("Cache entries", len(cache))
|
581
|
-
new_instance.cache = Cache(
|
582
|
-
data={k: v for k, v in cache.items() if v.model == self.model}
|
583
|
-
)
|
584
|
-
print("Cache entries with same model", len(new_instance.cache))
|
585
|
-
|
586
|
-
new_instance.user_prompts = [
|
587
|
-
ce.user_prompt for ce in new_instance.cache.values()
|
588
|
-
]
|
589
|
-
new_instance.system_prompts = [
|
590
|
-
ce.system_prompt for ce in new_instance.cache.values()
|
591
|
-
]
|
592
|
-
|
593
|
-
async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
|
594
|
-
cache_call_params = {
|
595
|
-
"model": str(self.model),
|
596
|
-
"parameters": self.parameters,
|
597
|
-
"system_prompt": system_prompt,
|
598
|
-
"user_prompt": user_prompt,
|
599
|
-
"iteration": 1,
|
600
|
-
}
|
601
|
-
cached_response, cache_key = cache.fetch(**cache_call_params)
|
602
|
-
response = json.loads(cached_response)
|
603
|
-
cost = 0
|
604
|
-
return ModelResponse(
|
605
|
-
response=response,
|
606
|
-
cache_used=True,
|
607
|
-
cache_key=cache_key,
|
608
|
-
cached_response=cached_response,
|
609
|
-
cost=cost,
|
610
|
-
)
|
611
|
-
|
612
|
-
# Bind the new method to the copied instance
|
613
|
-
setattr(
|
614
|
-
new_instance,
|
615
|
-
"async_execute_model_call",
|
616
|
-
MethodType(async_execute_model_call, new_instance),
|
617
|
-
)
|
618
|
-
|
619
|
-
return new_instance
|
620
|
-
|
621
566
|
|
622
567
|
if __name__ == "__main__":
|
623
568
|
"""Run the module's test suite."""
|
@@ -1,22 +1,18 @@
|
|
1
|
-
from typing import Optional, List
|
1
|
+
from typing import Optional, List
|
2
2
|
from collections import UserList
|
3
3
|
|
4
4
|
from edsl.Base import Base
|
5
|
-
from edsl.language_models.
|
5
|
+
from edsl.language_models.registry import Model
|
6
6
|
|
7
|
-
#
|
7
|
+
# from edsl.language_models import LanguageModel
|
8
8
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
9
9
|
from edsl.utilities.is_valid_variable_name import is_valid_variable_name
|
10
10
|
|
11
|
-
if TYPE_CHECKING:
|
12
|
-
from edsl.inference_services.data_structures import AvailableModels
|
13
|
-
from edsl.language_models import LanguageModel
|
14
|
-
|
15
11
|
|
16
12
|
class ModelList(Base, UserList):
|
17
13
|
__documentation__ = """https://docs.expectedparrot.com/en/latest/language_models.html#module-edsl.language_models.ModelList"""
|
18
14
|
|
19
|
-
def __init__(self, data: Optional[
|
15
|
+
def __init__(self, data: Optional[list] = None):
|
20
16
|
"""Initialize the ScenarioList class.
|
21
17
|
|
22
18
|
>>> from edsl import Model
|
@@ -37,6 +33,9 @@ class ModelList(Base, UserList):
|
|
37
33
|
"""
|
38
34
|
return set([model.model for model in self])
|
39
35
|
|
36
|
+
def rich_print(self):
|
37
|
+
pass
|
38
|
+
|
40
39
|
def __repr__(self):
|
41
40
|
return f"ModelList({super().__repr__()})"
|
42
41
|
|
@@ -88,7 +87,7 @@ class ModelList(Base, UserList):
|
|
88
87
|
.table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
|
89
88
|
)
|
90
89
|
|
91
|
-
def to_list(self)
|
90
|
+
def to_list(self):
|
92
91
|
return self.to_scenario_list().to_list()
|
93
92
|
|
94
93
|
def to_dict(self, sort=False, add_edsl_version=True):
|
@@ -121,16 +120,6 @@ class ModelList(Base, UserList):
|
|
121
120
|
args = args[0]
|
122
121
|
return ModelList([Model(model_name, **kwargs) for model_name in args])
|
123
122
|
|
124
|
-
@classmethod
|
125
|
-
def from_available_models(self, available_models_list: "AvailableModels"):
|
126
|
-
"""Create a ModelList from an AvailableModels object"""
|
127
|
-
return ModelList(
|
128
|
-
[
|
129
|
-
Model(model.model_name, service_name=model.service_name)
|
130
|
-
for model in available_models_list
|
131
|
-
]
|
132
|
-
)
|
133
|
-
|
134
123
|
@classmethod
|
135
124
|
@remove_edsl_version
|
136
125
|
def from_dict(cls, data):
|
edsl/language_models/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
1
|
from edsl.language_models.LanguageModel import LanguageModel
|
2
|
-
from edsl.language_models.
|
2
|
+
from edsl.language_models.registry import Model
|