edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,131 @@
1
+ from dataclasses import dataclass, asdict
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class APIKeyEntry:
7
+ """A class representing an API key entry.
8
+
9
+ >>> entry = APIKeyEntry.example()
10
+ >>> entry.service
11
+ 'openai'
12
+ >>> entry.name
13
+ 'OPENAI_API_KEY'
14
+ >>> entry.value
15
+ 'sk-abcd1234'
16
+ >>> entry.source
17
+ 'env'
18
+ """
19
+
20
+ service: str
21
+ name: str
22
+ value: str
23
+ source: Optional[str] = None
24
+
25
+ @classmethod
26
+ def example(cls):
27
+ return APIKeyEntry(
28
+ service="openai", name="OPENAI_API_KEY", value="sk-abcd1234", source="env"
29
+ )
30
+
31
+
32
+ @dataclass
33
+ class LimitEntry:
34
+ """A class representing rate limit entries for a service.
35
+
36
+ >>> limit = LimitEntry.example()
37
+ >>> limit.service
38
+ 'openai'
39
+ >>> limit.rpm
40
+ 60
41
+ >>> limit.tpm
42
+ 100000
43
+ >>> limit.source
44
+ 'config'
45
+ """
46
+
47
+ service: str
48
+ rpm: int
49
+ tpm: int
50
+ source: Optional[str] = None
51
+
52
+ @classmethod
53
+ def example(cls):
54
+ return LimitEntry(service="openai", rpm=60, tpm=100000, source="config")
55
+
56
+
57
+ @dataclass
58
+ class APIIDEntry:
59
+ """A class representing an API ID entry.
60
+
61
+ >>> id_entry = APIIDEntry.example()
62
+ >>> id_entry.service
63
+ 'bedrock'
64
+ >>> id_entry.name
65
+ 'AWS_ACCESS_KEY_ID'
66
+ >>> id_entry.value
67
+ 'AKIA1234'
68
+ >>> id_entry.source
69
+ 'env'
70
+ """
71
+
72
+ service: str
73
+ name: str
74
+ value: str
75
+ source: Optional[str] = None
76
+
77
+ @classmethod
78
+ def example(cls):
79
+ return APIIDEntry(
80
+ service="bedrock", name="AWS_ACCESS_KEY_ID", value="AKIA1234", source="env"
81
+ )
82
+
83
+
84
+ @dataclass
85
+ class LanguageModelInput:
86
+ """A class representing input configuration for a language model service.
87
+
88
+ >>> lm_input = LanguageModelInput.example()
89
+ >>> lm_input.api_token
90
+ 'sk-abcd123'
91
+ >>> lm_input.rpm
92
+ 60
93
+ >>> lm_input.tpm
94
+ 100000
95
+ >>> lm_input.api_id
96
+
97
+
98
+ Test dictionary conversion:
99
+ >>> d = lm_input.to_dict()
100
+ >>> isinstance(d, dict)
101
+ True
102
+ >>> LanguageModelInput.from_dict(d).api_token == lm_input.api_token
103
+ True
104
+ """
105
+
106
+ api_token: str
107
+ rpm: int
108
+ tpm: int
109
+ api_id: Optional[str] = None
110
+ token_source: Optional[str] = None
111
+ limit_source: Optional[str] = None
112
+ id_source: Optional[str] = None
113
+
114
+ def to_dict(self):
115
+ return asdict(self)
116
+
117
+ @classmethod
118
+ def from_dict(cls, d):
119
+ return cls(**d)
120
+
121
+ @classmethod
122
+ def example(cls):
123
+ return LanguageModelInput(
124
+ api_token="sk-abcd123", tpm=100000, rpm=60, api_id=None
125
+ )
126
+
127
+
128
+ if __name__ == "__main__":
129
+ import doctest
130
+
131
+ doctest.testmod()
@@ -1,54 +1,9 @@
1
1
  import textwrap
2
2
  from random import random
3
3
  from edsl.config import CONFIG
4
-
5
- # if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
- # default_model = "test"
7
- # else:
8
- # default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
9
-
10
- from collections import UserList
11
-
12
-
13
- class PrettyList(UserList):
14
- def __init__(self, data=None, columns=None):
15
- super().__init__(data)
16
- self.columns = columns
17
-
18
- def _repr_html_(self):
19
- if isinstance(self[0], list) or isinstance(self[0], tuple):
20
- num_cols = len(self[0])
21
- else:
22
- num_cols = 1
23
-
24
- if self.columns:
25
- columns = self.columns
26
- else:
27
- columns = list(range(num_cols))
28
-
29
- if num_cols > 1:
30
- return (
31
- "<pre><table>"
32
- + "".join(["<th>" + str(column) + "</th>" for column in columns])
33
- + "".join(
34
- [
35
- "<tr>"
36
- + "".join(["<td>" + str(x) + "</td>" for x in row])
37
- + "</tr>"
38
- for row in self
39
- ]
40
- )
41
- + "</table></pre>"
42
- )
43
- else:
44
- return (
45
- "<pre><table>"
46
- + "".join(["<th>" + str(index) + "</th>" for index in columns])
47
- + "".join(
48
- ["<tr>" + "<td>" + str(row) + "</td>" + "</tr>" for row in self]
49
- )
50
- + "</table></pre>"
51
- )
4
+ from functools import lru_cache
5
+ from edsl.utilities.PrettyList import PrettyList
6
+ from typing import Optional
52
7
 
53
8
 
54
9
  def get_model_class(model_name, registry=None):
@@ -103,48 +58,83 @@ class Model(metaclass=Meta):
103
58
  registry = default
104
59
  registry.add_model(service_name, model_name)
105
60
 
61
+ @classmethod
62
+ def service_classes(cls, registry=None):
63
+ from edsl.inference_services.registry import default
64
+
65
+ registry = registry or default
66
+ return [r for r in registry.services]
67
+
106
68
  @classmethod
107
69
  def services(cls, registry=None):
108
70
  from edsl.inference_services.registry import default
109
71
 
110
72
  registry = registry or default
111
- return [r._inference_service_ for r in registry.services]
73
+ return PrettyList(
74
+ [r._inference_service_ for r in registry.services], columns=["Service Name"]
75
+ )
112
76
 
113
77
  @classmethod
114
- def available(cls, search_term=None, name_only=False, registry=None, service=None):
78
+ def key_info(cls):
79
+ from edsl.language_models.key_management import KeyLookupCollection
80
+ from edsl.scenarios import Scenario, ScenarioList
81
+
82
+ klc = KeyLookupCollection()
83
+ klc.add_key_lookup(fetch_order=None)
84
+ sl = ScenarioList()
85
+ for service, entry in list(klc.data.values())[0].items():
86
+ sl.append(Scenario({"service": service} | entry.to_dict()))
87
+ return sl.to_dataset()
88
+
89
+ @classmethod
90
+ def available(
91
+ cls,
92
+ search_term: str = None,
93
+ name_only: bool = False,
94
+ registry=None,
95
+ service: Optional[str] = None,
96
+ ):
115
97
  from edsl.inference_services.registry import default
116
98
 
117
99
  registry = registry or default
118
- full_list = registry.available()
100
+ # full_list = registry.available()
119
101
 
120
102
  if service is not None:
121
103
  if service not in cls.services(registry=registry):
122
104
  raise ValueError(f"Service {service} not found in available services.")
123
105
 
124
- full_list = [m for m in full_list if m[1] == service]
106
+ # import time
107
+ # start = time.time()
108
+ full_list = registry.available(service=service)
109
+ # end = time.time()
110
+ # print(f"Time taken to get available models: {end-start}")
125
111
 
126
112
  if search_term is None:
127
113
  if name_only:
128
114
  return PrettyList(
129
- [m[0] for m in full_list],
130
- columns=["Model Name", "Service Name", "Code"],
115
+ [m.model_name for m in full_list],
116
+ columns=["Model Name"],
131
117
  )
132
118
  else:
133
119
  return PrettyList(
134
- full_list, columns=["Model Name", "Service Name", "Code"]
120
+ [[m.model_name, m.service_name] for m in full_list],
121
+ columns=["Model Name", "Service Name"],
135
122
  )
136
123
  else:
137
124
  filtered_results = [
138
- m for m in full_list if search_term in m[0] or search_term in m[1]
125
+ m
126
+ for m in full_list
127
+ if search_term in m.model_name or search_term in m.service_name
139
128
  ]
140
129
  if name_only:
141
130
  return PrettyList(
142
- [m[0] for m in filtered_results],
143
- columns=["Model Name", "Service Name", "Code"],
131
+ [m.model_name for m in filtered_results],
132
+ columns=["Model Name"],
144
133
  )
145
134
  else:
146
135
  return PrettyList(
147
- filtered_results, columns=["Model Name", "Service Name", "Code"]
136
+ [[m.model_name, m.service_name] for m in full_list],
137
+ columns=["Model Name", "Service Name"],
148
138
  )
149
139
 
150
140
  @classmethod
@@ -32,11 +32,11 @@ async def async_repair(
32
32
  else:
33
33
  return valid_dict, success
34
34
 
35
- from edsl import Model
35
+ from edsl.language_models.registry import Model
36
36
 
37
37
  m = Model()
38
38
 
39
- from edsl import QuestionExtract
39
+ from edsl.questions.QuestionExtract import QuestionExtract
40
40
 
41
41
  with warnings.catch_warnings():
42
42
  warnings.simplefilter("ignore", UserWarning)
@@ -1,13 +1,12 @@
1
1
  import asyncio
2
2
  from typing import Any, Optional, List
3
- from edsl import Survey
4
- from edsl.config import CONFIG
5
3
  from edsl.enums import InferenceServiceType
6
- from edsl.language_models.LanguageModel import LanguageModel
7
- from edsl.questions import QuestionFreeText
8
4
 
9
5
 
10
6
  def create_survey(num_questions: int, chained: bool = True, take_scenario=False):
7
+ from edsl.surveys.Survey import Survey
8
+ from edsl.questions.QuestionFreeText import QuestionFreeText
9
+
11
10
  survey = Survey()
12
11
  for i in range(num_questions):
13
12
  if take_scenario:
@@ -28,6 +27,8 @@ def create_survey(num_questions: int, chained: bool = True, take_scenario=False)
28
27
  def create_language_model(
29
28
  exception: Exception, fail_at_number: int, never_ending=False
30
29
  ):
30
+ from edsl.language_models.LanguageModel import LanguageModel
31
+
31
32
  class LanguageModelFromUtilities(LanguageModel):
32
33
  _model_ = "test"
33
34
  _parameters_ = {"temperature": 0.5}
@@ -17,8 +17,8 @@ class Notebook(Base):
17
17
 
18
18
  def __init__(
19
19
  self,
20
- data: Optional[Dict] = None,
21
20
  path: Optional[str] = None,
21
+ data: Optional[Dict] = None,
22
22
  name: Optional[str] = None,
23
23
  ):
24
24
  """
@@ -33,12 +33,16 @@ class Notebook(Base):
33
33
  import nbformat
34
34
 
35
35
  # Load current notebook path as fallback (VS Code only)
36
- path = path or globals().get("__vsc_ipynb_file__")
37
- if data is not None:
36
+ current_notebook_path = globals().get("__vsc_ipynb_file__")
37
+ if path is not None:
38
+ with open(path, mode="r", encoding="utf-8") as f:
39
+ data = nbformat.read(f, as_version=4)
40
+ self.data = json.loads(json.dumps(data))
41
+ elif data is not None:
38
42
  nbformat.validate(data)
39
43
  self.data = data
40
- elif path is not None:
41
- with open(path, mode="r", encoding="utf-8") as f:
44
+ elif current_notebook_path is not None:
45
+ with open(current_notebook_path, mode="r", encoding="utf-8") as f:
42
46
  data = nbformat.read(f, as_version=4)
43
47
  self.data = json.loads(json.dumps(data))
44
48
  else:
@@ -130,15 +134,6 @@ class Notebook(Base):
130
134
 
131
135
  nbformat.write(nbformat.from_dict(self.data), fp=path)
132
136
 
133
- def print(self):
134
- """
135
- Print the notebook.
136
- """
137
- from rich import print_json
138
- import json
139
-
140
- print_json(json.dumps(self.to_dict()))
141
-
142
137
  def __repr__(self):
143
138
  """
144
139
  Return representation of Notebook.
@@ -250,6 +245,16 @@ class Notebook(Base):
250
245
  lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""")')
251
246
  return lines
252
247
 
248
+ def to_latex(self, filename: str):
249
+ """
250
+ Convert notebook to LaTeX and create a folder with all necessary components.
251
+
252
+ :param filename: Name of the output folder and main tex file (without extension)
253
+ """
254
+ from edsl.notebooks.NotebookToLaTeX import NotebookToLaTeX
255
+
256
+ NotebookToLaTeX(self).convert(filename)
257
+
253
258
 
254
259
  if __name__ == "__main__":
255
260
  from edsl import Notebook
@@ -0,0 +1,142 @@
1
+ from typing import Optional, Dict
2
+ import os
3
+ import nbformat
4
+ from nbconvert.exporters import LatexExporter
5
+ from nbconvert.writers import FilesWriter
6
+
7
+
8
+ class NotebookToLaTeX:
9
+ """
10
+ A class for converting Jupyter notebooks to LaTeX with proper directory structure.
11
+ """
12
+
13
+ def __init__(self, notebook):
14
+ """
15
+ Initialize with a Notebook instance.
16
+
17
+ :param notebook: An instance of the Notebook class
18
+ """
19
+ self.notebook = notebook
20
+ self.latex_exporter = LatexExporter()
21
+ self._configure_exporter()
22
+
23
+ def _configure_exporter(self):
24
+ """Configure the LaTeX exporter with default settings."""
25
+ self.latex_exporter.exclude_input_prompt = True
26
+ self.latex_exporter.exclude_output_prompt = True
27
+ self.latex_exporter.template_name = "classic"
28
+
29
+ def _create_makefile(self, filename: str, output_dir: str):
30
+ """Create a Makefile for the LaTeX project."""
31
+ makefile_content = f"""# Makefile for {filename}
32
+ all: pdf
33
+
34
+ pdf: {filename}.pdf
35
+
36
+ {filename}.pdf: {filename}.tex
37
+ \tpdflatex {filename}.tex
38
+ \tpdflatex {filename}.tex # Run twice for references
39
+ \tbibtex {filename} # Run bibtex if needed
40
+ \tpdflatex {filename}.tex # Run one more time for bibtex
41
+
42
+ clean:
43
+ \trm -f *.aux *.log *.out *.toc *.pdf *.bbl *.blg
44
+ """
45
+ makefile_path = os.path.join(output_dir, "Makefile")
46
+ with open(makefile_path, "w") as f:
47
+ f.write(makefile_content)
48
+
49
+ def _create_readme(self, filename: str, output_dir: str):
50
+ """Create a README file with usage instructions."""
51
+ readme_content = f"""# {filename}
52
+
53
+ This folder contains the LaTeX version of your Jupyter notebook.
54
+
55
+ Files:
56
+ - {filename}.tex: Main LaTeX file
57
+ - Makefile: Build automation
58
+
59
+ To compile the PDF:
60
+ 1. Make sure you have a LaTeX distribution installed (e.g., TexLive)
61
+ 2. Run `make` in this directory
62
+ 3. The output will be {filename}.pdf
63
+
64
+ To clean up build files:
65
+ - Run `make clean`
66
+ """
67
+ readme_path = os.path.join(output_dir, "README.md")
68
+ with open(readme_path, "w") as f:
69
+ f.write(readme_content)
70
+
71
+ def convert(self, filename: str, output_dir: Optional[str] = None):
72
+ """
73
+ Convert the notebook to LaTeX and create a project directory.
74
+
75
+ :param filename: Name for the output files (without extension)
76
+ :param output_dir: Optional directory path. If None, uses filename as directory
77
+ """
78
+ # Use filename as directory if no output_dir specified
79
+ output_dir = output_dir or filename
80
+
81
+ # Create output directory
82
+ os.makedirs(output_dir, exist_ok=True)
83
+
84
+ # Convert notebook to nbformat
85
+ notebook_node = nbformat.from_dict(self.notebook.data)
86
+
87
+ # Convert to LaTeX
88
+ body, resources = self.latex_exporter.from_notebook_node(notebook_node)
89
+
90
+ # Write the main tex file
91
+ output_file_path = os.path.join(output_dir, f"{filename}.tex")
92
+ with open(output_file_path, "w", encoding="utf-8") as f:
93
+ f.write(body)
94
+
95
+ # Write additional resources (images, etc.)
96
+ if resources.get("outputs"):
97
+ for fname, data in resources["outputs"].items():
98
+ resource_path = os.path.join(output_dir, fname)
99
+ with open(resource_path, "wb") as f:
100
+ f.write(data)
101
+
102
+ # Create supporting files
103
+ self._create_makefile(filename, output_dir)
104
+ self._create_readme(filename, output_dir)
105
+
106
+ def set_template(self, template_name: str):
107
+ """
108
+ Set the LaTeX template to use.
109
+
110
+ :param template_name: Name of the template (e.g., 'classic', 'article')
111
+ """
112
+ self.latex_exporter.template_name = template_name
113
+
114
+ def set_template_options(self, options: Dict):
115
+ """
116
+ Set additional template options.
117
+
118
+ :param options: Dictionary of template options
119
+ """
120
+ for key, value in options.items():
121
+ setattr(self.latex_exporter, key, value)
122
+
123
+
124
+ # Example usage:
125
+ if __name__ == "__main__":
126
+ from edsl import Notebook
127
+
128
+ # Create or load a notebook
129
+ notebook = Notebook.example()
130
+
131
+ # Create converter and convert
132
+ converter = NotebookToLaTeX(notebook)
133
+ converter.convert("example_output")
134
+
135
+ # Example with custom template options
136
+ converter.set_template_options(
137
+ {
138
+ "exclude_input": True, # Hide input cells
139
+ "exclude_output": False, # Show output cells
140
+ }
141
+ )
142
+ converter.convert("example_output_custom")
edsl/prompts/Prompt.py CHANGED
@@ -1,43 +1,21 @@
1
1
  from __future__ import annotations
2
- from typing import Optional
3
- from abc import ABC
4
- from typing import Any, List
5
-
6
- from jinja2 import Environment, FileSystemLoader
7
- from typing import Union, Dict
2
+ from typing import Any, List, Union, Dict, Optional
8
3
  from pathlib import Path
9
4
 
10
- from rich.table import Table
11
- from jinja2 import Template, Environment, meta, TemplateSyntaxError, Undefined
12
-
13
-
14
- class PreserveUndefined(Undefined):
15
- def __str__(self):
16
- return "{{ " + str(self._undefined_name) + " }}"
5
+ # from jinja2 import Undefined
17
6
 
18
7
 
19
8
  from edsl.exceptions.prompts import TemplateRenderError
20
- from edsl.Base import PersistenceMixin, RichPrintingMixin
9
+ from edsl.Base import PersistenceMixin, RepresentationMixin
21
10
 
22
11
  MAX_NESTING = 100
23
12
 
24
13
 
25
- class Prompt(PersistenceMixin, RichPrintingMixin):
14
+ class Prompt(PersistenceMixin, RepresentationMixin):
26
15
  """Class for creating a prompt to be used in a survey."""
27
16
 
28
17
  default_instructions: Optional[str] = "Do good things, friendly LLM!"
29
18
 
30
- def _repr_html_(self):
31
- """Return an HTML representation of the Prompt."""
32
- # from edsl.utilities.utilities import data_to_html
33
- # return data_to_html(self.to_dict())
34
- d = self.to_dict()
35
- data = [[k, v] for k, v in d.items()]
36
- from tabulate import tabulate
37
-
38
- table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
39
- return f"<pre>{table}</pre>"
40
-
41
19
  def __len__(self):
42
20
  """Return the length of the prompt text."""
43
21
  return len(self.text)
@@ -185,6 +163,12 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
185
163
  :param template: The template to find the variables in.
186
164
 
187
165
  """
166
+ from jinja2 import Environment, meta, Undefined
167
+
168
+ class PreserveUndefined(Undefined):
169
+ def __str__(self):
170
+ return "{{ " + str(self._undefined_name) + " }}"
171
+
188
172
  env = Environment(undefined=PreserveUndefined)
189
173
  ast = env.parse(template)
190
174
  return list(meta.find_undeclared_variables(ast))
@@ -273,6 +257,12 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
273
257
  >>> p.render({"name": "John", "age": 44}, codebook=codebook)
274
258
  Prompt(text=\"""You are an agent named John. Age: 44\""")
275
259
  """
260
+ from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
261
+
262
+ class PreserveUndefined(Undefined):
263
+ def __str__(self):
264
+ return "{{ " + str(self._undefined_name) + " }}"
265
+
276
266
  env = Environment(undefined=PreserveUndefined)
277
267
  try:
278
268
  previous_text = None
@@ -296,7 +286,7 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
296
286
  f"Template syntax error: {e}. Bad template: {text}"
297
287
  )
298
288
 
299
- def to_dict(self) -> dict[str, Any]:
289
+ def to_dict(self, add_edsl_version=False) -> dict[str, Any]:
300
290
  """Return the `Prompt` as a dictionary.
301
291
 
302
292
  Example:
@@ -323,18 +313,18 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
323
313
  # class_name = data["class_name"]
324
314
  return Prompt(text=data["text"])
325
315
 
326
- def rich_print(self):
327
- """Display an object as a table."""
328
- table = Table(title="Prompt")
329
- table.add_column("Attribute", style="bold")
330
- table.add_column("Value")
331
-
332
- to_display = self.__dict__.copy()
333
- for attr_name, attr_value in to_display.items():
334
- table.add_row(attr_name, repr(attr_value))
335
- table.add_row("Component type", str(self.component_type))
336
- table.add_row("Model", str(getattr(self, "model", "Not specified")))
337
- return table
316
+ # def rich_print(self):
317
+ # """Display an object as a table."""
318
+ # table = Table(title="Prompt")
319
+ # table.add_column("Attribute", style="bold")
320
+ # table.add_column("Value")
321
+
322
+ # to_display = self.__dict__.copy()
323
+ # for attr_name, attr_value in to_display.items():
324
+ # table.add_row(attr_name, repr(attr_value))
325
+ # table.add_row("Component type", str(self.component_type))
326
+ # table.add_row("Model", str(getattr(self, "model", "Not specified")))
327
+ # return table
338
328
 
339
329
  @classmethod
340
330
  def example(cls):