edsl 0.1.27.dev2__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. edsl/Base.py +99 -22
  2. edsl/BaseDiff.py +260 -0
  3. edsl/__init__.py +4 -0
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +26 -5
  6. edsl/agents/AgentList.py +62 -7
  7. edsl/agents/Invigilator.py +4 -9
  8. edsl/agents/InvigilatorBase.py +5 -5
  9. edsl/agents/descriptors.py +3 -1
  10. edsl/conjure/AgentConstructionMixin.py +152 -0
  11. edsl/conjure/Conjure.py +56 -0
  12. edsl/conjure/InputData.py +628 -0
  13. edsl/conjure/InputDataCSV.py +48 -0
  14. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  15. edsl/conjure/InputDataPyRead.py +91 -0
  16. edsl/conjure/InputDataSPSS.py +8 -0
  17. edsl/conjure/InputDataStata.py +8 -0
  18. edsl/conjure/QuestionOptionMixin.py +76 -0
  19. edsl/conjure/QuestionTypeMixin.py +23 -0
  20. edsl/conjure/RawQuestion.py +65 -0
  21. edsl/conjure/SurveyResponses.py +7 -0
  22. edsl/conjure/__init__.py +9 -4
  23. edsl/conjure/examples/placeholder.txt +0 -0
  24. edsl/conjure/naming_utilities.py +263 -0
  25. edsl/conjure/utilities.py +165 -28
  26. edsl/conversation/Conversation.py +238 -0
  27. edsl/conversation/car_buying.py +58 -0
  28. edsl/conversation/mug_negotiation.py +81 -0
  29. edsl/conversation/next_speaker_utilities.py +93 -0
  30. edsl/coop/coop.py +191 -12
  31. edsl/coop/utils.py +20 -2
  32. edsl/data/Cache.py +55 -17
  33. edsl/data/CacheHandler.py +10 -9
  34. edsl/inference_services/AnthropicService.py +1 -0
  35. edsl/inference_services/DeepInfraService.py +20 -13
  36. edsl/inference_services/GoogleService.py +7 -1
  37. edsl/inference_services/InferenceServicesCollection.py +33 -7
  38. edsl/inference_services/OpenAIService.py +17 -10
  39. edsl/inference_services/models_available_cache.py +69 -0
  40. edsl/inference_services/rate_limits_cache.py +25 -0
  41. edsl/inference_services/write_available.py +10 -0
  42. edsl/jobs/Jobs.py +240 -36
  43. edsl/jobs/buckets/BucketCollection.py +9 -3
  44. edsl/jobs/interviews/Interview.py +4 -1
  45. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +24 -10
  46. edsl/jobs/interviews/retry_management.py +4 -4
  47. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -45
  48. edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
  49. edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
  50. edsl/language_models/LanguageModel.py +37 -44
  51. edsl/language_models/ModelList.py +96 -0
  52. edsl/language_models/registry.py +14 -0
  53. edsl/language_models/repair.py +95 -24
  54. edsl/notebooks/Notebook.py +119 -31
  55. edsl/questions/QuestionBase.py +109 -12
  56. edsl/questions/descriptors.py +5 -2
  57. edsl/questions/question_registry.py +7 -0
  58. edsl/results/Result.py +20 -8
  59. edsl/results/Results.py +85 -11
  60. edsl/results/ResultsDBMixin.py +3 -6
  61. edsl/results/ResultsExportMixin.py +47 -16
  62. edsl/results/ResultsToolsMixin.py +5 -5
  63. edsl/scenarios/Scenario.py +59 -5
  64. edsl/scenarios/ScenarioList.py +97 -40
  65. edsl/study/ObjectEntry.py +97 -0
  66. edsl/study/ProofOfWork.py +110 -0
  67. edsl/study/SnapShot.py +77 -0
  68. edsl/study/Study.py +491 -0
  69. edsl/study/__init__.py +2 -0
  70. edsl/surveys/Survey.py +79 -31
  71. edsl/surveys/SurveyExportMixin.py +21 -3
  72. edsl/utilities/__init__.py +1 -0
  73. edsl/utilities/gcp_bucket/__init__.py +0 -0
  74. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  75. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  76. edsl/utilities/interface.py +24 -28
  77. edsl/utilities/repair_functions.py +28 -0
  78. edsl/utilities/utilities.py +57 -2
  79. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/METADATA +43 -17
  80. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/RECORD +83 -55
  81. edsl-0.1.28.dist-info/entry_points.txt +3 -0
  82. edsl/conjure/RawResponseColumn.py +0 -327
  83. edsl/conjure/SurveyBuilder.py +0 -308
  84. edsl/conjure/SurveyBuilderCSV.py +0 -78
  85. edsl/conjure/SurveyBuilderSPSS.py +0 -118
  86. edsl/data/RemoteDict.py +0 -103
  87. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/LICENSE +0 -0
  88. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/WHEEL +0 -0
@@ -14,7 +14,8 @@ from edsl.results import Results, Result
14
14
 
15
15
  from edsl.jobs.interviews.Interview import Interview
16
16
  from edsl.utilities.decorators import jupyter_nb_handler
17
- from edsl.jobs.Jobs import Jobs
17
+
18
+ # from edsl.jobs.Jobs import Jobs
18
19
  from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
19
20
  from edsl.language_models import LanguageModel
20
21
  from edsl.data.Cache import Cache
@@ -22,6 +23,8 @@ from edsl.data.Cache import Cache
22
23
  from edsl.jobs.tasks.TaskHistory import TaskHistory
23
24
  from edsl.jobs.buckets.BucketCollection import BucketCollection
24
25
 
26
+ import time
27
+
25
28
 
26
29
  class JobsRunnerAsyncio(JobsRunnerStatusMixin):
27
30
  """A class for running a collection of interviews asynchronously.
@@ -32,12 +35,12 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
32
35
 
33
36
  def __init__(self, jobs: Jobs):
34
37
  self.jobs = jobs
35
-
38
+ # this creates the interviews, which can take a while
36
39
  self.interviews: List["Interview"] = jobs.interviews()
37
40
  self.bucket_collection: "BucketCollection" = jobs.bucket_collection
38
41
  self.total_interviews: List["Interview"] = []
39
42
 
40
- async def run_async(
43
+ async def run_async_generator(
41
44
  self,
42
45
  cache: Cache,
43
46
  n: int = 1,
@@ -96,6 +99,20 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
96
99
  ) # set the cache for the first interview
97
100
  self.total_interviews.append(interview)
98
101
 
102
+ async def run_async(self, cache=None) -> Results:
103
+ if cache is None:
104
+ self.cache = Cache()
105
+ else:
106
+ self.cache = cache
107
+ data = []
108
+ async for result in self.run_async_generator(cache=self.cache):
109
+ data.append(result)
110
+ return Results(survey=self.jobs.survey, data=data)
111
+
112
+ def simple_run(self):
113
+ data = asyncio.run(self.run_async())
114
+ return Results(survey=self.jobs.survey, data=data)
115
+
99
116
  async def _build_interview_task(
100
117
  self,
101
118
  *,
@@ -187,39 +204,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
187
204
  self.cache = cache
188
205
  self.sidecar_model = sidecar_model
189
206
 
190
- def generate_table():
191
- return self.status_table(self.results, self.elapsed_time)
192
-
193
- @contextmanager
194
- def no_op_cm():
195
- """A no-op context manager with a dummy update method."""
196
- yield DummyLive()
197
-
198
- class DummyLive:
199
- def update(self, *args, **kwargs):
200
- """A dummy update method that does nothing."""
201
- pass
202
-
203
- progress_bar_context = (
204
- Live(generate_table(), console=console, refresh_per_second=5)
205
- if progress_bar
206
- else no_op_cm()
207
- )
208
-
209
- with cache as c:
210
- with progress_bar_context as live:
211
-
212
- async def update_progress_bar():
213
- """Updates the progress bar at fixed intervals."""
214
- while True:
215
- live.update(generate_table())
216
- await asyncio.sleep(0.1) # Update interval
217
- if self.completed:
218
- break
207
+ if not progress_bar:
208
+ # print("Running without progress bar")
209
+ with cache as c:
219
210
 
220
211
  async def process_results():
221
212
  """Processes results from interviews."""
222
- async for result in self.run_async(
213
+ async for result in self.run_async_generator(
223
214
  n=n,
224
215
  debug=debug,
225
216
  stop_on_exception=stop_on_exception,
@@ -227,25 +218,74 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
227
218
  sidecar_model=sidecar_model,
228
219
  ):
229
220
  self.results.append(result)
230
- live.update(generate_table())
231
221
  self.completed = True
232
222
 
233
- progress_task = asyncio.create_task(update_progress_bar())
223
+ await asyncio.gather(process_results())
224
+
225
+ results = Results(survey=self.jobs.survey, data=self.results)
226
+ else:
227
+ # print("Running with progress bar")
228
+
229
+ def generate_table():
230
+ return self.status_table(self.results, self.elapsed_time)
234
231
 
235
- try:
236
- await asyncio.gather(process_results(), progress_task)
237
- except asyncio.CancelledError:
232
+ @contextmanager
233
+ def no_op_cm():
234
+ """A no-op context manager with a dummy update method."""
235
+ yield DummyLive()
236
+
237
+ class DummyLive:
238
+ def update(self, *args, **kwargs):
239
+ """A dummy update method that does nothing."""
238
240
  pass
239
- finally:
240
- progress_task.cancel() # Cancel the progress_task when process_results is done
241
- await progress_task
242
241
 
243
- await asyncio.sleep(1) # short delay to show the final status
242
+ progress_bar_context = (
243
+ Live(generate_table(), console=console, refresh_per_second=5)
244
+ if progress_bar
245
+ else no_op_cm()
246
+ )
244
247
 
245
- # one more update
246
- live.update(generate_table())
248
+ with cache as c:
249
+ with progress_bar_context as live:
250
+
251
+ async def update_progress_bar():
252
+ """Updates the progress bar at fixed intervals."""
253
+ while True:
254
+ live.update(generate_table())
255
+ await asyncio.sleep(0.00001) # Update interval
256
+ if self.completed:
257
+ break
258
+
259
+ async def process_results():
260
+ """Processes results from interviews."""
261
+ async for result in self.run_async_generator(
262
+ n=n,
263
+ debug=debug,
264
+ stop_on_exception=stop_on_exception,
265
+ cache=c,
266
+ sidecar_model=sidecar_model,
267
+ ):
268
+ self.results.append(result)
269
+ live.update(generate_table())
270
+ self.completed = True
271
+
272
+ progress_task = asyncio.create_task(update_progress_bar())
273
+
274
+ try:
275
+ await asyncio.gather(process_results(), progress_task)
276
+ except asyncio.CancelledError:
277
+ pass
278
+ finally:
279
+ progress_task.cancel() # Cancel the progress_task when process_results is done
280
+ await progress_task
281
+
282
+ await asyncio.sleep(1) # short delay to show the final status
283
+
284
+ # one more update
285
+ live.update(generate_table())
286
+
287
+ results = Results(survey=self.jobs.survey, data=self.results)
247
288
 
248
- results = Results(survey=self.jobs.survey, data=self.results)
249
289
  task_history = TaskHistory(self.total_interviews, include_traceback=False)
250
290
  results.task_history = task_history
251
291
 
@@ -259,6 +299,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
259
299
  for interview in self.total_interviews
260
300
  if interview.has_exceptions
261
301
  ]
302
+ from edsl.jobs.Jobs import Jobs
303
+
262
304
  results.failed_jobs = Jobs.from_interviews(
263
305
  [interview for interview in failed_interviews]
264
306
  )
@@ -81,7 +81,7 @@ class JobsRunnerStatusData:
81
81
  >>> completed_tasks = []
82
82
  >>> elapsed_time = 0
83
83
  >>> JobsRunnerStatusData().generate_status_summary(completed_tasks, elapsed_time, interviews)
84
- {'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA', 'model_queues': [{'model_name': 'gpt-4-1106-preview', 'TPM_limit_k': 1600.0, 'RPM_limit_k': 8.0, 'num_tasks_waiting': 0, 'token_usage_info': [{'cache_status': 'new_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}, {'cache_status': 'cached_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}]}]}
84
+ {'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA', 'model_queues': [{'model_name': '...', 'TPM_limit_k': ..., 'RPM_limit_k': ..., 'num_tasks_waiting': 0, 'token_usage_info': [{'cache_status': 'new_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}, {'cache_status': 'cached_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}]}]}
85
85
  """
86
86
 
87
87
  models_to_tokens = defaultdict(InterviewTokenUsage)
@@ -176,7 +176,7 @@ class JobsRunnerStatusData:
176
176
  >>> model = interviews[0].model
177
177
  >>> num_waiting = 0
178
178
  >>> JobsRunnerStatusData()._get_model_info(model, num_waiting, models_to_tokens)
179
- {'model_name': 'gpt-4-1106-preview', 'TPM_limit_k': 1600.0, 'RPM_limit_k': 8.0, 'num_tasks_waiting': 0, 'token_usage_info': [{'cache_status': 'new_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}, {'cache_status': 'cached_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}]}
179
+ {'model_name': 'gpt-4-1106-preview', 'TPM_limit_k': ..., 'RPM_limit_k': ..., 'num_tasks_waiting': 0, 'token_usage_info': [{'cache_status': 'new_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}, {'cache_status': 'cached_token_usage', 'details': [{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], 'cost': '$0.00000'}]}
180
180
  """
181
181
 
182
182
  prices = get_token_pricing(model.model)
@@ -234,4 +234,4 @@ class JobsRunnerStatusData:
234
234
  if __name__ == "__main__":
235
235
  import doctest
236
236
 
237
- doctest.testmod()
237
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  from typing import Callable, Union, List
3
3
  from collections import UserList, UserDict
4
-
4
+ import time
5
5
 
6
6
  from edsl.jobs.buckets import ModelBuckets
7
7
  from edsl.exceptions import InterviewErrorPriorTaskCanceled
@@ -132,7 +132,7 @@ class QuestionTaskCreator(UserList):
132
132
  self.waiting = True
133
133
  self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPACITY
134
134
 
135
- await self.requests_bucket.get_tokens(1)
135
+ await self.tokens_bucket.get_tokens(1)
136
136
 
137
137
  self.task_status = TaskStatus.API_CALL_IN_PROGRESS
138
138
  try:
@@ -151,6 +151,8 @@ class QuestionTaskCreator(UserList):
151
151
  self.requests_bucket.add_tokens(1)
152
152
  self.from_cache = True
153
153
 
154
+ _ = results.pop("cached_response", None)
155
+
154
156
  tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
155
157
 
156
158
  # TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
@@ -142,7 +142,7 @@ class LanguageModel(
142
142
  def has_valid_api_key(self) -> bool:
143
143
  """Check if the model has a valid API key.
144
144
 
145
- >>> LanguageModel.example().has_valid_api_key()
145
+ >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
146
146
  True
147
147
 
148
148
  This method is used to check if the model has a valid API key.
@@ -159,7 +159,9 @@ class LanguageModel(
159
159
 
160
160
  def __hash__(self):
161
161
  """Allow the model to be used as a key in a dictionary."""
162
- return hash(self.model + str(self.parameters))
162
+ from edsl.utilities.utilities import dict_hash
163
+
164
+ return dict_hash(self.to_dict())
163
165
 
164
166
  def __eq__(self, other):
165
167
  """Check is two models are the same.
@@ -207,8 +209,8 @@ class LanguageModel(
207
209
  """Model's tokens-per-minute limit.
208
210
 
209
211
  >>> m = LanguageModel.example()
210
- >>> m.TPM
211
- 1600000.0
212
+ >>> m.TPM > 0
213
+ True
212
214
  """
213
215
  self._set_rate_limits()
214
216
  return self._safety_factor * self.__rate_limits["tpm"]
@@ -285,28 +287,6 @@ class LanguageModel(
285
287
  """
286
288
  raise NotImplementedError
287
289
 
288
- def _update_response_with_tracking(
289
- self, response: dict, start_time: int, cached_response=False, cache_key=None
290
- ):
291
- """Update the response with tracking information.
292
-
293
- >>> m = LanguageModel.example()
294
- >>> m._update_response_with_tracking(response={"response": "Hello"}, start_time=0, cached_response=False, cache_key=None)
295
- {'response': 'Hello', 'elapsed_time': ..., 'timestamp': ..., 'cached_response': False, 'cache_key': None}
296
-
297
-
298
- """
299
- end_time = time.time()
300
- response.update(
301
- {
302
- "elapsed_time": end_time - start_time,
303
- "timestamp": end_time,
304
- "cached_response": cached_response,
305
- "cache_key": cache_key,
306
- }
307
- )
308
- return response
309
-
310
290
  async def async_get_raw_response(
311
291
  self,
312
292
  user_prompt: str,
@@ -314,7 +294,7 @@ class LanguageModel(
314
294
  cache,
315
295
  iteration: int = 0,
316
296
  encoded_image=None,
317
- ) -> dict[str, Any]:
297
+ ) -> tuple[dict, bool, str]:
318
298
  """Handle caching of responses.
319
299
 
320
300
  :param user_prompt: The user's prompt.
@@ -322,8 +302,7 @@ class LanguageModel(
322
302
  :param iteration: The iteration number.
323
303
  :param cache: The cache to use.
324
304
 
325
- If the cache isn't being used, it just returns a 'fresh' call to the LLM,
326
- but appends some tracking information to the response (using the _update_response_with_tracking method).
305
+ If the cache isn't being used, it just returns a 'fresh' call to the LLM.
327
306
  But if cache is being used, it first checks the database to see if the response is already there.
328
307
  If it is, it returns the cached response, but again appends some tracking information.
329
308
  If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
@@ -334,7 +313,7 @@ class LanguageModel(
334
313
  >>> from edsl import Cache
335
314
  >>> m = LanguageModel.example(test_model = True)
336
315
  >>> m.get_raw_response(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
337
- {'message': '{"answer": "Hello world"}', 'elapsed_time': ..., 'timestamp': ..., 'cached_response': False, 'cache_key': '24ff6ac2bc2f1729f817f261e0792577'}
316
+ ({'message': '{"answer": "Hello world"}'}, False, '24ff6ac2bc2f1729f817f261e0792577')
338
317
  """
339
318
  start_time = time.time()
340
319
 
@@ -379,12 +358,7 @@ class LanguageModel(
379
358
  )
380
359
  cache_used = False
381
360
 
382
- return self._update_response_with_tracking(
383
- response=response,
384
- start_time=start_time,
385
- cached_response=cache_used,
386
- cache_key=cache_key,
387
- )
361
+ return response, cache_used, cache_key
388
362
 
389
363
  get_raw_response = sync_wrapper(async_get_raw_response)
390
364
 
@@ -427,14 +401,18 @@ class LanguageModel(
427
401
  if encoded_image:
428
402
  params["encoded_image"] = encoded_image
429
403
 
430
- raw_response = await self.async_get_raw_response(**params)
404
+ raw_response, cache_used, cache_key = await self.async_get_raw_response(
405
+ **params
406
+ )
431
407
  response = self.parse_response(raw_response)
432
408
 
433
409
  try:
434
410
  dict_response = json.loads(response)
435
411
  except json.JSONDecodeError as e:
436
412
  # TODO: Turn into logs to generate issues
437
- dict_response, success = await repair(response, str(e))
413
+ dict_response, success = await repair(
414
+ bad_json=response, error_message=str(e), cache=cache
415
+ )
438
416
  if not success:
439
417
  raise Exception(
440
418
  f"""Even the repair failed. The error was: {e}. The response was: {response}."""
@@ -442,7 +420,8 @@ class LanguageModel(
442
420
 
443
421
  dict_response.update(
444
422
  {
445
- "cached_response": raw_response["cached_response"],
423
+ "cached_used": cache_used,
424
+ "cache_key": cache_key,
446
425
  "usage": raw_response.get("usage", {}),
447
426
  "raw_model_response": raw_response,
448
427
  }
@@ -458,15 +437,18 @@ class LanguageModel(
458
437
  #######################
459
438
  # SERIALIZATION METHODS
460
439
  #######################
440
+ def _to_dict(self) -> dict[str, Any]:
441
+ return {"model": self.model, "parameters": self.parameters}
442
+
461
443
  @add_edsl_version
462
444
  def to_dict(self) -> dict[str, Any]:
463
445
  """Convert instance to a dictionary.
464
446
 
465
447
  >>> m = LanguageModel.example()
466
448
  >>> m.to_dict()
467
- {'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}}
449
+ {'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
468
450
  """
469
- return {"model": self.model, "parameters": self.parameters}
451
+ return self._to_dict()
470
452
 
471
453
  @classmethod
472
454
  @remove_edsl_version
@@ -519,8 +501,18 @@ class LanguageModel(
519
501
  return table
520
502
 
521
503
  @classmethod
522
- def example(cls, test_model=False):
523
- """Return a default instance of the class."""
504
+ def example(cls, test_model: bool = False, canned_response: str = "Hello world"):
505
+ """Return a default instance of the class.
506
+
507
+ >>> from edsl.language_models import LanguageModel
508
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
509
+ >>> isinstance(m, LanguageModel)
510
+ True
511
+ >>> from edsl import QuestionFreeText
512
+ >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
513
+ >>> q.by(m).run(cache = False).select('example').first()
514
+ 'WOWZA!'
515
+ """
524
516
  from edsl import Model
525
517
 
526
518
  class TestLanguageModelGood(LanguageModel):
@@ -533,7 +525,8 @@ class LanguageModel(
533
525
  self, user_prompt: str, system_prompt: str
534
526
  ) -> dict[str, Any]:
535
527
  await asyncio.sleep(0.1)
536
- return {"message": """{"answer": "Hello world"}"""}
528
+ # return {"message": """{"answer": "Hello, world"}"""}
529
+ return {"message": f'{{"answer": "{canned_response}"}}'}
537
530
 
538
531
  def parse_response(self, raw_response: dict[str, Any]) -> str:
539
532
  return raw_response["message"]
@@ -0,0 +1,96 @@
1
+ from typing import Optional
2
+ from collections import UserList
3
+ from edsl import Model
4
+
5
+ from edsl.language_models import LanguageModel
6
+ from edsl.Base import Base
7
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
8
+ from edsl.utilities import is_valid_variable_name
9
+ from edsl.utilities.utilities import dict_hash
10
+
11
+
12
+ class ModelList(Base, UserList):
13
+ def __init__(self, data: Optional[list] = None):
14
+ """Initialize the ScenarioList class.
15
+
16
+ >>> from edsl import Model
17
+ >>> m = ModelList(Model.available())
18
+
19
+ """
20
+ if data is not None:
21
+ super().__init__(data)
22
+ else:
23
+ super().__init__([])
24
+
25
+ @property
26
+ def names(self):
27
+ """
28
+
29
+ >>> ModelList.example().names
30
+ {'...'}
31
+ """
32
+ return set([model.model for model in self])
33
+
34
+ def rich_print(self):
35
+ pass
36
+
37
+ def __repr__(self):
38
+ return f"ModelList({super().__repr__()})"
39
+
40
+ def __hash__(self):
41
+ """Return a hash of the ModelList. This is used for comparison of ModelLists.
42
+
43
+ >>> hash(ModelList.example())
44
+ 1423518243781418961
45
+
46
+ """
47
+ from edsl.utilities.utilities import dict_hash
48
+
49
+ return dict_hash(self._to_dict(sort=True))
50
+
51
+ def _to_dict(self, sort=False):
52
+ if sort:
53
+ model_list = sorted([model for model in self], key=lambda x: hash(x))
54
+ return {"models": [model._to_dict() for model in model_list]}
55
+ else:
56
+ return {"models": [model._to_dict() for model in self]}
57
+
58
+ @classmethod
59
+ def from_names(self, *args):
60
+ """A a model list from a list of names"""
61
+ if len(args) == 1 and isinstance(args[0], list):
62
+ args = args[0]
63
+ return ModelList([Model(model_name) for model_name in args])
64
+
65
+ @add_edsl_version
66
+ def to_dict(self):
67
+ """
68
+ Convert the ModelList to a dictionary.
69
+ >>> ModelList.example().to_dict()
70
+ {'models': [...], 'edsl_version': '...', 'edsl_class_name': 'ModelList'}
71
+ """
72
+ return self._to_dict()
73
+
74
+ @classmethod
75
+ @remove_edsl_version
76
+ def from_dict(cls, data):
77
+ """
78
+ Create a ModelList from a dictionary.
79
+
80
+ >>> newm = ModelList.from_dict(ModelList.example().to_dict())
81
+ >>> assert ModelList.example() == newm
82
+ """
83
+ return cls(data=[LanguageModel.from_dict(model) for model in data["models"]])
84
+
85
+ def code(self):
86
+ pass
87
+
88
+ @classmethod
89
+ def example(cl):
90
+ return ModelList([LanguageModel.example() for _ in range(3)])
91
+
92
+
93
+ if __name__ == "__main__":
94
+ import doctest
95
+
96
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -38,6 +38,20 @@ class Model(metaclass=Meta):
38
38
  factory = registry.create_model_factory(model_name)
39
39
  return factory(*args, **kwargs)
40
40
 
41
+ @classmethod
42
+ def add_model(cls, service_name, model_name):
43
+ from edsl.inference_services.registry import default
44
+
45
+ registry = default
46
+ registry.add_model(service_name, model_name)
47
+
48
+ @classmethod
49
+ def services(cls, registry=None):
50
+ from edsl.inference_services.registry import default
51
+
52
+ registry = registry or default
53
+ return [r._inference_service_ for r in registry.services]
54
+
41
55
  @classmethod
42
56
  def available(cls, search_term=None, name_only=False, registry=None):
43
57
  from edsl.inference_services.registry import default