edsl 0.1.29__py3-none-any.whl → 0.1.29.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. edsl/Base.py +18 -18
  2. edsl/__init__.py +23 -23
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +41 -77
  5. edsl/agents/AgentList.py +9 -19
  6. edsl/agents/Invigilator.py +1 -19
  7. edsl/agents/InvigilatorBase.py +10 -15
  8. edsl/agents/PromptConstructionMixin.py +100 -342
  9. edsl/agents/descriptors.py +1 -2
  10. edsl/config.py +1 -2
  11. edsl/conjure/InputData.py +8 -39
  12. edsl/coop/coop.py +150 -187
  13. edsl/coop/utils.py +75 -43
  14. edsl/data/Cache.py +5 -19
  15. edsl/data/SQLiteDict.py +3 -11
  16. edsl/jobs/Answers.py +1 -15
  17. edsl/jobs/Jobs.py +46 -90
  18. edsl/jobs/buckets/ModelBuckets.py +2 -4
  19. edsl/jobs/buckets/TokenBucket.py +2 -1
  20. edsl/jobs/interviews/Interview.py +9 -3
  21. edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
  22. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +10 -15
  23. edsl/jobs/runners/JobsRunnerAsyncio.py +25 -21
  24. edsl/jobs/tasks/TaskHistory.py +3 -4
  25. edsl/language_models/LanguageModel.py +11 -5
  26. edsl/language_models/ModelList.py +1 -1
  27. edsl/language_models/repair.py +7 -8
  28. edsl/notebooks/Notebook.py +3 -40
  29. edsl/prompts/Prompt.py +19 -31
  30. edsl/questions/QuestionBase.py +13 -38
  31. edsl/questions/QuestionBudget.py +6 -5
  32. edsl/questions/QuestionCheckBox.py +3 -7
  33. edsl/questions/QuestionExtract.py +3 -5
  34. edsl/questions/QuestionFreeText.py +3 -3
  35. edsl/questions/QuestionFunctional.py +3 -0
  36. edsl/questions/QuestionList.py +4 -3
  37. edsl/questions/QuestionMultipleChoice.py +8 -16
  38. edsl/questions/QuestionNumerical.py +3 -4
  39. edsl/questions/QuestionRank.py +3 -5
  40. edsl/questions/__init__.py +3 -4
  41. edsl/questions/descriptors.py +2 -4
  42. edsl/questions/question_registry.py +31 -20
  43. edsl/questions/settings.py +1 -1
  44. edsl/results/Dataset.py +0 -31
  45. edsl/results/Result.py +74 -22
  46. edsl/results/Results.py +47 -97
  47. edsl/results/ResultsDBMixin.py +3 -7
  48. edsl/results/ResultsExportMixin.py +537 -22
  49. edsl/results/ResultsGGMixin.py +3 -3
  50. edsl/results/ResultsToolsMixin.py +5 -5
  51. edsl/scenarios/Scenario.py +6 -5
  52. edsl/scenarios/ScenarioList.py +11 -34
  53. edsl/scenarios/ScenarioListPdfMixin.py +1 -2
  54. edsl/scenarios/__init__.py +0 -1
  55. edsl/study/Study.py +9 -3
  56. edsl/surveys/MemoryPlan.py +4 -11
  57. edsl/surveys/Survey.py +7 -46
  58. edsl/surveys/SurveyExportMixin.py +2 -4
  59. edsl/surveys/SurveyFlowVisualizationMixin.py +4 -6
  60. edsl/tools/plotting.py +2 -4
  61. edsl/utilities/__init__.py +21 -21
  62. edsl/utilities/interface.py +45 -66
  63. edsl/utilities/utilities.py +13 -11
  64. {edsl-0.1.29.dist-info → edsl-0.1.29.dev2.dist-info}/METADATA +10 -11
  65. {edsl-0.1.29.dist-info → edsl-0.1.29.dev2.dist-info}/RECORD +68 -71
  66. edsl-0.1.29.dev2.dist-info/entry_points.txt +3 -0
  67. edsl/base/Base.py +0 -289
  68. edsl/results/DatasetExportMixin.py +0 -493
  69. edsl/scenarios/FileStore.py +0 -140
  70. edsl/scenarios/ScenarioListExportMixin.py +0 -32
  71. {edsl-0.1.29.dist-info → edsl-0.1.29.dev2.dist-info}/LICENSE +0 -0
  72. {edsl-0.1.29.dist-info → edsl-0.1.29.dev2.dist-info}/WHEEL +0 -0
edsl/data/Cache.py CHANGED
@@ -7,13 +7,13 @@ import json
7
7
  import os
8
8
  import warnings
9
9
  from typing import Optional, Union
10
- import time
10
+
11
11
  from edsl.config import CONFIG
12
12
  from edsl.data.CacheEntry import CacheEntry
13
-
14
- # from edsl.data.SQLiteDict import SQLiteDict
13
+ from edsl.data.SQLiteDict import SQLiteDict
15
14
  from edsl.Base import Base
16
15
  from edsl.utilities.utilities import dict_hash
16
+
17
17
  from edsl.utilities.decorators import (
18
18
  add_edsl_version,
19
19
  remove_edsl_version,
@@ -38,7 +38,7 @@ class Cache(Base):
38
38
  self,
39
39
  *,
40
40
  filename: Optional[str] = None,
41
- data: Optional[Union["SQLiteDict", dict]] = None,
41
+ data: Optional[Union[SQLiteDict, dict]] = None,
42
42
  immediate_write: bool = True,
43
43
  method=None,
44
44
  ):
@@ -104,8 +104,6 @@ class Cache(Base):
104
104
 
105
105
  def _perform_checks(self):
106
106
  """Perform checks on the cache."""
107
- from edsl.data.CacheEntry import CacheEntry
108
-
109
107
  if any(not isinstance(value, CacheEntry) for value in self.data.values()):
110
108
  raise Exception("Not all values are CacheEntry instances")
111
109
  if self.method is not None:
@@ -140,8 +138,6 @@ class Cache(Base):
140
138
 
141
139
 
142
140
  """
143
- from edsl.data.CacheEntry import CacheEntry
144
-
145
141
  key = CacheEntry.gen_key(
146
142
  model=model,
147
143
  parameters=parameters,
@@ -175,7 +171,6 @@ class Cache(Base):
175
171
  * If `immediate_write` is True , the key-value pair is added to `self.data`
176
172
  * If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
177
173
  """
178
-
179
174
  entry = CacheEntry(
180
175
  model=model,
181
176
  parameters=parameters,
@@ -193,14 +188,13 @@ class Cache(Base):
193
188
  return key
194
189
 
195
190
  def add_from_dict(
196
- self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
191
+ self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
197
192
  ) -> None:
198
193
  """
199
194
  Add entries to the cache from a dictionary.
200
195
 
201
196
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
202
197
  """
203
-
204
198
  for key, value in new_data.items():
205
199
  if key in self.data:
206
200
  if value != self.data[key]:
@@ -237,8 +231,6 @@ class Cache(Base):
237
231
 
238
232
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
239
233
  """
240
- from edsl.data.SQLiteDict import SQLiteDict
241
-
242
234
  db = SQLiteDict(db_path)
243
235
  new_data = {}
244
236
  for key, value in db.items():
@@ -250,8 +242,6 @@ class Cache(Base):
250
242
  """
251
243
  Construct a Cache from a SQLite database.
252
244
  """
253
- from edsl.data.SQLiteDict import SQLiteDict
254
-
255
245
  return cls(data=SQLiteDict(db_path))
256
246
 
257
247
  @classmethod
@@ -278,8 +268,6 @@ class Cache(Base):
278
268
  * If `db_path` is provided, the cache will be stored in an SQLite database.
279
269
  """
280
270
  # if a file doesn't exist at jsonfile, throw an error
281
- from edsl.data.SQLiteDict import SQLiteDict
282
-
283
271
  if not os.path.exists(jsonlfile):
284
272
  raise FileNotFoundError(f"File {jsonlfile} not found")
285
273
 
@@ -298,8 +286,6 @@ class Cache(Base):
298
286
  """
299
287
  ## TODO: Check to make sure not over-writing (?)
300
288
  ## Should be added to SQLiteDict constructor (?)
301
- from edsl.data.SQLiteDict import SQLiteDict
302
-
303
289
  new_data = SQLiteDict(db_path)
304
290
  for key, value in self.data.items():
305
291
  new_data[key] = value
edsl/data/SQLiteDict.py CHANGED
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
2
  import json
3
+ from sqlalchemy import create_engine
4
+ from sqlalchemy.exc import SQLAlchemyError
5
+ from sqlalchemy.orm import sessionmaker
3
6
  from typing import Any, Generator, Optional, Union
4
-
5
7
  from edsl.config import CONFIG
6
8
  from edsl.data.CacheEntry import CacheEntry
7
9
  from edsl.data.orm import Base, Data
@@ -23,16 +25,10 @@ class SQLiteDict:
23
25
  >>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
24
26
 
25
27
  """
26
- from sqlalchemy.exc import SQLAlchemyError
27
- from sqlalchemy.orm import sessionmaker
28
- from sqlalchemy import create_engine
29
-
30
28
  self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
31
29
  if not self.db_path.startswith("sqlite:///"):
32
30
  self.db_path = f"sqlite:///{self.db_path}"
33
31
  try:
34
- from edsl.data.orm import Base, Data
35
-
36
32
  self.engine = create_engine(self.db_path, echo=False, future=True)
37
33
  Base.metadata.create_all(self.engine)
38
34
  self.Session = sessionmaker(bind=self.engine)
@@ -59,8 +55,6 @@ class SQLiteDict:
59
55
  if not isinstance(value, CacheEntry):
60
56
  raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
61
57
  with self.Session() as db:
62
- from edsl.data.orm import Base, Data
63
-
64
58
  db.merge(Data(key=key, value=json.dumps(value.to_dict())))
65
59
  db.commit()
66
60
 
@@ -75,8 +69,6 @@ class SQLiteDict:
75
69
  True
76
70
  """
77
71
  with self.Session() as db:
78
- from edsl.data.orm import Base, Data
79
-
80
72
  value = db.query(Data).filter_by(key=key).first()
81
73
  if not value:
82
74
  raise KeyError(f"Key '{key}' not found.")
edsl/jobs/Answers.py CHANGED
@@ -8,15 +8,7 @@ class Answers(UserDict):
8
8
  """Helper class to hold the answers to a survey."""
9
9
 
10
10
  def add_answer(self, response, question) -> None:
11
- """Add a response to the answers dictionary.
12
-
13
- >>> from edsl import QuestionFreeText
14
- >>> q = QuestionFreeText.example()
15
- >>> answers = Answers()
16
- >>> answers.add_answer({"answer": "yes"}, q)
17
- >>> answers[q.question_name]
18
- 'yes'
19
- """
11
+ """Add a response to the answers dictionary."""
20
12
  answer = response.get("answer")
21
13
  comment = response.pop("comment", None)
22
14
  # record the answer
@@ -50,9 +42,3 @@ class Answers(UserDict):
50
42
  table.add_row(attr_name, repr(attr_value))
51
43
 
52
44
  return table
53
-
54
-
55
- if __name__ == "__main__":
56
- import doctest
57
-
58
- doctest.testmod()
edsl/jobs/Jobs.py CHANGED
@@ -1,15 +1,30 @@
1
1
  # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
2
  from __future__ import annotations
3
+ import os
3
4
  import warnings
4
5
  from itertools import product
5
6
  from typing import Optional, Union, Sequence, Generator
6
-
7
+ from edsl import Model
8
+ from edsl.agents import Agent
9
+ from edsl.agents.AgentList import AgentList
7
10
  from edsl.Base import Base
11
+ from edsl.data.Cache import Cache
12
+ from edsl.data.CacheHandler import CacheHandler
13
+ from edsl.results.Dataset import Dataset
8
14
 
15
+ from edsl.exceptions.jobs import MissingRemoteInferenceError
9
16
  from edsl.exceptions import MissingAPIKeyError
10
17
  from edsl.jobs.buckets.BucketCollection import BucketCollection
11
18
  from edsl.jobs.interviews.Interview import Interview
19
+ from edsl.language_models import LanguageModel
20
+ from edsl.results import Results
21
+ from edsl.scenarios import Scenario
22
+ from edsl import ScenarioList
23
+ from edsl.surveys import Survey
12
24
  from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
25
+
26
+ from edsl.language_models.ModelList import ModelList
27
+
13
28
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
14
29
 
15
30
 
@@ -22,10 +37,10 @@ class Jobs(Base):
22
37
 
23
38
  def __init__(
24
39
  self,
25
- survey: "Survey",
26
- agents: Optional[list["Agent"]] = None,
27
- models: Optional[list["LanguageModel"]] = None,
28
- scenarios: Optional[list["Scenario"]] = None,
40
+ survey: Survey,
41
+ agents: Optional[list[Agent]] = None,
42
+ models: Optional[list[LanguageModel]] = None,
43
+ scenarios: Optional[list[Scenario]] = None,
29
44
  ):
30
45
  """Initialize a Jobs instance.
31
46
 
@@ -35,8 +50,8 @@ class Jobs(Base):
35
50
  :param scenarios: a list of scenarios
36
51
  """
37
52
  self.survey = survey
38
- self.agents: "AgentList" = agents
39
- self.scenarios: "ScenarioList" = scenarios
53
+ self.agents: AgentList = agents
54
+ self.scenarios: ScenarioList = scenarios
40
55
  self.models = models
41
56
 
42
57
  self.__bucket_collection = None
@@ -47,8 +62,6 @@ class Jobs(Base):
47
62
 
48
63
  @models.setter
49
64
  def models(self, value):
50
- from edsl import ModelList
51
-
52
65
  if value:
53
66
  if not isinstance(value, ModelList):
54
67
  self._models = ModelList(value)
@@ -63,8 +76,6 @@ class Jobs(Base):
63
76
 
64
77
  @agents.setter
65
78
  def agents(self, value):
66
- from edsl import AgentList
67
-
68
79
  if value:
69
80
  if not isinstance(value, AgentList):
70
81
  self._agents = AgentList(value)
@@ -79,8 +90,6 @@ class Jobs(Base):
79
90
 
80
91
  @scenarios.setter
81
92
  def scenarios(self, value):
82
- from edsl import ScenarioList
83
-
84
93
  if value:
85
94
  if not isinstance(value, ScenarioList):
86
95
  self._scenarios = ScenarioList(value)
@@ -92,10 +101,10 @@ class Jobs(Base):
92
101
  def by(
93
102
  self,
94
103
  *args: Union[
95
- "Agent",
96
- "Scenario",
97
- "LanguageModel",
98
- Sequence[Union["Agent", "Scenario", "LanguageModel"]],
104
+ Agent,
105
+ Scenario,
106
+ LanguageModel,
107
+ Sequence[Union[Agent, Scenario, LanguageModel]],
99
108
  ],
100
109
  ) -> Jobs:
101
110
  """
@@ -135,7 +144,7 @@ class Jobs(Base):
135
144
  setattr(self, objects_key, new_objects) # update the job
136
145
  return self
137
146
 
138
- def prompts(self) -> "Dataset":
147
+ def prompts(self) -> Dataset:
139
148
  """Return a Dataset of prompts that will be used.
140
149
 
141
150
 
@@ -151,7 +160,6 @@ class Jobs(Base):
151
160
  user_prompts = []
152
161
  system_prompts = []
153
162
  scenario_indices = []
154
- from edsl.results.Dataset import Dataset
155
163
 
156
164
  for interview_index, interview in enumerate(interviews):
157
165
  invigilators = list(interview._build_invigilators(debug=False))
@@ -174,10 +182,7 @@ class Jobs(Base):
174
182
 
175
183
  @staticmethod
176
184
  def _get_container_class(object):
177
- from edsl.agents.AgentList import AgentList
178
- from edsl.agents.Agent import Agent
179
- from edsl.scenarios.Scenario import Scenario
180
- from edsl.scenarios.ScenarioList import ScenarioList
185
+ from edsl import AgentList
181
186
 
182
187
  if isinstance(object, Agent):
183
188
  return AgentList
@@ -213,10 +218,6 @@ class Jobs(Base):
213
218
  def _get_current_objects_of_this_type(
214
219
  self, object: Union[Agent, Scenario, LanguageModel]
215
220
  ) -> tuple[list, str]:
216
- from edsl.agents.Agent import Agent
217
- from edsl.scenarios.Scenario import Scenario
218
- from edsl.language_models.LanguageModel import LanguageModel
219
-
220
221
  """Return the current objects of the same type as the first argument.
221
222
 
222
223
  >>> from edsl.jobs import Jobs
@@ -245,9 +246,6 @@ class Jobs(Base):
245
246
  @staticmethod
246
247
  def _get_empty_container_object(object):
247
248
  from edsl import AgentList
248
- from edsl import Agent
249
- from edsl import Scenario
250
- from edsl import ScenarioList
251
249
 
252
250
  if isinstance(object, Agent):
253
251
  return AgentList([])
@@ -312,12 +310,12 @@ class Jobs(Base):
312
310
  with us filling in defaults.
313
311
  """
314
312
  # if no agents, models, or scenarios are set, set them to defaults
315
- from edsl.agents.Agent import Agent
316
- from edsl.language_models.registry import Model
317
- from edsl.scenarios.Scenario import Scenario
318
-
319
313
  self.agents = self.agents or [Agent()]
320
314
  self.models = self.models or [Model()]
315
+ # if remote, set all the models to remote
316
+ if hasattr(self, "remote") and self.remote:
317
+ for model in self.models:
318
+ model.remote = True
321
319
  self.scenarios = self.scenarios or [Scenario()]
322
320
  for agent, scenario, model in product(self.agents, self.scenarios, self.models):
323
321
  yield Interview(
@@ -331,7 +329,6 @@ class Jobs(Base):
331
329
  These buckets are used to track API calls and token usage.
332
330
 
333
331
  >>> from edsl.jobs import Jobs
334
- >>> from edsl import Model
335
332
  >>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
336
333
  >>> bc = j.create_bucket_collection()
337
334
  >>> bc
@@ -371,16 +368,14 @@ class Jobs(Base):
371
368
  if self.verbose:
372
369
  print(message)
373
370
 
374
- def _check_parameters(self, strict=False, warn=False) -> None:
371
+ def _check_parameters(self, strict=False, warn = True) -> None:
375
372
  """Check if the parameters in the survey and scenarios are consistent.
376
373
 
377
374
  >>> from edsl import QuestionFreeText
378
- >>> from edsl import Survey
379
- >>> from edsl import Scenario
380
375
  >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
381
376
  >>> j = Jobs(survey = Survey(questions=[q]))
382
377
  >>> with warnings.catch_warnings(record=True) as w:
383
- ... j._check_parameters(warn = True)
378
+ ... j._check_parameters()
384
379
  ... assert len(w) == 1
385
380
  ... assert issubclass(w[-1].category, UserWarning)
386
381
  ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
@@ -418,13 +413,15 @@ class Jobs(Base):
418
413
  progress_bar: bool = False,
419
414
  stop_on_exception: bool = False,
420
415
  cache: Union[Cache, bool] = None,
416
+ remote: bool = (
417
+ False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
418
+ ),
421
419
  check_api_keys: bool = False,
422
420
  sidecar_model: Optional[LanguageModel] = None,
423
421
  batch_mode: Optional[bool] = None,
424
422
  verbose: bool = False,
425
423
  print_exceptions=True,
426
424
  remote_cache_description: Optional[str] = None,
427
- remote_inference_description: Optional[str] = None,
428
425
  ) -> Results:
429
426
  """
430
427
  Runs the Job: conducts Interviews and returns their results.
@@ -434,11 +431,11 @@ class Jobs(Base):
434
431
  :param progress_bar: shows a progress bar
435
432
  :param stop_on_exception: stops the job if an exception is raised
436
433
  :param cache: a cache object to store results
434
+ :param remote: run the job remotely
437
435
  :param check_api_keys: check if the API keys are valid
438
436
  :param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
439
437
  :param verbose: prints messages
440
438
  :param remote_cache_description: specifies a description for this group of entries in the remote cache
441
- :param remote_inference_description: specifies a description for the remote inference job
442
439
  """
443
440
  from edsl.coop.coop import Coop
444
441
 
@@ -449,50 +446,21 @@ class Jobs(Base):
449
446
  "Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
450
447
  )
451
448
 
449
+ self.remote = remote
452
450
  self.verbose = verbose
453
451
 
454
452
  try:
455
453
  coop = Coop()
456
- user_edsl_settings = coop.edsl_settings
457
- remote_cache = user_edsl_settings["remote_caching"]
458
- remote_inference = user_edsl_settings["remote_inference"]
454
+ remote_cache = coop.edsl_settings["remote_caching"]
459
455
  except Exception:
460
456
  remote_cache = False
461
- remote_inference = False
462
457
 
463
- if remote_inference:
464
- self._output("Remote inference activated. Sending job to server...")
465
- if remote_cache:
466
- self._output(
467
- "Remote caching activated. The remote cache will be used for this job."
468
- )
458
+ if self.remote:
459
+ ## TODO: This should be a coop check
460
+ if os.getenv("EXPECTED_PARROT_API_KEY", None) is None:
461
+ raise MissingRemoteInferenceError()
469
462
 
470
- remote_job_data = coop.remote_inference_create(
471
- self,
472
- description=remote_inference_description,
473
- status="queued",
474
- )
475
- self._output("Job sent!")
476
- # Create mock results object to store job data
477
- results = Results(
478
- survey=Survey(),
479
- data=[
480
- Result(
481
- agent=Agent.example(),
482
- scenario=Scenario.example(),
483
- model=Model(),
484
- iteration=1,
485
- answer={"info": "Remote job details"},
486
- )
487
- ],
488
- )
489
- results.add_columns_from_dict([remote_job_data])
490
- if self.verbose:
491
- results.select(["info", "uuid", "status", "version"]).print(
492
- format="rich"
493
- )
494
- return results
495
- else:
463
+ if not self.remote:
496
464
  if check_api_keys:
497
465
  for model in self.models + [Model()]:
498
466
  if not model.has_valid_api_key():
@@ -503,12 +471,8 @@ class Jobs(Base):
503
471
 
504
472
  # handle cache
505
473
  if cache is None:
506
- from edsl.data.CacheHandler import CacheHandler
507
-
508
474
  cache = CacheHandler().get_cache()
509
475
  if cache is False:
510
- from edsl.data.Cache import Cache
511
-
512
476
  cache = Cache()
513
477
 
514
478
  if not remote_cache:
@@ -660,11 +624,6 @@ class Jobs(Base):
660
624
  @remove_edsl_version
661
625
  def from_dict(cls, data: dict) -> Jobs:
662
626
  """Creates a Jobs instance from a dictionary."""
663
- from edsl import Survey
664
- from edsl.agents.Agent import Agent
665
- from edsl.language_models.LanguageModel import LanguageModel
666
- from edsl.scenarios.Scenario import Scenario
667
-
668
627
  return cls(
669
628
  survey=Survey.from_dict(data["survey"]),
670
629
  agents=[Agent.from_dict(agent) for agent in data["agents"]],
@@ -691,8 +650,7 @@ class Jobs(Base):
691
650
  """
692
651
  import random
693
652
  from edsl.questions import QuestionMultipleChoice
694
- from edsl.agents.Agent import Agent
695
- from edsl.scenarios.Scenario import Scenario
653
+ from edsl import Agent
696
654
 
697
655
  # (status, question, period)
698
656
  agent_answers = {
@@ -731,8 +689,6 @@ class Jobs(Base):
731
689
  question_options=["Good", "Great", "OK", "Terrible"],
732
690
  question_name="how_feeling_yesterday",
733
691
  )
734
- from edsl import Survey, ScenarioList
735
-
736
692
  base_survey = Survey(questions=[q1, q2])
737
693
 
738
694
  scenario_list = ScenarioList(
@@ -1,4 +1,4 @@
1
- # from edsl.jobs.buckets.TokenBucket import TokenBucket
1
+ from edsl.jobs.buckets.TokenBucket import TokenBucket
2
2
 
3
3
 
4
4
  class ModelBuckets:
@@ -8,7 +8,7 @@ class ModelBuckets:
8
8
  A request is one call to the service. The number of tokens required for a request depends on parameters.
9
9
  """
10
10
 
11
- def __init__(self, requests_bucket: "TokenBucket", tokens_bucket: "TokenBucket"):
11
+ def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
12
12
  """Initialize the model buckets.
13
13
 
14
14
  The requests bucket captures requests per unit of time.
@@ -28,8 +28,6 @@ class ModelBuckets:
28
28
  @classmethod
29
29
  def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
30
30
  """Create a bucket with infinite capacity and refill rate."""
31
- from edsl.jobs.buckets.TokenBucket import TokenBucket
32
-
33
31
  return cls(
34
32
  requests_bucket=TokenBucket(
35
33
  bucket_name=model_name,
@@ -1,6 +1,8 @@
1
1
  from typing import Union, List, Any
2
2
  import asyncio
3
3
  import time
4
+ from collections import UserDict
5
+ from matplotlib import pyplot as plt
4
6
 
5
7
 
6
8
  class TokenBucket:
@@ -112,7 +114,6 @@ class TokenBucket:
112
114
  times, tokens = zip(*self.get_log())
113
115
  start_time = times[0]
114
116
  times = [t - start_time for t in times] # Normalize time to start from 0
115
- from matplotlib import pyplot as plt
116
117
 
117
118
  plt.figure(figsize=(10, 6))
118
119
  plt.plot(times, tokens, label="Tokens Available")
@@ -6,9 +6,15 @@ import asyncio
6
6
  import time
7
7
  from typing import Any, Type, List, Generator, Optional
8
8
 
9
+ from edsl.agents import Agent
10
+ from edsl.language_models import LanguageModel
11
+ from edsl.scenarios import Scenario
12
+ from edsl.surveys import Survey
13
+
9
14
  from edsl.jobs.Answers import Answers
10
15
  from edsl.surveys.base import EndOfSurvey
11
16
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
17
+
12
18
  from edsl.jobs.tasks.TaskCreators import TaskCreators
13
19
 
14
20
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
@@ -54,9 +60,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
54
60
  self.debug = debug
55
61
  self.iteration = iteration
56
62
  self.cache = cache
57
- self.answers: dict[str, str] = (
58
- Answers()
59
- ) # will get filled in as interview progresses
63
+ self.answers: dict[
64
+ str, str
65
+ ] = Answers() # will get filled in as interview progresses
60
66
  self.sidecar_model = sidecar_model
61
67
 
62
68
  # Trackers
@@ -17,9 +17,9 @@ class InterviewStatusMixin:
17
17
  The keys are the question names; the values are the lists of status log changes for each task.
18
18
  """
19
19
  for task_creator in self.task_creators.values():
20
- self._task_status_log_dict[task_creator.question.question_name] = (
21
- task_creator.status_log
22
- )
20
+ self._task_status_log_dict[
21
+ task_creator.question.question_name
22
+ ] = task_creator.status_log
23
23
  return self._task_status_log_dict
24
24
 
25
25
  @property
@@ -5,19 +5,17 @@ import asyncio
5
5
  import time
6
6
  import traceback
7
7
  from typing import Generator, Union
8
-
9
8
  from edsl import CONFIG
10
9
  from edsl.exceptions import InterviewTimeoutError
11
-
12
- # from edsl.questions.QuestionBase import QuestionBase
10
+ from edsl.data_transfer_models import AgentResponseDict
11
+ from edsl.questions.QuestionBase import QuestionBase
13
12
  from edsl.surveys.base import EndOfSurvey
14
13
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
15
14
  from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
16
15
  from edsl.jobs.interviews.retry_management import retry_strategy
17
16
  from edsl.jobs.tasks.task_status_enum import TaskStatus
18
17
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
19
-
20
- # from edsl.agents.InvigilatorBase import InvigilatorBase
18
+ from edsl.agents.InvigilatorBase import InvigilatorBase
21
19
 
22
20
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
23
21
 
@@ -46,7 +44,6 @@ class InterviewTaskBuildingMixin:
46
44
  scenario=self.scenario,
47
45
  model=self.model,
48
46
  debug=debug,
49
- survey=self.survey,
50
47
  memory_plan=self.survey.memory_plan,
51
48
  current_answers=self.answers,
52
49
  iteration=self.iteration,
@@ -152,17 +149,15 @@ class InterviewTaskBuildingMixin:
152
149
  async def _answer_question_and_record_task(
153
150
  self,
154
151
  *,
155
- question: "QuestionBase",
152
+ question: QuestionBase,
156
153
  debug: bool,
157
154
  task=None,
158
- ) -> "AgentResponseDict":
155
+ ) -> AgentResponseDict:
159
156
  """Answer a question and records the task.
160
157
 
161
158
  This in turn calls the the passed-in agent's async_answer_question method, which returns a response dictionary.
162
159
  Note that is updates answers dictionary with the response.
163
160
  """
164
- from edsl.data_transfer_models import AgentResponseDict
165
-
166
161
  try:
167
162
  invigilator = self._get_invigilator(question, debug=debug)
168
163
 
@@ -258,11 +253,11 @@ class InterviewTaskBuildingMixin:
258
253
  """
259
254
  current_question_index: int = self.to_index[current_question.question_name]
260
255
 
261
- next_question: Union[int, EndOfSurvey] = (
262
- self.survey.rule_collection.next_question(
263
- q_now=current_question_index,
264
- answers=self.answers | self.scenario | self.agent["traits"],
265
- )
256
+ next_question: Union[
257
+ int, EndOfSurvey
258
+ ] = self.survey.rule_collection.next_question(
259
+ q_now=current_question_index,
260
+ answers=self.answers | self.scenario | self.agent["traits"],
266
261
  )
267
262
 
268
263
  next_question_index = next_question.next_q