edsl 0.1.28__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. edsl/Base.py +18 -18
  2. edsl/__init__.py +24 -24
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +77 -41
  5. edsl/agents/AgentList.py +35 -6
  6. edsl/agents/Invigilator.py +19 -1
  7. edsl/agents/InvigilatorBase.py +15 -10
  8. edsl/agents/PromptConstructionMixin.py +342 -100
  9. edsl/agents/descriptors.py +2 -1
  10. edsl/base/Base.py +289 -0
  11. edsl/config.py +2 -1
  12. edsl/conjure/InputData.py +39 -8
  13. edsl/coop/coop.py +188 -151
  14. edsl/coop/utils.py +43 -75
  15. edsl/data/Cache.py +19 -5
  16. edsl/data/SQLiteDict.py +11 -3
  17. edsl/jobs/Answers.py +15 -1
  18. edsl/jobs/Jobs.py +92 -47
  19. edsl/jobs/buckets/ModelBuckets.py +4 -2
  20. edsl/jobs/buckets/TokenBucket.py +1 -2
  21. edsl/jobs/interviews/Interview.py +3 -9
  22. edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
  23. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +15 -10
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +21 -25
  25. edsl/jobs/tasks/TaskHistory.py +4 -3
  26. edsl/language_models/LanguageModel.py +5 -11
  27. edsl/language_models/ModelList.py +3 -3
  28. edsl/language_models/repair.py +8 -7
  29. edsl/notebooks/Notebook.py +40 -3
  30. edsl/prompts/Prompt.py +31 -19
  31. edsl/questions/QuestionBase.py +38 -13
  32. edsl/questions/QuestionBudget.py +5 -6
  33. edsl/questions/QuestionCheckBox.py +7 -3
  34. edsl/questions/QuestionExtract.py +5 -3
  35. edsl/questions/QuestionFreeText.py +3 -3
  36. edsl/questions/QuestionFunctional.py +0 -3
  37. edsl/questions/QuestionList.py +3 -4
  38. edsl/questions/QuestionMultipleChoice.py +16 -8
  39. edsl/questions/QuestionNumerical.py +4 -3
  40. edsl/questions/QuestionRank.py +5 -3
  41. edsl/questions/__init__.py +4 -3
  42. edsl/questions/descriptors.py +4 -2
  43. edsl/questions/question_registry.py +20 -31
  44. edsl/questions/settings.py +1 -1
  45. edsl/results/Dataset.py +31 -0
  46. edsl/results/DatasetExportMixin.py +493 -0
  47. edsl/results/Result.py +22 -74
  48. edsl/results/Results.py +105 -67
  49. edsl/results/ResultsDBMixin.py +7 -3
  50. edsl/results/ResultsExportMixin.py +22 -537
  51. edsl/results/ResultsGGMixin.py +3 -3
  52. edsl/results/ResultsToolsMixin.py +5 -5
  53. edsl/scenarios/FileStore.py +140 -0
  54. edsl/scenarios/Scenario.py +5 -6
  55. edsl/scenarios/ScenarioList.py +44 -15
  56. edsl/scenarios/ScenarioListExportMixin.py +32 -0
  57. edsl/scenarios/ScenarioListPdfMixin.py +2 -1
  58. edsl/scenarios/__init__.py +1 -0
  59. edsl/study/ObjectEntry.py +89 -13
  60. edsl/study/ProofOfWork.py +5 -2
  61. edsl/study/SnapShot.py +4 -8
  62. edsl/study/Study.py +21 -14
  63. edsl/study/__init__.py +2 -0
  64. edsl/surveys/MemoryPlan.py +11 -4
  65. edsl/surveys/Survey.py +46 -7
  66. edsl/surveys/SurveyExportMixin.py +4 -2
  67. edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
  68. edsl/tools/plotting.py +4 -2
  69. edsl/utilities/__init__.py +21 -21
  70. edsl/utilities/interface.py +66 -45
  71. edsl/utilities/utilities.py +11 -13
  72. {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/METADATA +11 -10
  73. {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/RECORD +75 -72
  74. edsl-0.1.28.dist-info/entry_points.txt +0 -3
  75. {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
  76. {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
@@ -1,8 +1,6 @@
1
1
  from typing import Union, List, Any
2
2
  import asyncio
3
3
  import time
4
- from collections import UserDict
5
- from matplotlib import pyplot as plt
6
4
 
7
5
 
8
6
  class TokenBucket:
@@ -114,6 +112,7 @@ class TokenBucket:
114
112
  times, tokens = zip(*self.get_log())
115
113
  start_time = times[0]
116
114
  times = [t - start_time for t in times] # Normalize time to start from 0
115
+ from matplotlib import pyplot as plt
117
116
 
118
117
  plt.figure(figsize=(10, 6))
119
118
  plt.plot(times, tokens, label="Tokens Available")
@@ -6,15 +6,9 @@ import asyncio
6
6
  import time
7
7
  from typing import Any, Type, List, Generator, Optional
8
8
 
9
- from edsl.agents import Agent
10
- from edsl.language_models import LanguageModel
11
- from edsl.scenarios import Scenario
12
- from edsl.surveys import Survey
13
-
14
9
  from edsl.jobs.Answers import Answers
15
10
  from edsl.surveys.base import EndOfSurvey
16
11
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
17
-
18
12
  from edsl.jobs.tasks.TaskCreators import TaskCreators
19
13
 
20
14
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
@@ -60,9 +54,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
60
54
  self.debug = debug
61
55
  self.iteration = iteration
62
56
  self.cache = cache
63
- self.answers: dict[
64
- str, str
65
- ] = Answers() # will get filled in as interview progresses
57
+ self.answers: dict[str, str] = (
58
+ Answers()
59
+ ) # will get filled in as interview progresses
66
60
  self.sidecar_model = sidecar_model
67
61
 
68
62
  # Trackers
@@ -17,9 +17,9 @@ class InterviewStatusMixin:
17
17
  The keys are the question names; the values are the lists of status log changes for each task.
18
18
  """
19
19
  for task_creator in self.task_creators.values():
20
- self._task_status_log_dict[
21
- task_creator.question.question_name
22
- ] = task_creator.status_log
20
+ self._task_status_log_dict[task_creator.question.question_name] = (
21
+ task_creator.status_log
22
+ )
23
23
  return self._task_status_log_dict
24
24
 
25
25
  @property
@@ -5,17 +5,19 @@ import asyncio
5
5
  import time
6
6
  import traceback
7
7
  from typing import Generator, Union
8
+
8
9
  from edsl import CONFIG
9
10
  from edsl.exceptions import InterviewTimeoutError
10
- from edsl.data_transfer_models import AgentResponseDict
11
- from edsl.questions.QuestionBase import QuestionBase
11
+
12
+ # from edsl.questions.QuestionBase import QuestionBase
12
13
  from edsl.surveys.base import EndOfSurvey
13
14
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
14
15
  from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
15
16
  from edsl.jobs.interviews.retry_management import retry_strategy
16
17
  from edsl.jobs.tasks.task_status_enum import TaskStatus
17
18
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
18
- from edsl.agents.InvigilatorBase import InvigilatorBase
19
+
20
+ # from edsl.agents.InvigilatorBase import InvigilatorBase
19
21
 
20
22
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
21
23
 
@@ -44,6 +46,7 @@ class InterviewTaskBuildingMixin:
44
46
  scenario=self.scenario,
45
47
  model=self.model,
46
48
  debug=debug,
49
+ survey=self.survey,
47
50
  memory_plan=self.survey.memory_plan,
48
51
  current_answers=self.answers,
49
52
  iteration=self.iteration,
@@ -149,15 +152,17 @@ class InterviewTaskBuildingMixin:
149
152
  async def _answer_question_and_record_task(
150
153
  self,
151
154
  *,
152
- question: QuestionBase,
155
+ question: "QuestionBase",
153
156
  debug: bool,
154
157
  task=None,
155
- ) -> AgentResponseDict:
158
+ ) -> "AgentResponseDict":
156
159
  """Answer a question and records the task.
157
160
 
158
161
  This in turn calls the the passed-in agent's async_answer_question method, which returns a response dictionary.
159
162
  Note that is updates answers dictionary with the response.
160
163
  """
164
+ from edsl.data_transfer_models import AgentResponseDict
165
+
161
166
  try:
162
167
  invigilator = self._get_invigilator(question, debug=debug)
163
168
 
@@ -253,11 +258,11 @@ class InterviewTaskBuildingMixin:
253
258
  """
254
259
  current_question_index: int = self.to_index[current_question.question_name]
255
260
 
256
- next_question: Union[
257
- int, EndOfSurvey
258
- ] = self.survey.rule_collection.next_question(
259
- q_now=current_question_index,
260
- answers=self.answers | self.scenario | self.agent["traits"],
261
+ next_question: Union[int, EndOfSurvey] = (
262
+ self.survey.rule_collection.next_question(
263
+ q_now=current_question_index,
264
+ answers=self.answers | self.scenario | self.agent["traits"],
265
+ )
261
266
  )
262
267
 
263
268
  next_question_index = next_question.next_q
@@ -1,29 +1,17 @@
1
1
  from __future__ import annotations
2
2
  import time
3
3
  import asyncio
4
- import textwrap
4
+ import time
5
5
  from contextlib import contextmanager
6
6
 
7
7
  from typing import Coroutine, List, AsyncGenerator, Optional, Union
8
8
 
9
- from rich.live import Live
10
- from rich.console import Console
11
-
12
9
  from edsl import shared_globals
13
- from edsl.results import Results, Result
14
-
15
10
  from edsl.jobs.interviews.Interview import Interview
16
- from edsl.utilities.decorators import jupyter_nb_handler
17
-
18
- # from edsl.jobs.Jobs import Jobs
19
11
  from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
20
- from edsl.language_models import LanguageModel
21
- from edsl.data.Cache import Cache
22
-
23
12
  from edsl.jobs.tasks.TaskHistory import TaskHistory
24
13
  from edsl.jobs.buckets.BucketCollection import BucketCollection
25
-
26
- import time
14
+ from edsl.utilities.decorators import jupyter_nb_handler
27
15
 
28
16
 
29
17
  class JobsRunnerAsyncio(JobsRunnerStatusMixin):
@@ -42,13 +30,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
42
30
 
43
31
  async def run_async_generator(
44
32
  self,
45
- cache: Cache,
33
+ cache: "Cache",
46
34
  n: int = 1,
47
35
  debug: bool = False,
48
36
  stop_on_exception: bool = False,
49
37
  sidecar_model: "LanguageModel" = None,
50
38
  total_interviews: Optional[List["Interview"]] = None,
51
- ) -> AsyncGenerator[Result, None]:
39
+ ) -> AsyncGenerator["Result", None]:
52
40
  """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
53
41
 
54
42
  Completed tasks are yielded as they are completed.
@@ -155,19 +143,21 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
155
143
 
156
144
  prompt_dictionary = {}
157
145
  for answer_key_name in answer_key_names:
158
- prompt_dictionary[
159
- answer_key_name + "_user_prompt"
160
- ] = question_name_to_prompts[answer_key_name]["user_prompt"]
161
- prompt_dictionary[
162
- answer_key_name + "_system_prompt"
163
- ] = question_name_to_prompts[answer_key_name]["system_prompt"]
146
+ prompt_dictionary[answer_key_name + "_user_prompt"] = (
147
+ question_name_to_prompts[answer_key_name]["user_prompt"]
148
+ )
149
+ prompt_dictionary[answer_key_name + "_system_prompt"] = (
150
+ question_name_to_prompts[answer_key_name]["system_prompt"]
151
+ )
164
152
 
165
153
  raw_model_results_dictionary = {}
166
154
  for result in valid_results:
167
155
  question_name = result["question_name"]
168
- raw_model_results_dictionary[
169
- question_name + "_raw_model_response"
170
- ] = result["raw_model_response"]
156
+ raw_model_results_dictionary[question_name + "_raw_model_response"] = (
157
+ result["raw_model_response"]
158
+ )
159
+
160
+ from edsl.results.Result import Result
171
161
 
172
162
  result = Result(
173
163
  agent=interview.agent,
@@ -197,6 +187,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
197
187
  print_exceptions: bool = True,
198
188
  ) -> "Coroutine":
199
189
  """Runs a collection of interviews, handling both async and sync contexts."""
190
+ from rich.console import Console
191
+
200
192
  console = Console()
201
193
  self.results = []
202
194
  self.start_time = time.monotonic()
@@ -204,6 +196,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
204
196
  self.cache = cache
205
197
  self.sidecar_model = sidecar_model
206
198
 
199
+ from edsl.results.Results import Results
200
+
207
201
  if not progress_bar:
208
202
  # print("Running without progress bar")
209
203
  with cache as c:
@@ -225,6 +219,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
225
219
  results = Results(survey=self.jobs.survey, data=self.results)
226
220
  else:
227
221
  # print("Running with progress bar")
222
+ from rich.live import Live
223
+ from rich.console import Console
228
224
 
229
225
  def generate_table():
230
226
  return self.status_table(self.results, self.elapsed_time)
@@ -1,8 +1,5 @@
1
1
  from edsl.jobs.tasks.task_status_enum import TaskStatus
2
- from matplotlib import pyplot as plt
3
2
  from typing import List, Optional
4
-
5
- import matplotlib.pyplot as plt
6
3
  from io import BytesIO
7
4
  import base64
8
5
 
@@ -75,6 +72,8 @@ class TaskHistory:
75
72
 
76
73
  def plot_completion_times(self):
77
74
  """Plot the completion times for each task."""
75
+ import matplotlib.pyplot as plt
76
+
78
77
  updates = self.get_updates()
79
78
 
80
79
  elapsed = [update.max_time - update.min_time for update in updates]
@@ -126,6 +125,8 @@ class TaskHistory:
126
125
  rows = int(len(TaskStatus) ** 0.5) + 1
127
126
  cols = (len(TaskStatus) + rows - 1) // rows # Ensure all plots fit
128
127
 
128
+ import matplotlib.pyplot as plt
129
+
129
130
  fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
130
131
  axes = axes.flatten() # Flatten in case of a single row/column
131
132
 
@@ -7,26 +7,18 @@ import asyncio
7
7
  import json
8
8
  import time
9
9
  import os
10
-
11
10
  from typing import Coroutine, Any, Callable, Type, List, get_type_hints
12
-
13
- from abc import ABC, abstractmethod, ABCMeta
14
-
15
- from rich.table import Table
11
+ from abc import ABC, abstractmethod
16
12
 
17
13
  from edsl.config import CONFIG
18
14
 
19
- from edsl.utilities.utilities import clean_json
20
15
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
21
16
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
17
+
22
18
  from edsl.language_models.repair import repair
23
- from edsl.exceptions.language_models import LanguageModelAttributeTypeError
24
19
  from edsl.enums import InferenceServiceType
25
20
  from edsl.Base import RichPrintingMixin, PersistenceMixin
26
- from edsl.data.Cache import Cache
27
21
  from edsl.enums import service_to_api_keyname
28
-
29
-
30
22
  from edsl.exceptions import MissingAPIKeyError
31
23
  from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
32
24
 
@@ -291,7 +283,7 @@ class LanguageModel(
291
283
  self,
292
284
  user_prompt: str,
293
285
  system_prompt: str,
294
- cache,
286
+ cache: "Cache",
295
287
  iteration: int = 0,
296
288
  encoded_image=None,
297
289
  ) -> tuple[dict, bool, str]:
@@ -490,6 +482,8 @@ class LanguageModel(
490
482
 
491
483
  def rich_print(self):
492
484
  """Display an object as a table."""
485
+ from rich.table import Table
486
+
493
487
  table = Table(title="Language Model")
494
488
  table.add_column("Attribute", style="bold")
495
489
  table.add_column("Value")
@@ -5,7 +5,7 @@ from edsl import Model
5
5
  from edsl.language_models import LanguageModel
6
6
  from edsl.Base import Base
7
7
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
8
- from edsl.utilities import is_valid_variable_name
8
+ from edsl.utilities.utilities import is_valid_variable_name
9
9
  from edsl.utilities.utilities import dict_hash
10
10
 
11
11
 
@@ -56,11 +56,11 @@ class ModelList(Base, UserList):
56
56
  return {"models": [model._to_dict() for model in self]}
57
57
 
58
58
  @classmethod
59
- def from_names(self, *args):
59
+ def from_names(self, *args, **kwargs):
60
60
  """A a model list from a list of names"""
61
61
  if len(args) == 1 and isinstance(args[0], list):
62
62
  args = args[0]
63
- return ModelList([Model(model_name) for model_name in args])
63
+ return ModelList([Model(model_name, **kwargs) for model_name in args])
64
64
 
65
65
  @add_edsl_version
66
66
  def to_dict(self):
@@ -1,18 +1,13 @@
1
1
  import json
2
2
  import asyncio
3
3
  import warnings
4
- from rich import print
5
- from rich.console import Console
6
- from rich.syntax import Syntax
7
-
8
- from edsl.utilities.utilities import clean_json
9
-
10
- from edsl.utilities.repair_functions import extract_json_from_string
11
4
 
12
5
 
13
6
  async def async_repair(
14
7
  bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
15
8
  ):
9
+ from edsl.utilities.utilities import clean_json
10
+
16
11
  s = clean_json(bad_json)
17
12
 
18
13
  try:
@@ -27,6 +22,8 @@ async def async_repair(
27
22
  return valid_dict, success
28
23
 
29
24
  try:
25
+ from edsl.utilities.repair_functions import extract_json_from_string
26
+
30
27
  valid_dict = extract_json_from_string(s)
31
28
  success = True
32
29
  except ValueError:
@@ -98,6 +95,10 @@ async def async_repair(
98
95
  except json.JSONDecodeError:
99
96
  valid_dict = {}
100
97
  success = False
98
+ from rich import print
99
+ from rich.console import Console
100
+ from rich.syntax import Syntax
101
+
101
102
  console = Console()
102
103
  error_message = (
103
104
  f"All repairs. failed. LLM Model given [red]{str(bad_json)}[/red]"
@@ -1,10 +1,9 @@
1
1
  """A Notebook is a utility class that allows you to easily share/pull ipynbs from Coop."""
2
2
 
3
3
  import json
4
- import nbformat
5
- from nbconvert import HTMLExporter
6
4
  from typing import Dict, List, Optional
7
- from rich.table import Table
5
+
6
+
8
7
  from edsl.Base import Base
9
8
  from edsl.utilities.decorators import (
10
9
  add_edsl_version,
@@ -34,6 +33,8 @@ class Notebook(Base):
34
33
  If no path is provided, assume this code is run in a notebook and try to load the current notebook from file.
35
34
  :param name: A name for the Notebook.
36
35
  """
36
+ import nbformat
37
+
37
38
  # Load current notebook path as fallback (VS Code only)
38
39
  path = path or globals().get("__vsc_ipynb_file__")
39
40
  if data is not None:
@@ -56,6 +57,37 @@ class Notebook(Base):
56
57
 
57
58
  self.name = name or self.default_name
58
59
 
60
+ @classmethod
61
+ def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
62
+ # Read the script file
63
+ with open(path, "r") as script_file:
64
+ script_content = script_file.read()
65
+
66
+ # Create a new Jupyter notebook
67
+ nb = nbformat.v4.new_notebook()
68
+
69
+ # Add the script content to the first cell
70
+ first_cell = nbformat.v4.new_code_cell(script_content)
71
+ nb.cells.append(first_cell)
72
+
73
+ # Create a Notebook instance with the notebook data
74
+ notebook_instance = cls(nb)
75
+
76
+ return notebook_instance
77
+
78
+ @classmethod
79
+ def from_current_script(cls) -> "Notebook":
80
+ import inspect
81
+ import os
82
+
83
+ # Get the path to the current file
84
+ current_frame = inspect.currentframe()
85
+ caller_frame = inspect.getouterframes(current_frame, 2)
86
+ current_file_path = os.path.abspath(caller_frame[1].filename)
87
+
88
+ # Use from_script to create the notebook
89
+ return cls.from_script(current_file_path)
90
+
59
91
  def __eq__(self, other):
60
92
  """
61
93
  Check if two Notebooks are equal.
@@ -103,6 +135,9 @@ class Notebook(Base):
103
135
  """
104
136
  Return HTML representation of Notebook.
105
137
  """
138
+ from nbconvert import HTMLExporter
139
+ import nbformat
140
+
106
141
  notebook = nbformat.from_dict(self.data)
107
142
  html_exporter = HTMLExporter(template_name="basic")
108
143
  (body, _) = html_exporter.from_notebook_node(notebook)
@@ -143,6 +178,8 @@ class Notebook(Base):
143
178
  """
144
179
  Display a Notebook as a rich table.
145
180
  """
181
+ from rich.table import Table
182
+
146
183
  table_data, column_names = self._table()
147
184
  table = Table(title=f"{self.__class__.__name__} Attributes")
148
185
  for column in column_names:
edsl/prompts/Prompt.py CHANGED
@@ -1,12 +1,16 @@
1
- """Class for creating prompts to be used in a survey."""
2
-
3
1
  from __future__ import annotations
4
2
  from typing import Optional
5
3
  from abc import ABC
6
4
  from typing import Any, List
7
5
 
8
6
  from rich.table import Table
9
- from jinja2 import Template, Environment, meta, TemplateSyntaxError
7
+ from jinja2 import Template, Environment, meta, TemplateSyntaxError, Undefined
8
+
9
+
10
+ class PreserveUndefined(Undefined):
11
+ def __str__(self):
12
+ return "{{ " + self._undefined_name + " }}"
13
+
10
14
 
11
15
  from edsl.exceptions.prompts import TemplateRenderError
12
16
  from edsl.prompts.prompt_config import (
@@ -35,6 +39,10 @@ class PromptBase(
35
39
 
36
40
  return data_to_html(self.to_dict())
37
41
 
42
+ def __len__(self):
43
+ """Return the length of the prompt text."""
44
+ return len(self.text)
45
+
38
46
  @classmethod
39
47
  def prompt_attributes(cls) -> List[str]:
40
48
  """Return the prompt class attributes."""
@@ -75,10 +83,10 @@ class PromptBase(
75
83
  >>> p = Prompt("Hello, {{person}}")
76
84
  >>> p2 = Prompt("How are you?")
77
85
  >>> p + p2
78
- Prompt(text='Hello, {{person}}How are you?')
86
+ Prompt(text=\"""Hello, {{person}}How are you?\""")
79
87
 
80
88
  >>> p + "How are you?"
81
- Prompt(text='Hello, {{person}}How are you?')
89
+ Prompt(text=\"""Hello, {{person}}How are you?\""")
82
90
  """
83
91
  if isinstance(other_prompt, str):
84
92
  return self.__class__(self.text + other_prompt)
@@ -114,7 +122,7 @@ class PromptBase(
114
122
  Example:
115
123
  >>> p = Prompt("Hello, {{person}}")
116
124
  >>> p
117
- Prompt(text='Hello, {{person}}')
125
+ Prompt(text=\"""Hello, {{person}}\""")
118
126
  """
119
127
  return f'Prompt(text="""{self.text}""")'
120
128
 
@@ -137,7 +145,7 @@ class PromptBase(
137
145
  :param template: The template to find the variables in.
138
146
 
139
147
  """
140
- env = Environment()
148
+ env = Environment(undefined=PreserveUndefined)
141
149
  ast = env.parse(template)
142
150
  return list(meta.find_undeclared_variables(ast))
143
151
 
@@ -186,13 +194,16 @@ class PromptBase(
186
194
 
187
195
  >>> p = Prompt("Hello, {{person}}")
188
196
  >>> p.render({"person": "John"})
189
- 'Hello, John'
197
+ Prompt(text=\"""Hello, John\""")
190
198
 
191
199
  >>> p.render({"person": "Mr. {{last_name}}", "last_name": "Horton"})
192
- 'Hello, Mr. Horton'
200
+ Prompt(text=\"""Hello, Mr. Horton\""")
193
201
 
194
202
  >>> p.render({"person": "Mr. {{last_name}}", "last_name": "Ho{{letter}}ton"}, max_nesting = 1)
195
- 'Hello, Mr. Horton'
203
+ Prompt(text=\"""Hello, Mr. Ho{{ letter }}ton\""")
204
+
205
+ >>> p.render({"person": "Mr. {{last_name}}"})
206
+ Prompt(text=\"""Hello, Mr. {{ last_name }}\""")
196
207
  """
197
208
  new_text = self._render(
198
209
  self.text, primary_replacement, **additional_replacements
@@ -216,12 +227,13 @@ class PromptBase(
216
227
  >>> codebook = {"age": "Age"}
217
228
  >>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
218
229
  >>> p.render({"name": "John", "age": 44}, codebook=codebook)
219
- 'You are an agent named John. Age: 44'
230
+ Prompt(text=\"""You are an agent named John. Age: 44\""")
220
231
  """
232
+ env = Environment(undefined=PreserveUndefined)
221
233
  try:
222
234
  previous_text = None
223
235
  for _ in range(MAX_NESTING):
224
- rendered_text = Template(text).render(
236
+ rendered_text = env.from_string(text).render(
225
237
  primary_replacement, **additional_replacements
226
238
  )
227
239
  if rendered_text == previous_text:
@@ -258,7 +270,7 @@ class PromptBase(
258
270
  >>> p = Prompt("Hello, {{person}}")
259
271
  >>> p2 = Prompt.from_dict(p.to_dict())
260
272
  >>> p2
261
- Prompt(text='Hello, {{person}}')
273
+ Prompt(text=\"""Hello, {{person}}\""")
262
274
 
263
275
  """
264
276
  class_name = data["class_name"]
@@ -290,6 +302,12 @@ class Prompt(PromptBase):
290
302
  component_type = ComponentTypes.GENERIC
291
303
 
292
304
 
305
+ if __name__ == "__main__":
306
+ print("Running doctests...")
307
+ import doctest
308
+
309
+ doctest.testmod()
310
+
293
311
  from edsl.prompts.library.question_multiple_choice import *
294
312
  from edsl.prompts.library.agent_instructions import *
295
313
  from edsl.prompts.library.agent_persona import *
@@ -302,9 +320,3 @@ from edsl.prompts.library.question_numerical import *
302
320
  from edsl.prompts.library.question_rank import *
303
321
  from edsl.prompts.library.question_extract import *
304
322
  from edsl.prompts.library.question_list import *
305
-
306
-
307
- if __name__ == "__main__":
308
- import doctest
309
-
310
- doctest.testmod()