edsl 0.1.27.dev2__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +99 -22
- edsl/BaseDiff.py +260 -0
- edsl/__init__.py +4 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +26 -5
- edsl/agents/AgentList.py +62 -7
- edsl/agents/Invigilator.py +4 -9
- edsl/agents/InvigilatorBase.py +5 -5
- edsl/agents/descriptors.py +3 -1
- edsl/conjure/AgentConstructionMixin.py +152 -0
- edsl/conjure/Conjure.py +56 -0
- edsl/conjure/InputData.py +628 -0
- edsl/conjure/InputDataCSV.py +48 -0
- edsl/conjure/InputDataMixinQuestionStats.py +182 -0
- edsl/conjure/InputDataPyRead.py +91 -0
- edsl/conjure/InputDataSPSS.py +8 -0
- edsl/conjure/InputDataStata.py +8 -0
- edsl/conjure/QuestionOptionMixin.py +76 -0
- edsl/conjure/QuestionTypeMixin.py +23 -0
- edsl/conjure/RawQuestion.py +65 -0
- edsl/conjure/SurveyResponses.py +7 -0
- edsl/conjure/__init__.py +9 -4
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/naming_utilities.py +263 -0
- edsl/conjure/utilities.py +165 -28
- edsl/conversation/Conversation.py +238 -0
- edsl/conversation/car_buying.py +58 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/coop.py +191 -12
- edsl/coop/utils.py +20 -2
- edsl/data/Cache.py +55 -17
- edsl/data/CacheHandler.py +10 -9
- edsl/inference_services/AnthropicService.py +1 -0
- edsl/inference_services/DeepInfraService.py +20 -13
- edsl/inference_services/GoogleService.py +7 -1
- edsl/inference_services/InferenceServicesCollection.py +33 -7
- edsl/inference_services/OpenAIService.py +17 -10
- edsl/inference_services/models_available_cache.py +69 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/Jobs.py +240 -36
- edsl/jobs/buckets/BucketCollection.py +9 -3
- edsl/jobs/interviews/Interview.py +4 -1
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +24 -10
- edsl/jobs/interviews/retry_management.py +4 -4
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -45
- edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
- edsl/language_models/LanguageModel.py +37 -44
- edsl/language_models/ModelList.py +96 -0
- edsl/language_models/registry.py +14 -0
- edsl/language_models/repair.py +95 -24
- edsl/notebooks/Notebook.py +119 -31
- edsl/questions/QuestionBase.py +109 -12
- edsl/questions/descriptors.py +5 -2
- edsl/questions/question_registry.py +7 -0
- edsl/results/Result.py +20 -8
- edsl/results/Results.py +85 -11
- edsl/results/ResultsDBMixin.py +3 -6
- edsl/results/ResultsExportMixin.py +47 -16
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/Scenario.py +59 -5
- edsl/scenarios/ScenarioList.py +97 -40
- edsl/study/ObjectEntry.py +97 -0
- edsl/study/ProofOfWork.py +110 -0
- edsl/study/SnapShot.py +77 -0
- edsl/study/Study.py +491 -0
- edsl/study/__init__.py +2 -0
- edsl/surveys/Survey.py +79 -31
- edsl/surveys/SurveyExportMixin.py +21 -3
- edsl/utilities/__init__.py +1 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/interface.py +24 -28
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/utilities.py +57 -2
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/METADATA +43 -17
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/RECORD +83 -55
- edsl-0.1.28.dist-info/entry_points.txt +3 -0
- edsl/conjure/RawResponseColumn.py +0 -327
- edsl/conjure/SurveyBuilder.py +0 -308
- edsl/conjure/SurveyBuilderCSV.py +0 -78
- edsl/conjure/SurveyBuilderSPSS.py +0 -118
- edsl/data/RemoteDict.py +0 -103
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/LICENSE +0 -0
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/WHEEL +0 -0
@@ -14,7 +14,8 @@ from edsl.results import Results, Result
|
|
14
14
|
|
15
15
|
from edsl.jobs.interviews.Interview import Interview
|
16
16
|
from edsl.utilities.decorators import jupyter_nb_handler
|
17
|
-
|
17
|
+
|
18
|
+
# from edsl.jobs.Jobs import Jobs
|
18
19
|
from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
|
19
20
|
from edsl.language_models import LanguageModel
|
20
21
|
from edsl.data.Cache import Cache
|
@@ -22,6 +23,8 @@ from edsl.data.Cache import Cache
|
|
22
23
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
23
24
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
24
25
|
|
26
|
+
import time
|
27
|
+
|
25
28
|
|
26
29
|
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
27
30
|
"""A class for running a collection of interviews asynchronously.
|
@@ -32,12 +35,12 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
32
35
|
|
33
36
|
def __init__(self, jobs: Jobs):
|
34
37
|
self.jobs = jobs
|
35
|
-
|
38
|
+
# this creates the interviews, which can take a while
|
36
39
|
self.interviews: List["Interview"] = jobs.interviews()
|
37
40
|
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
38
41
|
self.total_interviews: List["Interview"] = []
|
39
42
|
|
40
|
-
async def
|
43
|
+
async def run_async_generator(
|
41
44
|
self,
|
42
45
|
cache: Cache,
|
43
46
|
n: int = 1,
|
@@ -96,6 +99,20 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
96
99
|
) # set the cache for the first interview
|
97
100
|
self.total_interviews.append(interview)
|
98
101
|
|
102
|
+
async def run_async(self, cache=None) -> Results:
|
103
|
+
if cache is None:
|
104
|
+
self.cache = Cache()
|
105
|
+
else:
|
106
|
+
self.cache = cache
|
107
|
+
data = []
|
108
|
+
async for result in self.run_async_generator(cache=self.cache):
|
109
|
+
data.append(result)
|
110
|
+
return Results(survey=self.jobs.survey, data=data)
|
111
|
+
|
112
|
+
def simple_run(self):
|
113
|
+
data = asyncio.run(self.run_async())
|
114
|
+
return Results(survey=self.jobs.survey, data=data)
|
115
|
+
|
99
116
|
async def _build_interview_task(
|
100
117
|
self,
|
101
118
|
*,
|
@@ -187,39 +204,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
187
204
|
self.cache = cache
|
188
205
|
self.sidecar_model = sidecar_model
|
189
206
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
@contextmanager
|
194
|
-
def no_op_cm():
|
195
|
-
"""A no-op context manager with a dummy update method."""
|
196
|
-
yield DummyLive()
|
197
|
-
|
198
|
-
class DummyLive:
|
199
|
-
def update(self, *args, **kwargs):
|
200
|
-
"""A dummy update method that does nothing."""
|
201
|
-
pass
|
202
|
-
|
203
|
-
progress_bar_context = (
|
204
|
-
Live(generate_table(), console=console, refresh_per_second=5)
|
205
|
-
if progress_bar
|
206
|
-
else no_op_cm()
|
207
|
-
)
|
208
|
-
|
209
|
-
with cache as c:
|
210
|
-
with progress_bar_context as live:
|
211
|
-
|
212
|
-
async def update_progress_bar():
|
213
|
-
"""Updates the progress bar at fixed intervals."""
|
214
|
-
while True:
|
215
|
-
live.update(generate_table())
|
216
|
-
await asyncio.sleep(0.1) # Update interval
|
217
|
-
if self.completed:
|
218
|
-
break
|
207
|
+
if not progress_bar:
|
208
|
+
# print("Running without progress bar")
|
209
|
+
with cache as c:
|
219
210
|
|
220
211
|
async def process_results():
|
221
212
|
"""Processes results from interviews."""
|
222
|
-
async for result in self.
|
213
|
+
async for result in self.run_async_generator(
|
223
214
|
n=n,
|
224
215
|
debug=debug,
|
225
216
|
stop_on_exception=stop_on_exception,
|
@@ -227,25 +218,74 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
227
218
|
sidecar_model=sidecar_model,
|
228
219
|
):
|
229
220
|
self.results.append(result)
|
230
|
-
live.update(generate_table())
|
231
221
|
self.completed = True
|
232
222
|
|
233
|
-
|
223
|
+
await asyncio.gather(process_results())
|
224
|
+
|
225
|
+
results = Results(survey=self.jobs.survey, data=self.results)
|
226
|
+
else:
|
227
|
+
# print("Running with progress bar")
|
228
|
+
|
229
|
+
def generate_table():
|
230
|
+
return self.status_table(self.results, self.elapsed_time)
|
234
231
|
|
235
|
-
|
236
|
-
|
237
|
-
|
232
|
+
@contextmanager
|
233
|
+
def no_op_cm():
|
234
|
+
"""A no-op context manager with a dummy update method."""
|
235
|
+
yield DummyLive()
|
236
|
+
|
237
|
+
class DummyLive:
|
238
|
+
def update(self, *args, **kwargs):
|
239
|
+
"""A dummy update method that does nothing."""
|
238
240
|
pass
|
239
|
-
finally:
|
240
|
-
progress_task.cancel() # Cancel the progress_task when process_results is done
|
241
|
-
await progress_task
|
242
241
|
|
243
|
-
|
242
|
+
progress_bar_context = (
|
243
|
+
Live(generate_table(), console=console, refresh_per_second=5)
|
244
|
+
if progress_bar
|
245
|
+
else no_op_cm()
|
246
|
+
)
|
244
247
|
|
245
|
-
|
246
|
-
|
248
|
+
with cache as c:
|
249
|
+
with progress_bar_context as live:
|
250
|
+
|
251
|
+
async def update_progress_bar():
|
252
|
+
"""Updates the progress bar at fixed intervals."""
|
253
|
+
while True:
|
254
|
+
live.update(generate_table())
|
255
|
+
await asyncio.sleep(0.00001) # Update interval
|
256
|
+
if self.completed:
|
257
|
+
break
|
258
|
+
|
259
|
+
async def process_results():
|
260
|
+
"""Processes results from interviews."""
|
261
|
+
async for result in self.run_async_generator(
|
262
|
+
n=n,
|
263
|
+
debug=debug,
|
264
|
+
stop_on_exception=stop_on_exception,
|
265
|
+
cache=c,
|
266
|
+
sidecar_model=sidecar_model,
|
267
|
+
):
|
268
|
+
self.results.append(result)
|
269
|
+
live.update(generate_table())
|
270
|
+
self.completed = True
|
271
|
+
|
272
|
+
progress_task = asyncio.create_task(update_progress_bar())
|
273
|
+
|
274
|
+
try:
|
275
|
+
await asyncio.gather(process_results(), progress_task)
|
276
|
+
except asyncio.CancelledError:
|
277
|
+
pass
|
278
|
+
finally:
|
279
|
+
progress_task.cancel() # Cancel the progress_task when process_results is done
|
280
|
+
await progress_task
|
281
|
+
|
282
|
+
await asyncio.sleep(1) # short delay to show the final status
|
283
|
+
|
284
|
+
# one more update
|
285
|
+
live.update(generate_table())
|
286
|
+
|
287
|
+
results = Results(survey=self.jobs.survey, data=self.results)
|
247
288
|
|
248
|
-
results = Results(survey=self.jobs.survey, data=self.results)
|
249
289
|
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
250
290
|
results.task_history = task_history
|
251
291
|
|
@@ -259,6 +299,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
259
299
|
for interview in self.total_interviews
|
260
300
|
if interview.has_exceptions
|
261
301
|
]
|
302
|
+
from edsl.jobs.Jobs import Jobs
|
303
|
+
|
262
304
|
results.failed_jobs = Jobs.from_interviews(
|
263
305
|
[interview for interview in failed_interviews]
|
264
306
|
)
|
@@ -81,7 +81,7 @@ class JobsRunnerStatusData:
|
|
81
81
|
>>> completed_tasks = []
|
82
82
|
>>> elapsed_time = 0
|
83
83
|
>>> JobsRunnerStatusData().generate_status_summary(completed_tasks, elapsed_time, interviews)
|
84
|
-
{'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA', 'model_queues': [{'model_name': '
|
84
|
+
{'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA', 'model_queues': [{'model_name': '...', 'TPM_limit_k': ..., 'RPM_limit_k': ..., 'num_tasks_waiting': 0, 'token_usage_info': [{'cache_status': 'new_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}, {'cache_status': 'cached_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}]}]}
|
85
85
|
"""
|
86
86
|
|
87
87
|
models_to_tokens = defaultdict(InterviewTokenUsage)
|
@@ -176,7 +176,7 @@ class JobsRunnerStatusData:
|
|
176
176
|
>>> model = interviews[0].model
|
177
177
|
>>> num_waiting = 0
|
178
178
|
>>> JobsRunnerStatusData()._get_model_info(model, num_waiting, models_to_tokens)
|
179
|
-
{'model_name': 'gpt-4-1106-preview', 'TPM_limit_k':
|
179
|
+
{'model_name': 'gpt-4-1106-preview', 'TPM_limit_k': ..., 'RPM_limit_k': ..., 'num_tasks_waiting': 0, 'token_usage_info': [{'cache_status': 'new_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}, {'cache_status': 'cached_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}]}
|
180
180
|
"""
|
181
181
|
|
182
182
|
prices = get_token_pricing(model.model)
|
@@ -234,4 +234,4 @@ class JobsRunnerStatusData:
|
|
234
234
|
if __name__ == "__main__":
|
235
235
|
import doctest
|
236
236
|
|
237
|
-
doctest.testmod()
|
237
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
from typing import Callable, Union, List
|
3
3
|
from collections import UserList, UserDict
|
4
|
-
|
4
|
+
import time
|
5
5
|
|
6
6
|
from edsl.jobs.buckets import ModelBuckets
|
7
7
|
from edsl.exceptions import InterviewErrorPriorTaskCanceled
|
@@ -132,7 +132,7 @@ class QuestionTaskCreator(UserList):
|
|
132
132
|
self.waiting = True
|
133
133
|
self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPACITY
|
134
134
|
|
135
|
-
await self.
|
135
|
+
await self.tokens_bucket.get_tokens(1)
|
136
136
|
|
137
137
|
self.task_status = TaskStatus.API_CALL_IN_PROGRESS
|
138
138
|
try:
|
@@ -151,6 +151,8 @@ class QuestionTaskCreator(UserList):
|
|
151
151
|
self.requests_bucket.add_tokens(1)
|
152
152
|
self.from_cache = True
|
153
153
|
|
154
|
+
_ = results.pop("cached_response", None)
|
155
|
+
|
154
156
|
tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
155
157
|
|
156
158
|
# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
|
@@ -142,7 +142,7 @@ class LanguageModel(
|
|
142
142
|
def has_valid_api_key(self) -> bool:
|
143
143
|
"""Check if the model has a valid API key.
|
144
144
|
|
145
|
-
>>> LanguageModel.example().has_valid_api_key()
|
145
|
+
>>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
|
146
146
|
True
|
147
147
|
|
148
148
|
This method is used to check if the model has a valid API key.
|
@@ -159,7 +159,9 @@ class LanguageModel(
|
|
159
159
|
|
160
160
|
def __hash__(self):
|
161
161
|
"""Allow the model to be used as a key in a dictionary."""
|
162
|
-
|
162
|
+
from edsl.utilities.utilities import dict_hash
|
163
|
+
|
164
|
+
return dict_hash(self.to_dict())
|
163
165
|
|
164
166
|
def __eq__(self, other):
|
165
167
|
"""Check is two models are the same.
|
@@ -207,8 +209,8 @@ class LanguageModel(
|
|
207
209
|
"""Model's tokens-per-minute limit.
|
208
210
|
|
209
211
|
>>> m = LanguageModel.example()
|
210
|
-
>>> m.TPM
|
211
|
-
|
212
|
+
>>> m.TPM > 0
|
213
|
+
True
|
212
214
|
"""
|
213
215
|
self._set_rate_limits()
|
214
216
|
return self._safety_factor * self.__rate_limits["tpm"]
|
@@ -285,28 +287,6 @@ class LanguageModel(
|
|
285
287
|
"""
|
286
288
|
raise NotImplementedError
|
287
289
|
|
288
|
-
def _update_response_with_tracking(
|
289
|
-
self, response: dict, start_time: int, cached_response=False, cache_key=None
|
290
|
-
):
|
291
|
-
"""Update the response with tracking information.
|
292
|
-
|
293
|
-
>>> m = LanguageModel.example()
|
294
|
-
>>> m._update_response_with_tracking(response={"response": "Hello"}, start_time=0, cached_response=False, cache_key=None)
|
295
|
-
{'response': 'Hello', 'elapsed_time': ..., 'timestamp': ..., 'cached_response': False, 'cache_key': None}
|
296
|
-
|
297
|
-
|
298
|
-
"""
|
299
|
-
end_time = time.time()
|
300
|
-
response.update(
|
301
|
-
{
|
302
|
-
"elapsed_time": end_time - start_time,
|
303
|
-
"timestamp": end_time,
|
304
|
-
"cached_response": cached_response,
|
305
|
-
"cache_key": cache_key,
|
306
|
-
}
|
307
|
-
)
|
308
|
-
return response
|
309
|
-
|
310
290
|
async def async_get_raw_response(
|
311
291
|
self,
|
312
292
|
user_prompt: str,
|
@@ -314,7 +294,7 @@ class LanguageModel(
|
|
314
294
|
cache,
|
315
295
|
iteration: int = 0,
|
316
296
|
encoded_image=None,
|
317
|
-
) -> dict
|
297
|
+
) -> tuple[dict, bool, str]:
|
318
298
|
"""Handle caching of responses.
|
319
299
|
|
320
300
|
:param user_prompt: The user's prompt.
|
@@ -322,8 +302,7 @@ class LanguageModel(
|
|
322
302
|
:param iteration: The iteration number.
|
323
303
|
:param cache: The cache to use.
|
324
304
|
|
325
|
-
If the cache isn't being used, it just returns a 'fresh' call to the LLM
|
326
|
-
but appends some tracking information to the response (using the _update_response_with_tracking method).
|
305
|
+
If the cache isn't being used, it just returns a 'fresh' call to the LLM.
|
327
306
|
But if cache is being used, it first checks the database to see if the response is already there.
|
328
307
|
If it is, it returns the cached response, but again appends some tracking information.
|
329
308
|
If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
|
@@ -334,7 +313,7 @@ class LanguageModel(
|
|
334
313
|
>>> from edsl import Cache
|
335
314
|
>>> m = LanguageModel.example(test_model = True)
|
336
315
|
>>> m.get_raw_response(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
337
|
-
{'message': '{"answer": "Hello world"}',
|
316
|
+
({'message': '{"answer": "Hello world"}'}, False, '24ff6ac2bc2f1729f817f261e0792577')
|
338
317
|
"""
|
339
318
|
start_time = time.time()
|
340
319
|
|
@@ -379,12 +358,7 @@ class LanguageModel(
|
|
379
358
|
)
|
380
359
|
cache_used = False
|
381
360
|
|
382
|
-
return
|
383
|
-
response=response,
|
384
|
-
start_time=start_time,
|
385
|
-
cached_response=cache_used,
|
386
|
-
cache_key=cache_key,
|
387
|
-
)
|
361
|
+
return response, cache_used, cache_key
|
388
362
|
|
389
363
|
get_raw_response = sync_wrapper(async_get_raw_response)
|
390
364
|
|
@@ -427,14 +401,18 @@ class LanguageModel(
|
|
427
401
|
if encoded_image:
|
428
402
|
params["encoded_image"] = encoded_image
|
429
403
|
|
430
|
-
raw_response = await self.async_get_raw_response(
|
404
|
+
raw_response, cache_used, cache_key = await self.async_get_raw_response(
|
405
|
+
**params
|
406
|
+
)
|
431
407
|
response = self.parse_response(raw_response)
|
432
408
|
|
433
409
|
try:
|
434
410
|
dict_response = json.loads(response)
|
435
411
|
except json.JSONDecodeError as e:
|
436
412
|
# TODO: Turn into logs to generate issues
|
437
|
-
dict_response, success = await repair(
|
413
|
+
dict_response, success = await repair(
|
414
|
+
bad_json=response, error_message=str(e), cache=cache
|
415
|
+
)
|
438
416
|
if not success:
|
439
417
|
raise Exception(
|
440
418
|
f"""Even the repair failed. The error was: {e}. The response was: {response}."""
|
@@ -442,7 +420,8 @@ class LanguageModel(
|
|
442
420
|
|
443
421
|
dict_response.update(
|
444
422
|
{
|
445
|
-
"
|
423
|
+
"cached_used": cache_used,
|
424
|
+
"cache_key": cache_key,
|
446
425
|
"usage": raw_response.get("usage", {}),
|
447
426
|
"raw_model_response": raw_response,
|
448
427
|
}
|
@@ -458,15 +437,18 @@ class LanguageModel(
|
|
458
437
|
#######################
|
459
438
|
# SERIALIZATION METHODS
|
460
439
|
#######################
|
440
|
+
def _to_dict(self) -> dict[str, Any]:
|
441
|
+
return {"model": self.model, "parameters": self.parameters}
|
442
|
+
|
461
443
|
@add_edsl_version
|
462
444
|
def to_dict(self) -> dict[str, Any]:
|
463
445
|
"""Convert instance to a dictionary.
|
464
446
|
|
465
447
|
>>> m = LanguageModel.example()
|
466
448
|
>>> m.to_dict()
|
467
|
-
{'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}}
|
449
|
+
{'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
468
450
|
"""
|
469
|
-
return
|
451
|
+
return self._to_dict()
|
470
452
|
|
471
453
|
@classmethod
|
472
454
|
@remove_edsl_version
|
@@ -519,8 +501,18 @@ class LanguageModel(
|
|
519
501
|
return table
|
520
502
|
|
521
503
|
@classmethod
|
522
|
-
def example(cls, test_model=False):
|
523
|
-
"""Return a default instance of the class.
|
504
|
+
def example(cls, test_model: bool = False, canned_response: str = "Hello world"):
|
505
|
+
"""Return a default instance of the class.
|
506
|
+
|
507
|
+
>>> from edsl.language_models import LanguageModel
|
508
|
+
>>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
|
509
|
+
>>> isinstance(m, LanguageModel)
|
510
|
+
True
|
511
|
+
>>> from edsl import QuestionFreeText
|
512
|
+
>>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
|
513
|
+
>>> q.by(m).run(cache = False).select('example').first()
|
514
|
+
'WOWZA!'
|
515
|
+
"""
|
524
516
|
from edsl import Model
|
525
517
|
|
526
518
|
class TestLanguageModelGood(LanguageModel):
|
@@ -533,7 +525,8 @@ class LanguageModel(
|
|
533
525
|
self, user_prompt: str, system_prompt: str
|
534
526
|
) -> dict[str, Any]:
|
535
527
|
await asyncio.sleep(0.1)
|
536
|
-
return {"message": """{"answer": "Hello world"}"""}
|
528
|
+
# return {"message": """{"answer": "Hello, world"}"""}
|
529
|
+
return {"message": f'{{"answer": "{canned_response}"}}'}
|
537
530
|
|
538
531
|
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
539
532
|
return raw_response["message"]
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
from collections import UserList
|
3
|
+
from edsl import Model
|
4
|
+
|
5
|
+
from edsl.language_models import LanguageModel
|
6
|
+
from edsl.Base import Base
|
7
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
8
|
+
from edsl.utilities import is_valid_variable_name
|
9
|
+
from edsl.utilities.utilities import dict_hash
|
10
|
+
|
11
|
+
|
12
|
+
class ModelList(Base, UserList):
|
13
|
+
def __init__(self, data: Optional[list] = None):
|
14
|
+
"""Initialize the ScenarioList class.
|
15
|
+
|
16
|
+
>>> from edsl import Model
|
17
|
+
>>> m = ModelList(Model.available())
|
18
|
+
|
19
|
+
"""
|
20
|
+
if data is not None:
|
21
|
+
super().__init__(data)
|
22
|
+
else:
|
23
|
+
super().__init__([])
|
24
|
+
|
25
|
+
@property
|
26
|
+
def names(self):
|
27
|
+
"""
|
28
|
+
|
29
|
+
>>> ModelList.example().names
|
30
|
+
{'...'}
|
31
|
+
"""
|
32
|
+
return set([model.model for model in self])
|
33
|
+
|
34
|
+
def rich_print(self):
|
35
|
+
pass
|
36
|
+
|
37
|
+
def __repr__(self):
|
38
|
+
return f"ModelList({super().__repr__()})"
|
39
|
+
|
40
|
+
def __hash__(self):
|
41
|
+
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
|
+
|
43
|
+
>>> hash(ModelList.example())
|
44
|
+
1423518243781418961
|
45
|
+
|
46
|
+
"""
|
47
|
+
from edsl.utilities.utilities import dict_hash
|
48
|
+
|
49
|
+
return dict_hash(self._to_dict(sort=True))
|
50
|
+
|
51
|
+
def _to_dict(self, sort=False):
|
52
|
+
if sort:
|
53
|
+
model_list = sorted([model for model in self], key=lambda x: hash(x))
|
54
|
+
return {"models": [model._to_dict() for model in model_list]}
|
55
|
+
else:
|
56
|
+
return {"models": [model._to_dict() for model in self]}
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def from_names(self, *args):
|
60
|
+
"""A a model list from a list of names"""
|
61
|
+
if len(args) == 1 and isinstance(args[0], list):
|
62
|
+
args = args[0]
|
63
|
+
return ModelList([Model(model_name) for model_name in args])
|
64
|
+
|
65
|
+
@add_edsl_version
|
66
|
+
def to_dict(self):
|
67
|
+
"""
|
68
|
+
Convert the ModelList to a dictionary.
|
69
|
+
>>> ModelList.example().to_dict()
|
70
|
+
{'models': [...], 'edsl_version': '...', 'edsl_class_name': 'ModelList'}
|
71
|
+
"""
|
72
|
+
return self._to_dict()
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
@remove_edsl_version
|
76
|
+
def from_dict(cls, data):
|
77
|
+
"""
|
78
|
+
Create a ModelList from a dictionary.
|
79
|
+
|
80
|
+
>>> newm = ModelList.from_dict(ModelList.example().to_dict())
|
81
|
+
>>> assert ModelList.example() == newm
|
82
|
+
"""
|
83
|
+
return cls(data=[LanguageModel.from_dict(model) for model in data["models"]])
|
84
|
+
|
85
|
+
def code(self):
|
86
|
+
pass
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def example(cl):
|
90
|
+
return ModelList([LanguageModel.example() for _ in range(3)])
|
91
|
+
|
92
|
+
|
93
|
+
if __name__ == "__main__":
|
94
|
+
import doctest
|
95
|
+
|
96
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
edsl/language_models/registry.py
CHANGED
@@ -38,6 +38,20 @@ class Model(metaclass=Meta):
|
|
38
38
|
factory = registry.create_model_factory(model_name)
|
39
39
|
return factory(*args, **kwargs)
|
40
40
|
|
41
|
+
@classmethod
|
42
|
+
def add_model(cls, service_name, model_name):
|
43
|
+
from edsl.inference_services.registry import default
|
44
|
+
|
45
|
+
registry = default
|
46
|
+
registry.add_model(service_name, model_name)
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def services(cls, registry=None):
|
50
|
+
from edsl.inference_services.registry import default
|
51
|
+
|
52
|
+
registry = registry or default
|
53
|
+
return [r._inference_service_ for r in registry.services]
|
54
|
+
|
41
55
|
@classmethod
|
42
56
|
def available(cls, search_term=None, name_only=False, registry=None):
|
43
57
|
from edsl.inference_services.registry import default
|