edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. edsl/Base.py +0 -28
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +17 -9
  5. edsl/agents/Invigilator.py +14 -13
  6. edsl/agents/InvigilatorBase.py +1 -4
  7. edsl/agents/PromptConstructor.py +22 -42
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +5 -18
  10. edsl/auto/StageBase.py +40 -53
  11. edsl/auto/StageQuestions.py +1 -2
  12. edsl/auto/utilities.py +6 -0
  13. edsl/coop/coop.py +5 -21
  14. edsl/data/Cache.py +18 -29
  15. edsl/data/CacheHandler.py +2 -0
  16. edsl/data/RemoteCacheSync.py +46 -154
  17. edsl/enums.py +0 -7
  18. edsl/inference_services/AnthropicService.py +16 -38
  19. edsl/inference_services/AvailableModelFetcher.py +1 -7
  20. edsl/inference_services/GoogleService.py +1 -5
  21. edsl/inference_services/InferenceServicesCollection.py +2 -18
  22. edsl/inference_services/OpenAIService.py +31 -46
  23. edsl/inference_services/TestService.py +3 -1
  24. edsl/inference_services/TogetherAIService.py +3 -5
  25. edsl/inference_services/data_structures.py +2 -74
  26. edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
  27. edsl/jobs/FetchInvigilator.py +3 -10
  28. edsl/jobs/InterviewsConstructor.py +4 -6
  29. edsl/jobs/Jobs.py +233 -299
  30. edsl/jobs/JobsChecks.py +2 -2
  31. edsl/jobs/JobsPrompts.py +1 -1
  32. edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
  33. edsl/jobs/interviews/Interview.py +42 -80
  34. edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
  35. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  36. edsl/jobs/tasks/TaskHistory.py +3 -24
  37. edsl/language_models/LanguageModel.py +4 -59
  38. edsl/language_models/ModelList.py +8 -19
  39. edsl/language_models/__init__.py +1 -1
  40. edsl/language_models/registry.py +180 -0
  41. edsl/language_models/repair.py +1 -1
  42. edsl/questions/QuestionBase.py +26 -35
  43. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
  44. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  45. edsl/questions/QuestionBudget.py +1 -1
  46. edsl/questions/QuestionCheckBox.py +2 -2
  47. edsl/questions/QuestionExtract.py +7 -5
  48. edsl/questions/QuestionFreeText.py +1 -1
  49. edsl/questions/QuestionList.py +15 -9
  50. edsl/questions/QuestionMatrix.py +1 -1
  51. edsl/questions/QuestionMultipleChoice.py +1 -1
  52. edsl/questions/QuestionNumerical.py +1 -1
  53. edsl/questions/QuestionRank.py +1 -1
  54. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
  55. edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
  56. edsl/questions/SimpleAskMixin.py +1 -1
  57. edsl/questions/__init__.py +1 -1
  58. edsl/results/DatasetExportMixin.py +119 -60
  59. edsl/results/Result.py +3 -109
  60. edsl/results/Results.py +39 -50
  61. edsl/scenarios/FileStore.py +0 -32
  62. edsl/scenarios/ScenarioList.py +7 -35
  63. edsl/scenarios/handlers/csv.py +0 -11
  64. edsl/surveys/Survey.py +20 -71
  65. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
  66. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
  67. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
  68. edsl/jobs/async_interview_runner.py +0 -138
  69. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  70. edsl/jobs/data_structures.py +0 -120
  71. edsl/jobs/results_exceptions_handler.py +0 -98
  72. edsl/language_models/model.py +0 -256
  73. edsl/questions/data_structures.py +0 -20
  74. edsl/results/file_exports.py +0 -252
  75. /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
  76. /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
  77. /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
  78. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  79. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  80. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  81. /edsl/results/{results_selector.py → Selector.py} +0 -0
  82. /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
  83. /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
  84. /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
  85. {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
@@ -3,12 +3,21 @@ from __future__ import annotations
3
3
  import os
4
4
  import time
5
5
  import requests
6
+ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
  from dataclasses import dataclass
9
+
10
+ from typing import Any, List, DefaultDict, Optional, Dict
8
11
  from collections import defaultdict
9
- from typing import Any, Dict, Optional
10
12
  from uuid import UUID
11
13
 
14
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
15
+
16
+ InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
17
+
18
+ from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
19
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
20
+
12
21
 
13
22
  @dataclass
14
23
  class ModelInfo:
@@ -19,44 +28,11 @@ class ModelInfo:
19
28
  token_usage_info: dict
20
29
 
21
30
 
22
- class StatisticsTracker:
23
- def __init__(self, total_interviews: int, distinct_models: list[str]):
24
- self.start_time = time.time()
25
- self.total_interviews = total_interviews
26
- self.completed_count = 0
27
- self.completed_by_model = defaultdict(int)
28
- self.distinct_models = distinct_models
29
- self.total_exceptions = 0
30
- self.unfixed_exceptions = 0
31
-
32
- def add_completed_interview(
33
- self, model: str, num_exceptions: int = 0, num_unfixed: int = 0
34
- ):
35
- self.completed_count += 1
36
- self.completed_by_model[model] += 1
37
- self.total_exceptions += num_exceptions
38
- self.unfixed_exceptions += num_unfixed
39
-
40
- def get_elapsed_time(self) -> float:
41
- return time.time() - self.start_time
42
-
43
- def get_average_time_per_interview(self) -> float:
44
- return (
45
- self.get_elapsed_time() / self.completed_count
46
- if self.completed_count > 0
47
- else 0
48
- )
49
-
50
- def get_throughput(self) -> float:
51
- elapsed = self.get_elapsed_time()
52
- return self.completed_count / elapsed if elapsed > 0 else 0
53
-
54
- def get_estimated_time_remaining(self) -> float:
55
- if self.completed_count == 0:
56
- return 0
57
- avg_time = self.get_average_time_per_interview()
58
- remaining = self.total_interviews - self.completed_count
59
- return avg_time * remaining
31
+ @dataclass
32
+ class ModelTokenUsageStats:
33
+ token_usage_type: str
34
+ details: List[dict]
35
+ cost: str
60
36
 
61
37
 
62
38
  class JobsRunnerStatusBase(ABC):
@@ -70,39 +46,48 @@ class JobsRunnerStatusBase(ABC):
70
46
  api_key: str = None,
71
47
  ):
72
48
  self.jobs_runner = jobs_runner
49
+
50
+ # The uuid of the job on Coop
73
51
  self.job_uuid = job_uuid
52
+
74
53
  self.base_url = f"{endpoint_url}"
54
+
55
+ self.start_time = time.time()
56
+ self.completed_interviews = []
75
57
  self.refresh_rate = refresh_rate
76
58
  self.statistics = [
77
59
  "elapsed_time",
78
60
  "total_interviews_requested",
79
61
  "completed_interviews",
62
+ # "percent_complete",
80
63
  "average_time_per_interview",
64
+ # "task_remaining",
81
65
  "estimated_time_remaining",
82
66
  "exceptions",
83
67
  "unfixed_exceptions",
84
68
  "throughput",
85
69
  ]
86
- self.num_total_interviews = n * len(self.jobs_runner)
70
+ self.num_total_interviews = n * len(self.jobs_runner.interviews)
87
71
 
88
72
  self.distinct_models = list(
89
- set(model.model for model in self.jobs_runner.jobs.models)
73
+ set(i.model.model for i in self.jobs_runner.interviews)
90
74
  )
91
75
 
92
- self.stats_tracker = StatisticsTracker(
93
- total_interviews=self.num_total_interviews,
94
- distinct_models=self.distinct_models,
95
- )
76
+ self.completed_interview_by_model = defaultdict(list)
96
77
 
97
78
  self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
98
79
 
99
80
  @abstractmethod
100
81
  def has_ep_api_key(self):
101
- """Checks if the user has an Expected Parrot API key."""
82
+ """
83
+ Checks if the user has an Expected Parrot API key.
84
+ """
102
85
  pass
103
86
 
104
87
  def get_status_dict(self) -> Dict[str, Any]:
105
- """Converts current status into a JSON-serializable dictionary."""
88
+ """
89
+ Converts current status into a JSON-serializable dictionary.
90
+ """
106
91
  # Get all statistics
107
92
  stats = {}
108
93
  for stat_name in self.statistics:
@@ -110,46 +95,42 @@ class JobsRunnerStatusBase(ABC):
110
95
  name, value = list(stat.items())[0]
111
96
  stats[name] = value
112
97
 
98
+ # Calculate overall progress
99
+ total_interviews = len(self.jobs_runner.total_interviews)
100
+ completed = len(self.completed_interviews)
101
+
113
102
  # Get model-specific progress
114
103
  model_progress = {}
115
- target_per_model = int(self.num_total_interviews / len(self.distinct_models))
116
-
117
104
  for model in self.distinct_models:
118
- completed = self.stats_tracker.completed_by_model[model]
105
+ completed_for_model = len(self.completed_interview_by_model[model])
106
+ target_for_model = int(
107
+ self.num_total_interviews / len(self.distinct_models)
108
+ )
119
109
  model_progress[model] = {
120
- "completed": completed,
121
- "total": target_per_model,
110
+ "completed": completed_for_model,
111
+ "total": target_for_model,
122
112
  "percent": (
123
- (completed / target_per_model * 100) if target_per_model > 0 else 0
113
+ (completed_for_model / target_for_model * 100)
114
+ if target_for_model > 0
115
+ else 0
124
116
  ),
125
117
  }
126
118
 
127
119
  status_dict = {
128
120
  "overall_progress": {
129
- "completed": self.stats_tracker.completed_count,
130
- "total": self.num_total_interviews,
121
+ "completed": completed,
122
+ "total": total_interviews,
131
123
  "percent": (
132
- (
133
- self.stats_tracker.completed_count
134
- / self.num_total_interviews
135
- * 100
136
- )
137
- if self.num_total_interviews > 0
138
- else 0
124
+ (completed / total_interviews * 100) if total_interviews > 0 else 0
139
125
  ),
140
126
  },
141
127
  "language_model_progress": model_progress,
142
128
  "statistics": stats,
143
- "status": (
144
- "completed"
145
- if self.stats_tracker.completed_count >= self.num_total_interviews
146
- else "running"
147
- ),
129
+ "status": "completed" if completed >= total_interviews else "running",
148
130
  }
149
131
 
150
132
  model_queues = {}
151
- # for model, bucket in self.jobs_runner.bucket_collection.items():
152
- for model, bucket in self.jobs_runner.environment.bucket_collection.items():
133
+ for model, bucket in self.jobs_runner.bucket_collection.items():
153
134
  model_name = model.model
154
135
  model_queues[model_name] = {
155
136
  "language_model_name": model_name,
@@ -171,67 +152,98 @@ class JobsRunnerStatusBase(ABC):
171
152
  status_dict["language_model_queues"] = model_queues
172
153
  return status_dict
173
154
 
174
- def add_completed_interview(self, result):
175
- """Records a completed interview without storing the full interview data."""
176
- self.stats_tracker.add_completed_interview(
177
- model=result.model.model,
178
- num_exceptions=(
179
- len(result.exceptions) if hasattr(result, "exceptions") else 0
180
- ),
181
- num_unfixed=(
182
- result.exceptions.num_unfixed() if hasattr(result, "exceptions") else 0
183
- ),
184
- )
185
-
186
- def _compute_statistic(self, stat_name: str):
187
- """Computes individual statistics based on the stats tracker."""
188
- if stat_name == "elapsed_time":
189
- value = self.stats_tracker.get_elapsed_time()
190
- return {"elapsed_time": (value, 1, "sec.")}
191
-
192
- elif stat_name == "total_interviews_requested":
193
- return {"total_interviews_requested": (self.num_total_interviews, None, "")}
155
+ @abstractmethod
156
+ def setup(self):
157
+ """
158
+ Conducts any setup that needs to happen prior to sending status updates.
194
159
 
195
- elif stat_name == "completed_interviews":
196
- return {
197
- "completed_interviews": (self.stats_tracker.completed_count, None, "")
198
- }
160
+ Ex. For a local job, creates a job in the Coop database.
161
+ """
162
+ pass
199
163
 
200
- elif stat_name == "average_time_per_interview":
201
- value = self.stats_tracker.get_average_time_per_interview()
202
- return {"average_time_per_interview": (value, 2, "sec.")}
164
+ @abstractmethod
165
+ def send_status_update(self):
166
+ """
167
+ Updates the current status of the job.
168
+ """
169
+ pass
203
170
 
204
- elif stat_name == "estimated_time_remaining":
205
- value = self.stats_tracker.get_estimated_time_remaining()
206
- return {"estimated_time_remaining": (value, 1, "sec.")}
171
+ def add_completed_interview(self, result):
172
+ self.completed_interviews.append(result.interview_hash)
207
173
 
208
- elif stat_name == "exceptions":
209
- return {"exceptions": (self.stats_tracker.total_exceptions, None, "")}
174
+ relevant_model = result.model.model
175
+ self.completed_interview_by_model[relevant_model].append(result.interview_hash)
210
176
 
211
- elif stat_name == "unfixed_exceptions":
212
- return {
213
- "unfixed_exceptions": (self.stats_tracker.unfixed_exceptions, None, "")
214
- }
177
+ def _compute_statistic(self, stat_name: str):
178
+ completed_tasks = self.completed_interviews
179
+ elapsed_time = time.time() - self.start_time
180
+ interviews = self.jobs_runner.total_interviews
215
181
 
216
- elif stat_name == "throughput":
217
- value = self.stats_tracker.get_throughput()
218
- return {"throughput": (value, 2, "interviews/sec.")}
182
+ stat_definitions = {
183
+ "elapsed_time": lambda: InterviewStatistic(
184
+ "elapsed_time", value=elapsed_time, digits=1, units="sec."
185
+ ),
186
+ "total_interviews_requested": lambda: InterviewStatistic(
187
+ "total_interviews_requested", value=len(interviews), units=""
188
+ ),
189
+ "completed_interviews": lambda: InterviewStatistic(
190
+ "completed_interviews", value=len(completed_tasks), units=""
191
+ ),
192
+ "percent_complete": lambda: InterviewStatistic(
193
+ "percent_complete",
194
+ value=(
195
+ len(completed_tasks) / len(interviews) * 100
196
+ if len(interviews) > 0
197
+ else 0
198
+ ),
199
+ digits=1,
200
+ units="%",
201
+ ),
202
+ "average_time_per_interview": lambda: InterviewStatistic(
203
+ "average_time_per_interview",
204
+ value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
205
+ digits=2,
206
+ units="sec.",
207
+ ),
208
+ "task_remaining": lambda: InterviewStatistic(
209
+ "task_remaining", value=len(interviews) - len(completed_tasks), units=""
210
+ ),
211
+ "estimated_time_remaining": lambda: InterviewStatistic(
212
+ "estimated_time_remaining",
213
+ value=(
214
+ (len(interviews) - len(completed_tasks))
215
+ * (elapsed_time / len(completed_tasks))
216
+ if len(completed_tasks) > 0
217
+ else 0
218
+ ),
219
+ digits=1,
220
+ units="sec.",
221
+ ),
222
+ "exceptions": lambda: InterviewStatistic(
223
+ "exceptions",
224
+ value=sum(len(i.exceptions) for i in interviews),
225
+ units="",
226
+ ),
227
+ "unfixed_exceptions": lambda: InterviewStatistic(
228
+ "unfixed_exceptions",
229
+ value=sum(i.exceptions.num_unfixed() for i in interviews),
230
+ units="",
231
+ ),
232
+ "throughput": lambda: InterviewStatistic(
233
+ "throughput",
234
+ value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
235
+ digits=2,
236
+ units="interviews/sec.",
237
+ ),
238
+ }
239
+ return stat_definitions[stat_name]()
219
240
 
220
241
  def update_progress(self, stop_event):
221
242
  while not stop_event.is_set():
222
243
  self.send_status_update()
223
244
  time.sleep(self.refresh_rate)
224
- self.send_status_update()
225
245
 
226
- @abstractmethod
227
- def setup(self):
228
- """Conducts any setup needed prior to sending status updates."""
229
- pass
230
-
231
- @abstractmethod
232
- def send_status_update(self):
233
- """Updates the current status of the job."""
234
- pass
246
+ self.send_status_update()
235
247
 
236
248
 
237
249
  class JobsRunnerStatus(JobsRunnerStatusBase):
@@ -248,35 +260,49 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
248
260
  return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
249
261
 
250
262
  def setup(self) -> None:
251
- """Creates a local job on Coop if one does not already exist."""
252
- headers = {
253
- "Content-Type": "application/json",
254
- "Authorization": f"Bearer {self.api_key or 'None'}",
255
- }
263
+ """
264
+ Creates a local job on Coop if one does not already exist.
265
+ """
266
+
267
+ headers = {"Content-Type": "application/json"}
268
+
269
+ if self.api_key:
270
+ headers["Authorization"] = f"Bearer {self.api_key}"
271
+ else:
272
+ headers["Authorization"] = f"Bearer None"
256
273
 
257
274
  if self.job_uuid is None:
275
+ # Create a new local job
258
276
  response = requests.post(
259
277
  self.create_url,
260
278
  headers=headers,
261
279
  timeout=1,
262
280
  )
263
- response.raise_for_status()
264
- data = response.json()
265
- self.job_uuid = data.get("job_uuid")
281
+ response.raise_for_status()
282
+ data = response.json()
283
+ self.job_uuid = data.get("job_uuid")
266
284
 
267
285
  print(f"Running with progress bar. View progress at {self.viewing_url}")
268
286
 
269
287
  def send_status_update(self) -> None:
270
- """Sends current status to the web endpoint using the instance's job_uuid."""
288
+ """
289
+ Sends current status to the web endpoint using the instance's job_uuid.
290
+ """
271
291
  try:
292
+ # Get the status dictionary and add the job_id
272
293
  status_dict = self.get_status_dict()
294
+
295
+ # Make the UUID JSON serializable
273
296
  status_dict["job_id"] = str(self.job_uuid)
274
297
 
275
- headers = {
276
- "Content-Type": "application/json",
277
- "Authorization": f"Bearer {self.api_key or 'None'}",
278
- }
298
+ headers = {"Content-Type": "application/json"}
299
+
300
+ if self.api_key:
301
+ headers["Authorization"] = f"Bearer {self.api_key}"
302
+ else:
303
+ headers["Authorization"] = f"Bearer None"
279
304
 
305
+ # Send the update
280
306
  response = requests.patch(
281
307
  self.update_url,
282
308
  json=status_dict,
@@ -288,8 +314,14 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
288
314
  print(f"Failed to send status update for job {self.job_uuid}: {e}")
289
315
 
290
316
  def has_ep_api_key(self) -> bool:
291
- """Returns True if the user has an Expected Parrot API key."""
292
- return self.api_key is not None
317
+ """
318
+ Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
319
+ """
320
+
321
+ if self.api_key is not None:
322
+ return True
323
+ else:
324
+ return False
293
325
 
294
326
 
295
327
  if __name__ == "__main__":
@@ -8,10 +8,9 @@ from edsl.Base import RepresentationMixin
8
8
  class TaskHistory(RepresentationMixin):
9
9
  def __init__(
10
10
  self,
11
- interviews: List["Interview"] = None,
11
+ interviews: List["Interview"],
12
12
  include_traceback: bool = False,
13
13
  max_interviews: int = 10,
14
- interviews_with_exceptions_only: bool = False,
15
14
  ):
16
15
  """
17
16
  The structure of a TaskHistory exception
@@ -21,33 +20,13 @@ class TaskHistory(RepresentationMixin):
21
20
  >>> _ = TaskHistory.example()
22
21
  ...
23
22
  """
24
- self.interviews_with_exceptions_only = interviews_with_exceptions_only
25
- self._interviews = {}
26
- self.total_interviews = []
27
- if interviews is not None:
28
- for interview in interviews:
29
- self.add_interview(interview)
30
23
 
31
- self.include_traceback = include_traceback
32
- self._interviews = {
33
- index: interview for index, interview in enumerate(self.total_interviews)
34
- }
35
- self.max_interviews = max_interviews
36
-
37
- # self.total_interviews = interviews
24
+ self.total_interviews = interviews
38
25
  self.include_traceback = include_traceback
39
26
 
40
- # self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
27
+ self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
41
28
  self.max_interviews = max_interviews
42
29
 
43
- def add_interview(self, interview: "Interview"):
44
- """Add a single interview to the history"""
45
- if self.interviews_with_exceptions_only and interview.exceptions == {}:
46
- return
47
-
48
- self.total_interviews.append(interview)
49
- self._interviews[len(self._interviews)] = interview
50
-
51
30
  @classmethod
52
31
  def example(cls):
53
32
  """ """
@@ -44,8 +44,6 @@ if TYPE_CHECKING:
44
44
  from edsl.questions.QuestionBase import QuestionBase
45
45
  from edsl.language_models.key_management.KeyLookup import KeyLookup
46
46
 
47
- from edsl.enums import InferenceServiceType
48
-
49
47
  from edsl.utilities.decorators import (
50
48
  sync_wrapper,
51
49
  jupyter_nb_handler,
@@ -157,9 +155,7 @@ class LanguageModel(
157
155
  return klc.get(("config", "env"))
158
156
 
159
157
  def set_key_lookup(self, key_lookup: "KeyLookup") -> None:
160
- """Set the key lookup, later"""
161
- if hasattr(self, "_api_token"):
162
- del self._api_token
158
+ del self._api_token
163
159
  self.key_lookup = key_lookup
164
160
 
165
161
  def ask_question(self, question: "QuestionBase") -> str:
@@ -497,10 +493,7 @@ class LanguageModel(
497
493
  >>> m.to_dict()
498
494
  {'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
499
495
  """
500
- d = {
501
- "model": self.model,
502
- "parameters": self.parameters,
503
- }
496
+ d = {"model": self.model, "parameters": self.parameters}
504
497
  if add_edsl_version:
505
498
  from edsl import __version__
506
499
 
@@ -512,7 +505,7 @@ class LanguageModel(
512
505
  @remove_edsl_version
513
506
  def from_dict(cls, data: dict) -> Type[LanguageModel]:
514
507
  """Convert dictionary to a LanguageModel child instance."""
515
- from edsl.language_models.model import get_model_class
508
+ from edsl.language_models.registry import get_model_class
516
509
 
517
510
  model_class = get_model_class(data["model"])
518
511
  return model_class(**data)
@@ -560,7 +553,7 @@ class LanguageModel(
560
553
  Exception report saved to ...
561
554
  Also see: ...
562
555
  """
563
- from edsl.language_models.model import Model
556
+ from edsl.language_models.registry import Model
564
557
 
565
558
  if test_model:
566
559
  m = Model(
@@ -570,54 +563,6 @@ class LanguageModel(
570
563
  else:
571
564
  return Model(skip_api_key_check=True)
572
565
 
573
- def from_cache(self, cache: "Cache") -> LanguageModel:
574
-
575
- from copy import deepcopy
576
- from types import MethodType
577
- from edsl import Cache
578
-
579
- new_instance = deepcopy(self)
580
- print("Cache entries", len(cache))
581
- new_instance.cache = Cache(
582
- data={k: v for k, v in cache.items() if v.model == self.model}
583
- )
584
- print("Cache entries with same model", len(new_instance.cache))
585
-
586
- new_instance.user_prompts = [
587
- ce.user_prompt for ce in new_instance.cache.values()
588
- ]
589
- new_instance.system_prompts = [
590
- ce.system_prompt for ce in new_instance.cache.values()
591
- ]
592
-
593
- async def async_execute_model_call(self, user_prompt: str, system_prompt: str):
594
- cache_call_params = {
595
- "model": str(self.model),
596
- "parameters": self.parameters,
597
- "system_prompt": system_prompt,
598
- "user_prompt": user_prompt,
599
- "iteration": 1,
600
- }
601
- cached_response, cache_key = cache.fetch(**cache_call_params)
602
- response = json.loads(cached_response)
603
- cost = 0
604
- return ModelResponse(
605
- response=response,
606
- cache_used=True,
607
- cache_key=cache_key,
608
- cached_response=cached_response,
609
- cost=cost,
610
- )
611
-
612
- # Bind the new method to the copied instance
613
- setattr(
614
- new_instance,
615
- "async_execute_model_call",
616
- MethodType(async_execute_model_call, new_instance),
617
- )
618
-
619
- return new_instance
620
-
621
566
 
622
567
  if __name__ == "__main__":
623
568
  """Run the module's test suite."""
@@ -1,22 +1,18 @@
1
- from typing import Optional, List, TYPE_CHECKING
1
+ from typing import Optional, List
2
2
  from collections import UserList
3
3
 
4
4
  from edsl.Base import Base
5
- from edsl.language_models.model import Model
5
+ from edsl.language_models.registry import Model
6
6
 
7
- #
7
+ # from edsl.language_models import LanguageModel
8
8
  from edsl.utilities.remove_edsl_version import remove_edsl_version
9
9
  from edsl.utilities.is_valid_variable_name import is_valid_variable_name
10
10
 
11
- if TYPE_CHECKING:
12
- from edsl.inference_services.data_structures import AvailableModels
13
- from edsl.language_models import LanguageModel
14
-
15
11
 
16
12
  class ModelList(Base, UserList):
17
13
  __documentation__ = """https://docs.expectedparrot.com/en/latest/language_models.html#module-edsl.language_models.ModelList"""
18
14
 
19
- def __init__(self, data: Optional["LanguageModel"] = None):
15
+ def __init__(self, data: Optional[list] = None):
20
16
  """Initialize the ScenarioList class.
21
17
 
22
18
  >>> from edsl import Model
@@ -37,6 +33,9 @@ class ModelList(Base, UserList):
37
33
  """
38
34
  return set([model.model for model in self])
39
35
 
36
+ def rich_print(self):
37
+ pass
38
+
40
39
  def __repr__(self):
41
40
  return f"ModelList({super().__repr__()})"
42
41
 
@@ -88,7 +87,7 @@ class ModelList(Base, UserList):
88
87
  .table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
89
88
  )
90
89
 
91
- def to_list(self) -> list:
90
+ def to_list(self):
92
91
  return self.to_scenario_list().to_list()
93
92
 
94
93
  def to_dict(self, sort=False, add_edsl_version=True):
@@ -121,16 +120,6 @@ class ModelList(Base, UserList):
121
120
  args = args[0]
122
121
  return ModelList([Model(model_name, **kwargs) for model_name in args])
123
122
 
124
- @classmethod
125
- def from_available_models(self, available_models_list: "AvailableModels"):
126
- """Create a ModelList from an AvailableModels object"""
127
- return ModelList(
128
- [
129
- Model(model.model_name, service_name=model.service_name)
130
- for model in available_models_list
131
- ]
132
- )
133
-
134
123
  @classmethod
135
124
  @remove_edsl_version
136
125
  def from_dict(cls, data):
@@ -1,2 +1,2 @@
1
1
  from edsl.language_models.LanguageModel import LanguageModel
2
- from edsl.language_models.model import Model
2
+ from edsl.language_models.registry import Model