edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. edsl/Base.py +24 -14
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +28 -6
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
  8. edsl/agents/prompt_helpers.py +129 -0
  9. edsl/config.py +26 -34
  10. edsl/coop/coop.py +14 -4
  11. edsl/data_transfer_models.py +26 -73
  12. edsl/enums.py +2 -0
  13. edsl/inference_services/AnthropicService.py +5 -2
  14. edsl/inference_services/AwsBedrock.py +5 -2
  15. edsl/inference_services/AzureAI.py +5 -2
  16. edsl/inference_services/GoogleService.py +108 -33
  17. edsl/inference_services/InferenceServiceABC.py +44 -13
  18. edsl/inference_services/MistralAIService.py +5 -2
  19. edsl/inference_services/OpenAIService.py +10 -6
  20. edsl/inference_services/TestService.py +34 -16
  21. edsl/inference_services/TogetherAIService.py +170 -0
  22. edsl/inference_services/registry.py +2 -0
  23. edsl/jobs/Jobs.py +109 -18
  24. edsl/jobs/buckets/BucketCollection.py +24 -15
  25. edsl/jobs/buckets/TokenBucket.py +64 -10
  26. edsl/jobs/interviews/Interview.py +130 -49
  27. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  28. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  29. edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
  30. edsl/jobs/runners/JobsRunnerStatus.py +332 -0
  31. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  32. edsl/jobs/tasks/TaskHistory.py +17 -0
  33. edsl/language_models/LanguageModel.py +36 -38
  34. edsl/language_models/registry.py +13 -9
  35. edsl/language_models/utilities.py +5 -2
  36. edsl/questions/QuestionBase.py +74 -16
  37. edsl/questions/QuestionBaseGenMixin.py +28 -0
  38. edsl/questions/QuestionBudget.py +93 -41
  39. edsl/questions/QuestionCheckBox.py +1 -1
  40. edsl/questions/QuestionFreeText.py +6 -0
  41. edsl/questions/QuestionMultipleChoice.py +13 -24
  42. edsl/questions/QuestionNumerical.py +5 -4
  43. edsl/questions/Quick.py +41 -0
  44. edsl/questions/ResponseValidatorABC.py +11 -6
  45. edsl/questions/derived/QuestionLinearScale.py +4 -1
  46. edsl/questions/derived/QuestionTopK.py +4 -1
  47. edsl/questions/derived/QuestionYesNo.py +8 -2
  48. edsl/questions/descriptors.py +12 -11
  49. edsl/questions/templates/budget/__init__.py +0 -0
  50. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  51. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  52. edsl/questions/templates/extract/__init__.py +0 -0
  53. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  54. edsl/questions/templates/rank/__init__.py +0 -0
  55. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  56. edsl/results/DatasetExportMixin.py +5 -1
  57. edsl/results/Result.py +1 -1
  58. edsl/results/Results.py +4 -1
  59. edsl/scenarios/FileStore.py +178 -34
  60. edsl/scenarios/Scenario.py +76 -37
  61. edsl/scenarios/ScenarioList.py +19 -2
  62. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  63. edsl/study/Study.py +32 -0
  64. edsl/surveys/DAG.py +62 -0
  65. edsl/surveys/MemoryPlan.py +26 -0
  66. edsl/surveys/Rule.py +34 -1
  67. edsl/surveys/RuleCollection.py +55 -5
  68. edsl/surveys/Survey.py +189 -10
  69. edsl/surveys/base.py +4 -0
  70. edsl/templates/error_reporting/interview_details.html +6 -1
  71. edsl/utilities/utilities.py +9 -1
  72. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
  73. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
  74. edsl/jobs/interviews/retry_management.py +0 -39
  75. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  76. edsl/scenarios/ScenarioImageMixin.py +0 -100
  77. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  78. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -0,0 +1,332 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass, asdict
5
+
6
+ from typing import List, DefaultDict, Optional, Type, Literal
7
+ from collections import UserDict, defaultdict
8
+
9
+ from rich.text import Text
10
+ from rich.box import SIMPLE
11
+ from rich.table import Table
12
+ from rich.live import Live
13
+ from rich.panel import Panel
14
+ from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
15
+ from rich.layout import Layout
16
+ from rich.console import Group
17
+ from rich import box
18
+
19
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
20
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
21
+ from edsl.jobs.tokens.TokenUsage import TokenUsage
22
+ from edsl.enums import get_token_pricing
23
+ from edsl.jobs.tasks.task_status_enum import TaskStatus
24
+
25
+ InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
26
+
27
+ from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
28
+ from edsl.jobs.interviews.InterviewStatisticsCollection import (
29
+ InterviewStatisticsCollection,
30
+ )
31
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
32
+
33
+
34
+ @dataclass
35
+ class ModelInfo:
36
+ model_name: str
37
+ TPM_limit_k: float
38
+ RPM_limit_k: float
39
+ num_tasks_waiting: int
40
+ token_usage_info: dict
41
+
42
+
43
+ @dataclass
44
+ class ModelTokenUsageStats:
45
+ token_usage_type: str
46
+ details: List[dict]
47
+ cost: str
48
+
49
+
50
+ class Stats:
51
+ def elapsed_time(self):
52
+ InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
53
+
54
+
55
+ class JobsRunnerStatus:
56
+ def __init__(
57
+ self, jobs_runner: "JobsRunnerAsyncio", n: int, refresh_rate: float = 0.25
58
+ ):
59
+ self.jobs_runner = jobs_runner
60
+ self.start_time = time.time()
61
+ self.completed_interviews = []
62
+ self.refresh_rate = refresh_rate
63
+ self.statistics = [
64
+ "elapsed_time",
65
+ "total_interviews_requested",
66
+ "completed_interviews",
67
+ # "percent_complete",
68
+ "average_time_per_interview",
69
+ # "task_remaining",
70
+ "estimated_time_remaining",
71
+ "exceptions",
72
+ "unfixed_exceptions",
73
+ "throughput",
74
+ ]
75
+ self.num_total_interviews = n * len(self.jobs_runner.interviews)
76
+
77
+ self.distinct_models = list(
78
+ set(i.model.model for i in self.jobs_runner.interviews)
79
+ )
80
+
81
+ self.completed_interview_by_model = defaultdict(list)
82
+
83
+ def add_completed_interview(self, result):
84
+ self.completed_interviews.append(result.interview_hash)
85
+
86
+ relevant_model = result.model.model
87
+ self.completed_interview_by_model[relevant_model].append(result.interview_hash)
88
+
89
+ def _compute_statistic(self, stat_name: str):
90
+ completed_tasks = self.completed_interviews
91
+ elapsed_time = time.time() - self.start_time
92
+ interviews = self.jobs_runner.total_interviews
93
+
94
+ stat_definitions = {
95
+ "elapsed_time": lambda: InterviewStatistic(
96
+ "elapsed_time", value=elapsed_time, digits=1, units="sec."
97
+ ),
98
+ "total_interviews_requested": lambda: InterviewStatistic(
99
+ "total_interviews_requested", value=len(interviews), units=""
100
+ ),
101
+ "completed_interviews": lambda: InterviewStatistic(
102
+ "completed_interviews", value=len(completed_tasks), units=""
103
+ ),
104
+ "percent_complete": lambda: InterviewStatistic(
105
+ "percent_complete",
106
+ value=(
107
+ len(completed_tasks) / len(interviews) * 100
108
+ if len(interviews) > 0
109
+ else 0
110
+ ),
111
+ digits=1,
112
+ units="%",
113
+ ),
114
+ "average_time_per_interview": lambda: InterviewStatistic(
115
+ "average_time_per_interview",
116
+ value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
117
+ digits=2,
118
+ units="sec.",
119
+ ),
120
+ "task_remaining": lambda: InterviewStatistic(
121
+ "task_remaining", value=len(interviews) - len(completed_tasks), units=""
122
+ ),
123
+ "estimated_time_remaining": lambda: InterviewStatistic(
124
+ "estimated_time_remaining",
125
+ value=(
126
+ (len(interviews) - len(completed_tasks))
127
+ * (elapsed_time / len(completed_tasks))
128
+ if len(completed_tasks) > 0
129
+ else 0
130
+ ),
131
+ digits=1,
132
+ units="sec.",
133
+ ),
134
+ "exceptions": lambda: InterviewStatistic(
135
+ "exceptions",
136
+ value=sum(len(i.exceptions) for i in interviews),
137
+ units="",
138
+ ),
139
+ "unfixed_exceptions": lambda: InterviewStatistic(
140
+ "unfixed_exceptions",
141
+ value=sum(i.exceptions.num_unfixed() for i in interviews),
142
+ units="",
143
+ ),
144
+ "throughput": lambda: InterviewStatistic(
145
+ "throughput",
146
+ value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
147
+ digits=2,
148
+ units="interviews/sec.",
149
+ ),
150
+ }
151
+ return stat_definitions[stat_name]()
152
+
153
+ def create_progress_bar(self):
154
+ return Progress(
155
+ TextColumn("[progress.description]{task.description}"),
156
+ BarColumn(),
157
+ TaskProgressColumn(),
158
+ TextColumn("{task.completed}/{task.total}"),
159
+ )
160
+
161
+ def generate_model_queues_table(self):
162
+ table = Table(show_header=False, box=box.SIMPLE)
163
+ table.add_column("Info", style="cyan")
164
+ table.add_column("Value", style="magenta")
165
+ # table.add_row("Bucket collection", str(self.jobs_runner.bucket_collection))
166
+ for model, bucket in self.jobs_runner.bucket_collection.items():
167
+ table.add_row(Text(model.model, style="bold blue"), "")
168
+ bucket_types = ["requests_bucket", "tokens_bucket"]
169
+ for bucket_type in bucket_types:
170
+ table.add_row(Text(" " + bucket_type, style="green"), "")
171
+ # table.add_row(
172
+ # f" Current level (capacity = {round(getattr(bucket, bucket_type).capacity, 3)})",
173
+ # str(round(getattr(bucket, bucket_type).tokens, 3)),
174
+ # )
175
+ num_requests = getattr(bucket, bucket_type).num_requests
176
+ num_released = getattr(bucket, bucket_type).num_released
177
+ tokens_returned = getattr(bucket, bucket_type).tokens_returned
178
+ # table.add_row(
179
+ # f" Requested",
180
+ # str(num_requests),
181
+ # )
182
+ # table.add_row(
183
+ # f" Completed",
184
+ # str(num_released),
185
+ # )
186
+ table.add_row(
187
+ " Completed vs. Requested", f"{num_released} vs. {num_requests}"
188
+ )
189
+ table.add_row(
190
+ " Added tokens (from cache)",
191
+ str(tokens_returned),
192
+ )
193
+ if bucket_type == "tokens_bucket":
194
+ rate_name = "TPM"
195
+ else:
196
+ rate_name = "RPM"
197
+ target_rate = round(getattr(bucket, bucket_type).target_rate, 1)
198
+ table.add_row(
199
+ f" Empirical {rate_name} (target = {target_rate})",
200
+ str(round(getattr(bucket, bucket_type).get_throughput(), 0)),
201
+ )
202
+
203
+ return table
204
+
205
+ def generate_layout(self):
206
+ progress = self.create_progress_bar()
207
+ task_ids = []
208
+ for model in self.distinct_models:
209
+ task_id = progress.add_task(
210
+ f"[cyan]{model}...",
211
+ total=int(self.num_total_interviews / len(self.distinct_models)),
212
+ )
213
+ task_ids.append((model, task_id))
214
+
215
+ progress_height = min(5, 2 + len(self.distinct_models))
216
+ layout = Layout()
217
+
218
+ # Create the top row with only the progress panel
219
+ layout.split_column(
220
+ Layout(
221
+ Panel(
222
+ progress,
223
+ title="Interview Progress",
224
+ border_style="cyan",
225
+ box=box.ROUNDED,
226
+ ),
227
+ name="progress",
228
+ size=progress_height, # Adjusted size
229
+ ),
230
+ Layout(name="bottom_row"), # Adjusted size
231
+ )
232
+
233
+ # Split the bottom row into two columns for metrics and model queues
234
+ layout["bottom_row"].split_row(
235
+ Layout(
236
+ Panel(
237
+ self.generate_metrics_table(),
238
+ title="Metrics",
239
+ border_style="magenta",
240
+ box=box.ROUNDED,
241
+ ),
242
+ name="metrics",
243
+ ),
244
+ Layout(
245
+ Panel(
246
+ self.generate_model_queues_table(),
247
+ title="Model Queues",
248
+ border_style="yellow",
249
+ box=box.ROUNDED,
250
+ ),
251
+ name="model_queues",
252
+ ),
253
+ )
254
+
255
+ return layout, progress, task_ids
256
+
257
+ def generate_metrics_table(self):
258
+ table = Table(show_header=True, header_style="bold magenta", box=box.SIMPLE)
259
+ table.add_column("Metric", style="cyan", no_wrap=True)
260
+ table.add_column("Value", justify="right")
261
+
262
+ for stat_name in self.statistics:
263
+ pretty_name, value = list(self._compute_statistic(stat_name).items())[0]
264
+ # breakpoint()
265
+ table.add_row(pretty_name, value)
266
+ return table
267
+
268
+ def update_progress(self, stop_event):
269
+ layout, progress, task_ids = self.generate_layout()
270
+
271
+ with Live(
272
+ layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
273
+ ) as live:
274
+ while (
275
+ len(self.completed_interviews) < len(self.jobs_runner.total_interviews)
276
+ and not stop_event.is_set()
277
+ ):
278
+ completed_tasks = len(self.completed_interviews)
279
+ total_tasks = len(self.jobs_runner.total_interviews)
280
+
281
+ for model, task_id in task_ids:
282
+ completed_tasks = len(self.completed_interview_by_model[model])
283
+ progress.update(
284
+ task_id,
285
+ completed=completed_tasks,
286
+ description=f"[cyan]Conducting interviews for {model}...",
287
+ )
288
+
289
+ layout["metrics"].update(
290
+ Panel(
291
+ self.generate_metrics_table(),
292
+ title="Metrics",
293
+ border_style="magenta",
294
+ box=box.ROUNDED,
295
+ )
296
+ )
297
+ layout["model_queues"].update(
298
+ Panel(
299
+ self.generate_model_queues_table(),
300
+ title="Final Model Queues",
301
+ border_style="yellow",
302
+ box=box.ROUNDED,
303
+ )
304
+ )
305
+
306
+ time.sleep(self.refresh_rate)
307
+
308
+ # Final update
309
+ for model, task_id in task_ids:
310
+ completed_tasks = len(self.completed_interview_by_model[model])
311
+ progress.update(
312
+ task_id,
313
+ completed=completed_tasks,
314
+ description=f"[cyan]Conducting interviews for {model}...",
315
+ )
316
+
317
+ layout["metrics"].update(
318
+ Panel(
319
+ self.generate_metrics_table(),
320
+ title="Final Metrics",
321
+ border_style="magenta",
322
+ box=box.ROUNDED,
323
+ )
324
+ )
325
+ live.update(layout)
326
+ time.sleep(1) # Show final state for 1 second
327
+
328
+
329
+ if __name__ == "__main__":
330
+ import doctest
331
+
332
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -156,19 +156,6 @@ class QuestionTaskCreator(UserList):
156
156
  self.tokens_bucket.turbo_mode_off()
157
157
  self.requests_bucket.turbo_mode_off()
158
158
 
159
- # breakpoint()
160
- # _ = results.pop("cached_response", None)
161
-
162
- # tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
163
-
164
- # TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
165
- # usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
166
- # prompt_tokens = usage.get("prompt_tokens", 0)
167
- # completion_tokens = usage.get("completion_tokens", 0)
168
- # tracker.add_tokens(
169
- # prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
170
- # )
171
-
172
159
  return results
173
160
 
174
161
  @classmethod
@@ -249,6 +236,7 @@ class QuestionTaskCreator(UserList):
249
236
  f"Required tasks failed for {self.question.question_name}"
250
237
  ) from e
251
238
 
239
+ # this only runs if all the dependencies are successful
252
240
  return await self._run_focal_task()
253
241
 
254
242
 
@@ -50,6 +50,18 @@ class TaskHistory:
50
50
  """
51
51
  return [i.exceptions for k, i in self._interviews.items() if i.exceptions != {}]
52
52
 
53
+ @property
54
+ def unfixed_exceptions(self):
55
+ """
56
+ >>> len(TaskHistory.example().unfixed_exceptions)
57
+ 4
58
+ """
59
+ return [
60
+ i.exceptions
61
+ for k, i in self._interviews.items()
62
+ if i.exceptions.num_unfixed() > 0
63
+ ]
64
+
53
65
  @property
54
66
  def indices(self):
55
67
  return [k for k, i in self._interviews.items() if i.exceptions != {}]
@@ -78,6 +90,11 @@ class TaskHistory:
78
90
  """
79
91
  return len(self.exceptions) > 0
80
92
 
93
+ @property
94
+ def has_unfixed_exceptions(self) -> bool:
95
+ """Return True if there are any exceptions."""
96
+ return len(self.unfixed_exceptions) > 0
97
+
81
98
  def _repr_html_(self):
82
99
  """Return an HTML representation of the TaskHistory."""
83
100
  from edsl.utilities.utilities import data_to_html
@@ -164,20 +164,20 @@ class LanguageModel(
164
164
  None # This should be something like ["choices", 0, "message", "content"]
165
165
  )
166
166
  __rate_limits = None
167
- __default_rate_limits = {
168
- "rpm": 10_000,
169
- "tpm": 2_000_000,
170
- } # TODO: Use the OpenAI Teir 1 rate limits
171
167
  _safety_factor = 0.8
172
168
 
173
- def __init__(self, tpm=None, rpm=None, **kwargs):
169
+ def __init__(
170
+ self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
171
+ ):
174
172
  """Initialize the LanguageModel."""
175
173
  self.model = getattr(self, "_model_", None)
176
174
  default_parameters = getattr(self, "_parameters_", None)
177
175
  parameters = self._overide_default_parameters(kwargs, default_parameters)
178
176
  self.parameters = parameters
179
177
  self.remote = False
178
+ self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
180
179
 
180
+ # self._rpm / _tpm comes from the class
181
181
  if rpm is not None:
182
182
  self._rpm = rpm
183
183
 
@@ -286,35 +286,40 @@ class LanguageModel(
286
286
  >>> m.RPM
287
287
  100
288
288
  """
289
- self._set_rate_limits(rpm=rpm, tpm=tpm)
289
+ if rpm is not None:
290
+ self._rpm = rpm
291
+ if tpm is not None:
292
+ self._tpm = tpm
293
+ return None
294
+ # self._set_rate_limits(rpm=rpm, tpm=tpm)
290
295
 
291
- def _set_rate_limits(self, rpm=None, tpm=None) -> None:
292
- """Set the rate limits for the model.
296
+ # def _set_rate_limits(self, rpm=None, tpm=None) -> None:
297
+ # """Set the rate limits for the model.
293
298
 
294
- If the model does not have rate limits, use the default rate limits."""
295
- if rpm is not None and tpm is not None:
296
- self.__rate_limits = {"rpm": rpm, "tpm": tpm}
297
- return
299
+ # If the model does not have rate limits, use the default rate limits."""
300
+ # if rpm is not None and tpm is not None:
301
+ # self.__rate_limits = {"rpm": rpm, "tpm": tpm}
302
+ # return
298
303
 
299
- if self.__rate_limits is None:
300
- if hasattr(self, "get_rate_limits"):
301
- self.__rate_limits = self.get_rate_limits()
302
- else:
303
- self.__rate_limits = self.__default_rate_limits
304
+ # if self.__rate_limits is None:
305
+ # if hasattr(self, "get_rate_limits"):
306
+ # self.__rate_limits = self.get_rate_limits()
307
+ # else:
308
+ # self.__rate_limits = self.__default_rate_limits
304
309
 
305
310
  @property
306
311
  def RPM(self):
307
312
  """Model's requests-per-minute limit."""
308
313
  # self._set_rate_limits()
309
314
  # return self._safety_factor * self.__rate_limits["rpm"]
310
- return self.rpm
315
+ return self._rpm
311
316
 
312
317
  @property
313
318
  def TPM(self):
314
319
  """Model's tokens-per-minute limit."""
315
320
  # self._set_rate_limits()
316
321
  # return self._safety_factor * self.__rate_limits["tpm"]
317
- return self.tpm
322
+ return self._tpm
318
323
 
319
324
  @property
320
325
  def rpm(self):
@@ -332,17 +337,6 @@ class LanguageModel(
332
337
  def tpm(self, value):
333
338
  self._tpm = value
334
339
 
335
- @property
336
- def TPM(self):
337
- """Model's tokens-per-minute limit.
338
-
339
- >>> m = LanguageModel.example()
340
- >>> m.TPM > 0
341
- True
342
- """
343
- self._set_rate_limits()
344
- return self._safety_factor * self.__rate_limits["tpm"]
345
-
346
340
  @staticmethod
347
341
  def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
348
342
  """Return a dictionary of parameters, with passed parameters taking precedence over defaults.
@@ -446,7 +440,7 @@ class LanguageModel(
446
440
  system_prompt: str,
447
441
  cache: "Cache",
448
442
  iteration: int = 0,
449
- encoded_image=None,
443
+ files_list=None,
450
444
  ) -> ModelResponse:
451
445
  """Handle caching of responses.
452
446
 
@@ -468,15 +462,18 @@ class LanguageModel(
468
462
  >>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
469
463
  ModelResponse(...)"""
470
464
 
471
- if encoded_image:
472
- # the image has is appended to the user_prompt for hash-lookup purposes
473
- image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
465
+ if files_list:
466
+ files_hash = "+".join([str(hash(file)) for file in files_list])
467
+ # print(f"Files hash: {files_hash}")
468
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
469
+ else:
470
+ user_prompt_with_hashes = user_prompt
474
471
 
475
472
  cache_call_params = {
476
473
  "model": str(self.model),
477
474
  "parameters": self.parameters,
478
475
  "system_prompt": system_prompt,
479
- "user_prompt": user_prompt + "" if not encoded_image else f" {image_hash}",
476
+ "user_prompt": user_prompt_with_hashes,
480
477
  "iteration": iteration,
481
478
  }
482
479
  cached_response, cache_key = cache.fetch(**cache_call_params)
@@ -492,7 +489,8 @@ class LanguageModel(
492
489
  params = {
493
490
  "user_prompt": user_prompt,
494
491
  "system_prompt": system_prompt,
495
- **({"encoded_image": encoded_image} if encoded_image else {}),
492
+ "files_list": files_list
493
+ # **({"encoded_image": encoded_image} if encoded_image else {}),
496
494
  }
497
495
  # response = await f(**params)
498
496
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
@@ -536,7 +534,7 @@ class LanguageModel(
536
534
  system_prompt: str,
537
535
  cache: "Cache",
538
536
  iteration: int = 1,
539
- encoded_image=None,
537
+ files_list: Optional[List["File"]] = None,
540
538
  ) -> dict:
541
539
  """Get response, parse, and return as string.
542
540
 
@@ -552,7 +550,7 @@ class LanguageModel(
552
550
  "system_prompt": system_prompt,
553
551
  "iteration": iteration,
554
552
  "cache": cache,
555
- **({"encoded_image": encoded_image} if encoded_image else {}),
553
+ "files_list": files_list,
556
554
  }
557
555
  model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
558
556
  model_outputs = await self._async_get_intended_model_call_outcome(**params)
@@ -2,10 +2,10 @@ import textwrap
2
2
  from random import random
3
3
  from edsl.config import CONFIG
4
4
 
5
- if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
- default_model = "test"
7
- else:
8
- default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
5
+ # if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
+ # default_model = "test"
7
+ # else:
8
+ # default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
9
9
 
10
10
 
11
11
  def get_model_class(model_name, registry=None):
@@ -33,20 +33,24 @@ class Meta(type):
33
33
 
34
34
 
35
35
  class Model(metaclass=Meta):
36
- default_model = default_model
36
+ default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
37
37
 
38
- def __new__(cls, model_name=None, registry=None, *args, **kwargs):
38
+ def __new__(
39
+ cls, model_name=None, registry=None, service_name=None, *args, **kwargs
40
+ ):
39
41
  # Map index to the respective subclass
40
42
  if model_name is None:
41
- model_name = cls.default_model
43
+ model_name = (
44
+ cls.default_model
45
+ ) # when model_name is None, use the default model, set in the config file
42
46
  from edsl.inference_services.registry import default
43
47
 
44
48
  registry = registry or default
45
49
 
46
- if isinstance(model_name, int):
50
+ if isinstance(model_name, int): # can refer to a model by index
47
51
  model_name = cls.available(name_only=True)[model_name]
48
52
 
49
- factory = registry.create_model_factory(model_name)
53
+ factory = registry.create_model_factory(model_name, service_name=service_name)
50
54
  return factory(*args, **kwargs)
51
55
 
52
56
  @classmethod
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Any
2
+ from typing import Any, Optional, List
3
3
  from edsl import Survey
4
4
  from edsl.config import CONFIG
5
5
  from edsl.enums import InferenceServiceType
@@ -40,7 +40,10 @@ def create_language_model(
40
40
  _tpm = 1000000000000
41
41
 
42
42
  async def async_execute_model_call(
43
- self, user_prompt: str, system_prompt: str
43
+ self,
44
+ user_prompt: str,
45
+ system_prompt: str,
46
+ files_list: Optional[List[Any]] = None,
44
47
  ) -> dict[str, Any]:
45
48
  question_number = int(
46
49
  user_prompt.split("XX")[1]