edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +28 -0
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -16
  5. edsl/agents/Invigilator.py +13 -14
  6. edsl/agents/InvigilatorBase.py +4 -1
  7. edsl/agents/PromptConstructor.py +42 -22
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +18 -5
  10. edsl/auto/StageBase.py +53 -40
  11. edsl/auto/StageQuestions.py +2 -1
  12. edsl/auto/utilities.py +0 -6
  13. edsl/coop/coop.py +21 -5
  14. edsl/data/Cache.py +29 -18
  15. edsl/data/CacheHandler.py +0 -2
  16. edsl/data/RemoteCacheSync.py +154 -46
  17. edsl/data/hack.py +10 -0
  18. edsl/enums.py +7 -0
  19. edsl/inference_services/AnthropicService.py +38 -16
  20. edsl/inference_services/AvailableModelFetcher.py +7 -1
  21. edsl/inference_services/GoogleService.py +5 -1
  22. edsl/inference_services/InferenceServicesCollection.py +18 -2
  23. edsl/inference_services/OpenAIService.py +46 -31
  24. edsl/inference_services/TestService.py +1 -3
  25. edsl/inference_services/TogetherAIService.py +5 -3
  26. edsl/inference_services/data_structures.py +74 -2
  27. edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
  28. edsl/jobs/FetchInvigilator.py +10 -3
  29. edsl/jobs/InterviewsConstructor.py +6 -4
  30. edsl/jobs/Jobs.py +299 -233
  31. edsl/jobs/JobsChecks.py +2 -2
  32. edsl/jobs/JobsPrompts.py +1 -1
  33. edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
  34. edsl/jobs/async_interview_runner.py +138 -0
  35. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  36. edsl/jobs/data_structures.py +120 -0
  37. edsl/jobs/interviews/Interview.py +80 -42
  38. edsl/jobs/results_exceptions_handler.py +98 -0
  39. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
  40. edsl/jobs/runners/JobsRunnerStatus.py +131 -164
  41. edsl/jobs/tasks/TaskHistory.py +24 -3
  42. edsl/language_models/LanguageModel.py +59 -4
  43. edsl/language_models/ModelList.py +19 -8
  44. edsl/language_models/__init__.py +1 -1
  45. edsl/language_models/model.py +256 -0
  46. edsl/language_models/repair.py +1 -1
  47. edsl/questions/QuestionBase.py +35 -26
  48. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  49. edsl/questions/QuestionBudget.py +1 -1
  50. edsl/questions/QuestionCheckBox.py +2 -2
  51. edsl/questions/QuestionExtract.py +5 -7
  52. edsl/questions/QuestionFreeText.py +1 -1
  53. edsl/questions/QuestionList.py +9 -15
  54. edsl/questions/QuestionMatrix.py +1 -1
  55. edsl/questions/QuestionMultipleChoice.py +1 -1
  56. edsl/questions/QuestionNumerical.py +1 -1
  57. edsl/questions/QuestionRank.py +1 -1
  58. edsl/questions/SimpleAskMixin.py +1 -1
  59. edsl/questions/__init__.py +1 -1
  60. edsl/questions/data_structures.py +20 -0
  61. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
  62. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
  63. edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
  64. edsl/results/DatasetExportMixin.py +60 -119
  65. edsl/results/Result.py +109 -3
  66. edsl/results/Results.py +50 -39
  67. edsl/results/file_exports.py +252 -0
  68. edsl/scenarios/ScenarioList.py +35 -7
  69. edsl/surveys/Survey.py +71 -20
  70. edsl/test_h +1 -0
  71. edsl/utilities/gcp_bucket/example.py +50 -0
  72. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
  73. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
  74. edsl/language_models/registry.py +0 -180
  75. /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
  76. /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
  77. /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
  78. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  79. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  80. /edsl/results/{Selector.py → results_selector.py} +0 -0
  81. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  82. /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
  83. /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
  84. /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
  85. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -3,21 +3,12 @@ from __future__ import annotations
3
3
  import os
4
4
  import time
5
5
  import requests
6
- import warnings
7
6
  from abc import ABC, abstractmethod
8
7
  from dataclasses import dataclass
9
-
10
- from typing import Any, List, DefaultDict, Optional, Dict
11
8
  from collections import defaultdict
9
+ from typing import Any, Dict, Optional
12
10
  from uuid import UUID
13
11
 
14
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
15
-
16
- InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
17
-
18
- from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
19
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
20
-
21
12
 
22
13
  @dataclass
23
14
  class ModelInfo:
@@ -28,11 +19,44 @@ class ModelInfo:
28
19
  token_usage_info: dict
29
20
 
30
21
 
31
- @dataclass
32
- class ModelTokenUsageStats:
33
- token_usage_type: str
34
- details: List[dict]
35
- cost: str
22
+ class StatisticsTracker:
23
+ def __init__(self, total_interviews: int, distinct_models: list[str]):
24
+ self.start_time = time.time()
25
+ self.total_interviews = total_interviews
26
+ self.completed_count = 0
27
+ self.completed_by_model = defaultdict(int)
28
+ self.distinct_models = distinct_models
29
+ self.total_exceptions = 0
30
+ self.unfixed_exceptions = 0
31
+
32
+ def add_completed_interview(
33
+ self, model: str, num_exceptions: int = 0, num_unfixed: int = 0
34
+ ):
35
+ self.completed_count += 1
36
+ self.completed_by_model[model] += 1
37
+ self.total_exceptions += num_exceptions
38
+ self.unfixed_exceptions += num_unfixed
39
+
40
+ def get_elapsed_time(self) -> float:
41
+ return time.time() - self.start_time
42
+
43
+ def get_average_time_per_interview(self) -> float:
44
+ return (
45
+ self.get_elapsed_time() / self.completed_count
46
+ if self.completed_count > 0
47
+ else 0
48
+ )
49
+
50
+ def get_throughput(self) -> float:
51
+ elapsed = self.get_elapsed_time()
52
+ return self.completed_count / elapsed if elapsed > 0 else 0
53
+
54
+ def get_estimated_time_remaining(self) -> float:
55
+ if self.completed_count == 0:
56
+ return 0
57
+ avg_time = self.get_average_time_per_interview()
58
+ remaining = self.total_interviews - self.completed_count
59
+ return avg_time * remaining
36
60
 
37
61
 
38
62
  class JobsRunnerStatusBase(ABC):
@@ -46,48 +70,39 @@ class JobsRunnerStatusBase(ABC):
46
70
  api_key: str = None,
47
71
  ):
48
72
  self.jobs_runner = jobs_runner
49
-
50
- # The uuid of the job on Coop
51
73
  self.job_uuid = job_uuid
52
-
53
74
  self.base_url = f"{endpoint_url}"
54
-
55
- self.start_time = time.time()
56
- self.completed_interviews = []
57
75
  self.refresh_rate = refresh_rate
58
76
  self.statistics = [
59
77
  "elapsed_time",
60
78
  "total_interviews_requested",
61
79
  "completed_interviews",
62
- # "percent_complete",
63
80
  "average_time_per_interview",
64
- # "task_remaining",
65
81
  "estimated_time_remaining",
66
82
  "exceptions",
67
83
  "unfixed_exceptions",
68
84
  "throughput",
69
85
  ]
70
- self.num_total_interviews = n * len(self.jobs_runner.interviews)
86
+ self.num_total_interviews = n * len(self.jobs_runner)
71
87
 
72
88
  self.distinct_models = list(
73
- set(i.model.model for i in self.jobs_runner.interviews)
89
+ set(model.model for model in self.jobs_runner.jobs.models)
74
90
  )
75
91
 
76
- self.completed_interview_by_model = defaultdict(list)
92
+ self.stats_tracker = StatisticsTracker(
93
+ total_interviews=self.num_total_interviews,
94
+ distinct_models=self.distinct_models,
95
+ )
77
96
 
78
97
  self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
79
98
 
80
99
  @abstractmethod
81
100
  def has_ep_api_key(self):
82
- """
83
- Checks if the user has an Expected Parrot API key.
84
- """
101
+ """Checks if the user has an Expected Parrot API key."""
85
102
  pass
86
103
 
87
104
  def get_status_dict(self) -> Dict[str, Any]:
88
- """
89
- Converts current status into a JSON-serializable dictionary.
90
- """
105
+ """Converts current status into a JSON-serializable dictionary."""
91
106
  # Get all statistics
92
107
  stats = {}
93
108
  for stat_name in self.statistics:
@@ -95,38 +110,41 @@ class JobsRunnerStatusBase(ABC):
95
110
  name, value = list(stat.items())[0]
96
111
  stats[name] = value
97
112
 
98
- # Calculate overall progress
99
- total_interviews = len(self.jobs_runner.total_interviews)
100
- completed = len(self.completed_interviews)
101
-
102
113
  # Get model-specific progress
103
114
  model_progress = {}
115
+ target_per_model = int(self.num_total_interviews / len(self.distinct_models))
116
+
104
117
  for model in self.distinct_models:
105
- completed_for_model = len(self.completed_interview_by_model[model])
106
- target_for_model = int(
107
- self.num_total_interviews / len(self.distinct_models)
108
- )
118
+ completed = self.stats_tracker.completed_by_model[model]
109
119
  model_progress[model] = {
110
- "completed": completed_for_model,
111
- "total": target_for_model,
120
+ "completed": completed,
121
+ "total": target_per_model,
112
122
  "percent": (
113
- (completed_for_model / target_for_model * 100)
114
- if target_for_model > 0
115
- else 0
123
+ (completed / target_per_model * 100) if target_per_model > 0 else 0
116
124
  ),
117
125
  }
118
126
 
119
127
  status_dict = {
120
128
  "overall_progress": {
121
- "completed": completed,
122
- "total": total_interviews,
129
+ "completed": self.stats_tracker.completed_count,
130
+ "total": self.num_total_interviews,
123
131
  "percent": (
124
- (completed / total_interviews * 100) if total_interviews > 0 else 0
132
+ (
133
+ self.stats_tracker.completed_count
134
+ / self.num_total_interviews
135
+ * 100
136
+ )
137
+ if self.num_total_interviews > 0
138
+ else 0
125
139
  ),
126
140
  },
127
141
  "language_model_progress": model_progress,
128
142
  "statistics": stats,
129
- "status": "completed" if completed >= total_interviews else "running",
143
+ "status": (
144
+ "completed"
145
+ if self.stats_tracker.completed_count >= self.num_total_interviews
146
+ else "running"
147
+ ),
130
148
  }
131
149
 
132
150
  model_queues = {}
@@ -152,99 +170,68 @@ class JobsRunnerStatusBase(ABC):
152
170
  status_dict["language_model_queues"] = model_queues
153
171
  return status_dict
154
172
 
155
- @abstractmethod
156
- def setup(self):
157
- """
158
- Conducts any setup that needs to happen prior to sending status updates.
173
+ def add_completed_interview(self, result):
174
+ """Records a completed interview without storing the full interview data."""
175
+ self.stats_tracker.add_completed_interview(
176
+ model=result.model.model,
177
+ num_exceptions=(
178
+ len(result.exceptions) if hasattr(result, "exceptions") else 0
179
+ ),
180
+ num_unfixed=(
181
+ result.exceptions.num_unfixed() if hasattr(result, "exceptions") else 0
182
+ ),
183
+ )
159
184
 
160
- Ex. For a local job, creates a job in the Coop database.
161
- """
162
- pass
185
+ def _compute_statistic(self, stat_name: str):
186
+ """Computes individual statistics based on the stats tracker."""
187
+ if stat_name == "elapsed_time":
188
+ value = self.stats_tracker.get_elapsed_time()
189
+ return {"elapsed_time": (value, 1, "sec.")}
163
190
 
164
- @abstractmethod
165
- def send_status_update(self):
166
- """
167
- Updates the current status of the job.
168
- """
169
- pass
191
+ elif stat_name == "total_interviews_requested":
192
+ return {"total_interviews_requested": (self.num_total_interviews, None, "")}
170
193
 
171
- def add_completed_interview(self, result):
172
- self.completed_interviews.append(result.interview_hash)
194
+ elif stat_name == "completed_interviews":
195
+ return {
196
+ "completed_interviews": (self.stats_tracker.completed_count, None, "")
197
+ }
173
198
 
174
- relevant_model = result.model.model
175
- self.completed_interview_by_model[relevant_model].append(result.interview_hash)
199
+ elif stat_name == "average_time_per_interview":
200
+ value = self.stats_tracker.get_average_time_per_interview()
201
+ return {"average_time_per_interview": (value, 2, "sec.")}
176
202
 
177
- def _compute_statistic(self, stat_name: str):
178
- completed_tasks = self.completed_interviews
179
- elapsed_time = time.time() - self.start_time
180
- interviews = self.jobs_runner.total_interviews
203
+ elif stat_name == "estimated_time_remaining":
204
+ value = self.stats_tracker.get_estimated_time_remaining()
205
+ return {"estimated_time_remaining": (value, 1, "sec.")}
181
206
 
182
- stat_definitions = {
183
- "elapsed_time": lambda: InterviewStatistic(
184
- "elapsed_time", value=elapsed_time, digits=1, units="sec."
185
- ),
186
- "total_interviews_requested": lambda: InterviewStatistic(
187
- "total_interviews_requested", value=len(interviews), units=""
188
- ),
189
- "completed_interviews": lambda: InterviewStatistic(
190
- "completed_interviews", value=len(completed_tasks), units=""
191
- ),
192
- "percent_complete": lambda: InterviewStatistic(
193
- "percent_complete",
194
- value=(
195
- len(completed_tasks) / len(interviews) * 100
196
- if len(interviews) > 0
197
- else 0
198
- ),
199
- digits=1,
200
- units="%",
201
- ),
202
- "average_time_per_interview": lambda: InterviewStatistic(
203
- "average_time_per_interview",
204
- value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
205
- digits=2,
206
- units="sec.",
207
- ),
208
- "task_remaining": lambda: InterviewStatistic(
209
- "task_remaining", value=len(interviews) - len(completed_tasks), units=""
210
- ),
211
- "estimated_time_remaining": lambda: InterviewStatistic(
212
- "estimated_time_remaining",
213
- value=(
214
- (len(interviews) - len(completed_tasks))
215
- * (elapsed_time / len(completed_tasks))
216
- if len(completed_tasks) > 0
217
- else 0
218
- ),
219
- digits=1,
220
- units="sec.",
221
- ),
222
- "exceptions": lambda: InterviewStatistic(
223
- "exceptions",
224
- value=sum(len(i.exceptions) for i in interviews),
225
- units="",
226
- ),
227
- "unfixed_exceptions": lambda: InterviewStatistic(
228
- "unfixed_exceptions",
229
- value=sum(i.exceptions.num_unfixed() for i in interviews),
230
- units="",
231
- ),
232
- "throughput": lambda: InterviewStatistic(
233
- "throughput",
234
- value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
235
- digits=2,
236
- units="interviews/sec.",
237
- ),
238
- }
239
- return stat_definitions[stat_name]()
207
+ elif stat_name == "exceptions":
208
+ return {"exceptions": (self.stats_tracker.total_exceptions, None, "")}
209
+
210
+ elif stat_name == "unfixed_exceptions":
211
+ return {
212
+ "unfixed_exceptions": (self.stats_tracker.unfixed_exceptions, None, "")
213
+ }
214
+
215
+ elif stat_name == "throughput":
216
+ value = self.stats_tracker.get_throughput()
217
+ return {"throughput": (value, 2, "interviews/sec.")}
240
218
 
241
219
  def update_progress(self, stop_event):
242
220
  while not stop_event.is_set():
243
221
  self.send_status_update()
244
222
  time.sleep(self.refresh_rate)
245
-
246
223
  self.send_status_update()
247
224
 
225
+ @abstractmethod
226
+ def setup(self):
227
+ """Conducts any setup needed prior to sending status updates."""
228
+ pass
229
+
230
+ @abstractmethod
231
+ def send_status_update(self):
232
+ """Updates the current status of the job."""
233
+ pass
234
+
248
235
 
249
236
  class JobsRunnerStatus(JobsRunnerStatusBase):
250
237
  @property
@@ -260,49 +247,35 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
260
247
  return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
261
248
 
262
249
  def setup(self) -> None:
263
- """
264
- Creates a local job on Coop if one does not already exist.
265
- """
266
-
267
- headers = {"Content-Type": "application/json"}
268
-
269
- if self.api_key:
270
- headers["Authorization"] = f"Bearer {self.api_key}"
271
- else:
272
- headers["Authorization"] = f"Bearer None"
250
+ """Creates a local job on Coop if one does not already exist."""
251
+ headers = {
252
+ "Content-Type": "application/json",
253
+ "Authorization": f"Bearer {self.api_key or 'None'}",
254
+ }
273
255
 
274
256
  if self.job_uuid is None:
275
- # Create a new local job
276
257
  response = requests.post(
277
258
  self.create_url,
278
259
  headers=headers,
279
260
  timeout=1,
280
261
  )
281
- response.raise_for_status()
282
- data = response.json()
283
- self.job_uuid = data.get("job_uuid")
262
+ response.raise_for_status()
263
+ data = response.json()
264
+ self.job_uuid = data.get("job_uuid")
284
265
 
285
266
  print(f"Running with progress bar. View progress at {self.viewing_url}")
286
267
 
287
268
  def send_status_update(self) -> None:
288
- """
289
- Sends current status to the web endpoint using the instance's job_uuid.
290
- """
269
+ """Sends current status to the web endpoint using the instance's job_uuid."""
291
270
  try:
292
- # Get the status dictionary and add the job_id
293
271
  status_dict = self.get_status_dict()
294
-
295
- # Make the UUID JSON serializable
296
272
  status_dict["job_id"] = str(self.job_uuid)
297
273
 
298
- headers = {"Content-Type": "application/json"}
299
-
300
- if self.api_key:
301
- headers["Authorization"] = f"Bearer {self.api_key}"
302
- else:
303
- headers["Authorization"] = f"Bearer None"
274
+ headers = {
275
+ "Content-Type": "application/json",
276
+ "Authorization": f"Bearer {self.api_key or 'None'}",
277
+ }
304
278
 
305
- # Send the update
306
279
  response = requests.patch(
307
280
  self.update_url,
308
281
  json=status_dict,
@@ -314,14 +287,8 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
314
287
  print(f"Failed to send status update for job {self.job_uuid}: {e}")
315
288
 
316
289
  def has_ep_api_key(self) -> bool:
317
- """
318
- Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
319
- """
320
-
321
- if self.api_key is not None:
322
- return True
323
- else:
324
- return False
290
+ """Returns True if the user has an Expected Parrot API key."""
291
+ return self.api_key is not None
325
292
 
326
293
 
327
294
  if __name__ == "__main__":
@@ -8,9 +8,10 @@ from edsl.Base import RepresentationMixin
8
8
  class TaskHistory(RepresentationMixin):
9
9
  def __init__(
10
10
  self,
11
- interviews: List["Interview"],
11
+ interviews: List["Interview"] = None,
12
12
  include_traceback: bool = False,
13
13
  max_interviews: int = 10,
14
+ interviews_with_exceptions_only: bool = False,
14
15
  ):
15
16
  """
16
17
  The structure of a TaskHistory exception
@@ -20,13 +21,33 @@ class TaskHistory(RepresentationMixin):
20
21
  >>> _ = TaskHistory.example()
21
22
  ...
22
23
  """
24
+ self.interviews_with_exceptions_only = interviews_with_exceptions_only
25
+ self._interviews = {}
26
+ self.total_interviews = []
27
+ if interviews is not None:
28
+ for interview in interviews:
29
+ self.add_interview(interview)
23
30
 
24
- self.total_interviews = interviews
31
+ self.include_traceback = include_traceback
32
+ self._interviews = {
33
+ index: interview for index, interview in enumerate(self.total_interviews)
34
+ }
35
+ self.max_interviews = max_interviews
36
+
37
+ # self.total_interviews = interviews
25
38
  self.include_traceback = include_traceback
26
39
 
27
- self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
40
+ # self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
28
41
  self.max_interviews = max_interviews
29
42
 
43
+ def add_interview(self, interview: "Interview"):
44
+ """Add a single interview to the history"""
45
+ if self.interviews_with_exceptions_only and interview.exceptions == {}:
46
+ return
47
+
48
+ self.total_interviews.append(interview)
49
+ self._interviews[len(self._interviews)] = interview
50
+
30
51
  @classmethod
31
52
  def example(cls):
32
53
  """ """
@@ -44,6 +44,8 @@ if TYPE_CHECKING:
44
44
  from edsl.questions.QuestionBase import QuestionBase
45
45
  from edsl.language_models.key_management.KeyLookup import KeyLookup
46
46
 
47
+ from edsl.enums import InferenceServiceType
48
+
47
49
  from edsl.utilities.decorators import (
48
50
  sync_wrapper,
49
51
  jupyter_nb_handler,
@@ -155,7 +157,9 @@ class LanguageModel(
155
157
  return klc.get(("config", "env"))
156
158
 
157
159
  def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
158
- del self._api_token
160
+ """Set the key lookup, later"""
161
+ if hasattr(self, "_api_token"):
162
+ del self._api_token
159
163
  self.key_lookup = key_lookup
160
164
 
161
165
  def ask_question(self, question: "QuestionBase") -> str:
@@ -493,7 +497,10 @@ class LanguageModel(
493
497
  >>> m.to_dict()
494
498
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
495
499
  """
496
- d = {"model": self.model, "parameters": self.parameters}
500
+ d = {
501
+ "model": self.model,
502
+ "parameters": self.parameters,
503
+ }
497
504
  if add_edsl_version:
498
505
  from edsl import __version__
499
506
 
@@ -505,7 +512,7 @@ class LanguageModel(
505
512
  @remove_edsl_version
506
513
  def from_dict(cls, data: dict) -> Type[LanguageModel]:
507
514
  """Convert dictionary to a LanguageModel child instance."""
508
- from edsl.language_models.registry import get_model_class
515
+ from edsl.language_models.model import get_model_class
509
516
 
510
517
  model_class = get_model_class(data["model"])
511
518
  return model_class(**data)
@@ -553,7 +560,7 @@ class LanguageModel(
553
560
  Exception report saved to ...
554
561
  Also see: ...
555
562
  """
556
- from edsl.language_models.registry import Model
563
+ from edsl.language_models.model import Model
557
564
 
558
565
  if test_model:
559
566
  m = Model(
@@ -563,6 +570,54 @@ class LanguageModel(
563
570
  else:
564
571
  return Model(skip_api_key_check=True)
565
572
 
573
+ def from_cache(self, cache: "Cache") -> LanguageModel:
574
+
575
+ from copy import deepcopy
576
+ from types import MethodType
577
+ from edsl import Cache
578
+
579
+ new_instance = deepcopy(self)
580
+ print("Cache entries", len(cache))
581
+ new_instance.cache = Cache(
582
+ data={k: v for k, v in cache.items() if v.model == self.model}
583
+ )
584
+ print("Cache entries with same model", len(new_instance.cache))
585
+
586
+ new_instance.user_prompts = [
587
+ ce.user_prompt for ce in new_instance.cache.values()
588
+ ]
589
+ new_instance.system_prompts = [
590
+ ce.system_prompt for ce in new_instance.cache.values()
591
+ ]
592
+
593
+ async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
594
+ cache_call_params = {
595
+ "model": str(self.model),
596
+ "parameters": self.parameters,
597
+ "system_prompt": system_prompt,
598
+ "user_prompt": user_prompt,
599
+ "iteration": 1,
600
+ }
601
+ cached_response, cache_key = cache.fetch(**cache_call_params)
602
+ response = json.loads(cached_response)
603
+ cost = 0
604
+ return ModelResponse(
605
+ response=response,
606
+ cache_used=True,
607
+ cache_key=cache_key,
608
+ cached_response=cached_response,
609
+ cost=cost,
610
+ )
611
+
612
+ # Bind the new method to the copied instance
613
+ setattr(
614
+ new_instance,
615
+ "async_execute_model_call",
616
+ MethodType(async_execute_model_call, new_instance),
617
+ )
618
+
619
+ return new_instance
620
+
566
621
 
567
622
  if __name__ == "__main__":
568
623
  """Run the module's test suite."""
@@ -1,18 +1,22 @@
1
- from typing import Optional, List
1
+ from typing import Optional, List, TYPE_CHECKING
2
2
  from collections import UserList
3
3
 
4
4
  from edsl.Base import Base
5
- from edsl.language_models.registry import Model
5
+ from edsl.language_models.model import Model
6
6
 
7
- # from edsl.language_models import LanguageModel
7
+ #
8
8
  from edsl.utilities.remove_edsl_version import remove_edsl_version
9
9
  from edsl.utilities.is_valid_variable_name import is_valid_variable_name
10
10
 
11
+ if TYPE_CHECKING:
12
+ from edsl.inference_services.data_structures import AvailableModels
13
+ from edsl.language_models import LanguageModel
14
+
11
15
 
12
16
  class ModelList(Base, UserList):
13
17
  __documentation__ = """https://docs.expectedparrot.com/en/latest/language_models.html#module-edsl.language_models.ModelList"""
14
18
 
15
- def __init__(self, data: Optional[list] = None):
19
+ def __init__(self, data: Optional["LanguageModel"] = None):
16
20
  """Initialize the ScenarioList class.
17
21
 
18
22
  >>> from edsl import Model
@@ -33,9 +37,6 @@ class ModelList(Base, UserList):
33
37
  """
34
38
  return set([model.model for model in self])
35
39
 
36
- def rich_print(self):
37
- pass
38
-
39
40
  def __repr__(self):
40
41
  return f"ModelList({super().__repr__()})"
41
42
 
@@ -87,7 +88,7 @@ class ModelList(Base, UserList):
87
88
  .table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
88
89
  )
89
90
 
90
- def to_list(self):
91
+ def to_list(self) -> list:
91
92
  return self.to_scenario_list().to_list()
92
93
 
93
94
  def to_dict(self, sort=False, add_edsl_version=True):
@@ -120,6 +121,16 @@ class ModelList(Base, UserList):
120
121
  args = args[0]
121
122
  return ModelList([Model(model_name, **kwargs) for model_name in args])
122
123
 
124
+ @classmethod
125
+ def from_available_models(self, available_models_list: "AvailableModels"):
126
+ """Create a ModelList from an AvailableModels object"""
127
+ return ModelList(
128
+ [
129
+ Model(model.model_name, service_name=model.service_name)
130
+ for model in available_models_list
131
+ ]
132
+ )
133
+
123
134
  @classmethod
124
135
  @remove_edsl_version
125
136
  def from_dict(cls, data):
@@ -1,2 +1,2 @@
1
1
  from edsl.language_models.LanguageModel import LanguageModel
2
- from edsl.language_models.registry import Model
2
+ from edsl.language_models.model import Model