edsl 0.1.38.dev1__py3-none-any.whl → 0.1.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -3
- edsl/BaseDiff.py +7 -7
- edsl/__init__.py +2 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +17 -14
- edsl/agents/AgentList.py +29 -17
- edsl/auto/SurveyCreatorPipeline.py +1 -1
- edsl/auto/utilities.py +1 -1
- edsl/base/Base.py +3 -13
- edsl/coop/coop.py +3 -0
- edsl/data/Cache.py +18 -15
- edsl/exceptions/agents.py +4 -0
- edsl/exceptions/cache.py +5 -0
- edsl/jobs/Jobs.py +22 -11
- edsl/jobs/buckets/TokenBucket.py +3 -0
- edsl/jobs/interviews/Interview.py +18 -18
- edsl/jobs/runners/JobsRunnerAsyncio.py +38 -15
- edsl/jobs/runners/JobsRunnerStatus.py +196 -196
- edsl/jobs/tasks/TaskHistory.py +12 -3
- edsl/language_models/LanguageModel.py +9 -7
- edsl/language_models/ModelList.py +20 -13
- edsl/notebooks/Notebook.py +7 -8
- edsl/questions/QuestionBase.py +21 -17
- edsl/questions/QuestionBaseGenMixin.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +0 -17
- edsl/questions/QuestionFunctional.py +10 -3
- edsl/questions/derived/QuestionTopK.py +2 -0
- edsl/results/Result.py +31 -25
- edsl/results/Results.py +22 -22
- edsl/scenarios/Scenario.py +12 -14
- edsl/scenarios/ScenarioList.py +16 -16
- edsl/surveys/MemoryPlan.py +1 -1
- edsl/surveys/Rule.py +1 -5
- edsl/surveys/RuleCollection.py +1 -1
- edsl/surveys/Survey.py +9 -17
- edsl/surveys/instructions/ChangeInstruction.py +9 -7
- edsl/surveys/instructions/Instruction.py +9 -7
- edsl/{conjure → utilities}/naming_utilities.py +1 -1
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/RECORD +42 -56
- edsl/conjure/AgentConstructionMixin.py +0 -160
- edsl/conjure/Conjure.py +0 -62
- edsl/conjure/InputData.py +0 -659
- edsl/conjure/InputDataCSV.py +0 -48
- edsl/conjure/InputDataMixinQuestionStats.py +0 -182
- edsl/conjure/InputDataPyRead.py +0 -91
- edsl/conjure/InputDataSPSS.py +0 -8
- edsl/conjure/InputDataStata.py +0 -8
- edsl/conjure/QuestionOptionMixin.py +0 -76
- edsl/conjure/QuestionTypeMixin.py +0 -23
- edsl/conjure/RawQuestion.py +0 -65
- edsl/conjure/SurveyResponses.py +0 -7
- edsl/conjure/__init__.py +0 -9
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/utilities.py +0 -201
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev2.dist-info}/WHEEL +0 -0
@@ -1,33 +1,21 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import os
|
3
4
|
import time
|
4
|
-
|
5
|
-
|
6
|
-
from
|
7
|
-
from
|
8
|
-
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
|
13
|
-
from rich.panel import Panel
|
14
|
-
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
|
15
|
-
from rich.layout import Layout
|
16
|
-
from rich.console import Group
|
17
|
-
from rich import box
|
18
|
-
|
19
|
-
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
5
|
+
import requests
|
6
|
+
import warnings
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from dataclasses import dataclass
|
9
|
+
|
10
|
+
from typing import Any, List, DefaultDict, Optional, Dict
|
11
|
+
from collections import defaultdict
|
12
|
+
from uuid import UUID
|
13
|
+
|
20
14
|
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
21
|
-
from edsl.jobs.tokens.TokenUsage import TokenUsage
|
22
|
-
from edsl.enums import get_token_pricing
|
23
|
-
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
24
15
|
|
25
16
|
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
26
17
|
|
27
18
|
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
28
|
-
from edsl.jobs.interviews.InterviewStatisticsCollection import (
|
29
|
-
InterviewStatisticsCollection,
|
30
|
-
)
|
31
19
|
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
32
20
|
|
33
21
|
|
@@ -47,16 +35,23 @@ class ModelTokenUsageStats:
|
|
47
35
|
cost: str
|
48
36
|
|
49
37
|
|
50
|
-
class
|
51
|
-
def elapsed_time(self):
|
52
|
-
InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
|
53
|
-
|
54
|
-
|
55
|
-
class JobsRunnerStatus:
|
38
|
+
class JobsRunnerStatusBase(ABC):
|
56
39
|
def __init__(
|
57
|
-
self,
|
40
|
+
self,
|
41
|
+
jobs_runner: "JobsRunnerAsyncio",
|
42
|
+
n: int,
|
43
|
+
refresh_rate: float = 1,
|
44
|
+
endpoint_url: Optional[str] = "http://localhost:8000",
|
45
|
+
job_uuid: Optional[UUID] = None,
|
46
|
+
api_key: str = None,
|
58
47
|
):
|
59
48
|
self.jobs_runner = jobs_runner
|
49
|
+
|
50
|
+
# The uuid of the job on Coop
|
51
|
+
self.job_uuid = job_uuid
|
52
|
+
|
53
|
+
self.base_url = f"{endpoint_url}"
|
54
|
+
|
60
55
|
self.start_time = time.time()
|
61
56
|
self.completed_interviews = []
|
62
57
|
self.refresh_rate = refresh_rate
|
@@ -80,6 +75,99 @@ class JobsRunnerStatus:
|
|
80
75
|
|
81
76
|
self.completed_interview_by_model = defaultdict(list)
|
82
77
|
|
78
|
+
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
79
|
+
|
80
|
+
@abstractmethod
|
81
|
+
def has_ep_api_key(self):
|
82
|
+
"""
|
83
|
+
Checks if the user has an Expected Parrot API key.
|
84
|
+
"""
|
85
|
+
pass
|
86
|
+
|
87
|
+
def get_status_dict(self) -> Dict[str, Any]:
|
88
|
+
"""
|
89
|
+
Converts current status into a JSON-serializable dictionary.
|
90
|
+
"""
|
91
|
+
# Get all statistics
|
92
|
+
stats = {}
|
93
|
+
for stat_name in self.statistics:
|
94
|
+
stat = self._compute_statistic(stat_name)
|
95
|
+
name, value = list(stat.items())[0]
|
96
|
+
stats[name] = value
|
97
|
+
|
98
|
+
# Calculate overall progress
|
99
|
+
total_interviews = len(self.jobs_runner.total_interviews)
|
100
|
+
completed = len(self.completed_interviews)
|
101
|
+
|
102
|
+
# Get model-specific progress
|
103
|
+
model_progress = {}
|
104
|
+
for model in self.distinct_models:
|
105
|
+
completed_for_model = len(self.completed_interview_by_model[model])
|
106
|
+
target_for_model = int(
|
107
|
+
self.num_total_interviews / len(self.distinct_models)
|
108
|
+
)
|
109
|
+
model_progress[model] = {
|
110
|
+
"completed": completed_for_model,
|
111
|
+
"total": target_for_model,
|
112
|
+
"percent": (
|
113
|
+
(completed_for_model / target_for_model * 100)
|
114
|
+
if target_for_model > 0
|
115
|
+
else 0
|
116
|
+
),
|
117
|
+
}
|
118
|
+
|
119
|
+
status_dict = {
|
120
|
+
"overall_progress": {
|
121
|
+
"completed": completed,
|
122
|
+
"total": total_interviews,
|
123
|
+
"percent": (
|
124
|
+
(completed / total_interviews * 100) if total_interviews > 0 else 0
|
125
|
+
),
|
126
|
+
},
|
127
|
+
"language_model_progress": model_progress,
|
128
|
+
"statistics": stats,
|
129
|
+
"status": "completed" if completed >= total_interviews else "running",
|
130
|
+
}
|
131
|
+
|
132
|
+
model_queues = {}
|
133
|
+
for model, bucket in self.jobs_runner.bucket_collection.items():
|
134
|
+
model_name = model.model
|
135
|
+
model_queues[model_name] = {
|
136
|
+
"language_model_name": model_name,
|
137
|
+
"requests_bucket": {
|
138
|
+
"completed": bucket.requests_bucket.num_released,
|
139
|
+
"requested": bucket.requests_bucket.num_requests,
|
140
|
+
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
141
|
+
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
142
|
+
"current_rate": round(bucket.requests_bucket.get_throughput(), 1),
|
143
|
+
},
|
144
|
+
"tokens_bucket": {
|
145
|
+
"completed": bucket.tokens_bucket.num_released,
|
146
|
+
"requested": bucket.tokens_bucket.num_requests,
|
147
|
+
"tokens_returned": bucket.tokens_bucket.tokens_returned,
|
148
|
+
"target_rate": round(bucket.tokens_bucket.target_rate, 1),
|
149
|
+
"current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
|
150
|
+
},
|
151
|
+
}
|
152
|
+
status_dict["language_model_queues"] = model_queues
|
153
|
+
return status_dict
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def setup(self):
|
157
|
+
"""
|
158
|
+
Conducts any setup that needs to happen prior to sending status updates.
|
159
|
+
|
160
|
+
Ex. For a local job, creates a job in the Coop database.
|
161
|
+
"""
|
162
|
+
pass
|
163
|
+
|
164
|
+
@abstractmethod
|
165
|
+
def send_status_update(self):
|
166
|
+
"""
|
167
|
+
Updates the current status of the job.
|
168
|
+
"""
|
169
|
+
pass
|
170
|
+
|
83
171
|
def add_completed_interview(self, result):
|
84
172
|
self.completed_interviews.append(result.interview_hash)
|
85
173
|
|
@@ -150,180 +238,92 @@ class JobsRunnerStatus:
|
|
150
238
|
}
|
151
239
|
return stat_definitions[stat_name]()
|
152
240
|
|
153
|
-
def
|
154
|
-
return Progress(
|
155
|
-
TextColumn("[progress.description]{task.description}"),
|
156
|
-
BarColumn(),
|
157
|
-
TaskProgressColumn(),
|
158
|
-
TextColumn("{task.completed}/{task.total}"),
|
159
|
-
)
|
241
|
+
def update_progress(self, stop_event):
|
160
242
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
table.add_column("Value", style="magenta")
|
165
|
-
# table.add_row("Bucket collection", str(self.jobs_runner.bucket_collection))
|
166
|
-
for model, bucket in self.jobs_runner.bucket_collection.items():
|
167
|
-
table.add_row(Text(model.model, style="bold blue"), "")
|
168
|
-
bucket_types = ["requests_bucket", "tokens_bucket"]
|
169
|
-
for bucket_type in bucket_types:
|
170
|
-
table.add_row(Text(" " + bucket_type, style="green"), "")
|
171
|
-
# table.add_row(
|
172
|
-
# f" Current level (capacity = {round(getattr(bucket, bucket_type).capacity, 3)})",
|
173
|
-
# str(round(getattr(bucket, bucket_type).tokens, 3)),
|
174
|
-
# )
|
175
|
-
num_requests = getattr(bucket, bucket_type).num_requests
|
176
|
-
num_released = getattr(bucket, bucket_type).num_released
|
177
|
-
tokens_returned = getattr(bucket, bucket_type).tokens_returned
|
178
|
-
# table.add_row(
|
179
|
-
# f" Requested",
|
180
|
-
# str(num_requests),
|
181
|
-
# )
|
182
|
-
# table.add_row(
|
183
|
-
# f" Completed",
|
184
|
-
# str(num_released),
|
185
|
-
# )
|
186
|
-
table.add_row(
|
187
|
-
" Completed vs. Requested", f"{num_released} vs. {num_requests}"
|
188
|
-
)
|
189
|
-
table.add_row(
|
190
|
-
" Added tokens (from cache)",
|
191
|
-
str(tokens_returned),
|
192
|
-
)
|
193
|
-
if bucket_type == "tokens_bucket":
|
194
|
-
rate_name = "TPM"
|
195
|
-
else:
|
196
|
-
rate_name = "RPM"
|
197
|
-
target_rate = round(getattr(bucket, bucket_type).target_rate, 1)
|
198
|
-
table.add_row(
|
199
|
-
f" Empirical {rate_name} (target = {target_rate})",
|
200
|
-
str(round(getattr(bucket, bucket_type).get_throughput(), 0)),
|
201
|
-
)
|
202
|
-
|
203
|
-
return table
|
204
|
-
|
205
|
-
def generate_layout(self):
|
206
|
-
progress = self.create_progress_bar()
|
207
|
-
task_ids = []
|
208
|
-
for model in self.distinct_models:
|
209
|
-
task_id = progress.add_task(
|
210
|
-
f"[cyan]{model}...",
|
211
|
-
total=int(self.num_total_interviews / len(self.distinct_models)),
|
212
|
-
)
|
213
|
-
task_ids.append((model, task_id))
|
214
|
-
|
215
|
-
progress_height = min(5, 2 + len(self.distinct_models))
|
216
|
-
layout = Layout()
|
217
|
-
|
218
|
-
# Create the top row with only the progress panel
|
219
|
-
layout.split_column(
|
220
|
-
Layout(
|
221
|
-
Panel(
|
222
|
-
progress,
|
223
|
-
title="Interview Progress",
|
224
|
-
border_style="cyan",
|
225
|
-
box=box.ROUNDED,
|
226
|
-
),
|
227
|
-
name="progress",
|
228
|
-
size=progress_height, # Adjusted size
|
229
|
-
),
|
230
|
-
Layout(name="bottom_row"), # Adjusted size
|
231
|
-
)
|
243
|
+
while not stop_event.is_set():
|
244
|
+
self.send_status_update()
|
245
|
+
time.sleep(self.refresh_rate)
|
232
246
|
|
233
|
-
|
234
|
-
layout["bottom_row"].split_row(
|
235
|
-
Layout(
|
236
|
-
Panel(
|
237
|
-
self.generate_metrics_table(),
|
238
|
-
title="Metrics",
|
239
|
-
border_style="magenta",
|
240
|
-
box=box.ROUNDED,
|
241
|
-
),
|
242
|
-
name="metrics",
|
243
|
-
),
|
244
|
-
Layout(
|
245
|
-
Panel(
|
246
|
-
self.generate_model_queues_table(),
|
247
|
-
title="Model Queues",
|
248
|
-
border_style="yellow",
|
249
|
-
box=box.ROUNDED,
|
250
|
-
),
|
251
|
-
name="model_queues",
|
252
|
-
),
|
253
|
-
)
|
247
|
+
self.send_status_update()
|
254
248
|
|
255
|
-
return layout, progress, task_ids
|
256
249
|
|
257
|
-
|
258
|
-
table = Table(show_header=True, header_style="bold magenta", box=box.SIMPLE)
|
259
|
-
table.add_column("Metric", style="cyan", no_wrap=True)
|
260
|
-
table.add_column("Value", justify="right")
|
250
|
+
class JobsRunnerStatus(JobsRunnerStatusBase):
|
261
251
|
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
table.add_row(pretty_name, value)
|
266
|
-
return table
|
252
|
+
@property
|
253
|
+
def create_url(self) -> str:
|
254
|
+
return f"{self.base_url}/api/v0/local-job"
|
267
255
|
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
256
|
+
@property
|
257
|
+
def viewing_url(self) -> str:
|
258
|
+
return f"{self.base_url}/home/local-job-progress/{str(self.job_uuid)}"
|
259
|
+
|
260
|
+
@property
|
261
|
+
def update_url(self) -> str:
|
262
|
+
return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
|
263
|
+
|
264
|
+
def setup(self) -> None:
|
265
|
+
"""
|
266
|
+
Creates a local job on Coop if one does not already exist.
|
267
|
+
"""
|
268
|
+
|
269
|
+
headers = {"Content-Type": "application/json"}
|
270
|
+
|
271
|
+
if self.api_key:
|
272
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
273
|
+
else:
|
274
|
+
headers["Authorization"] = f"Bearer None"
|
275
|
+
|
276
|
+
if self.job_uuid is None:
|
277
|
+
# Create a new local job
|
278
|
+
response = requests.post(
|
279
|
+
self.create_url,
|
280
|
+
headers=headers,
|
281
|
+
timeout=1,
|
282
|
+
)
|
283
|
+
response.raise_for_status()
|
284
|
+
data = response.json()
|
285
|
+
self.job_uuid = data.get("job_uuid")
|
286
|
+
|
287
|
+
print(f"Running with progress bar. View progress at {self.viewing_url}")
|
288
|
+
|
289
|
+
def send_status_update(self) -> None:
|
290
|
+
"""
|
291
|
+
Sends current status to the web endpoint using the instance's job_uuid.
|
292
|
+
"""
|
293
|
+
try:
|
294
|
+
# Get the status dictionary and add the job_id
|
295
|
+
status_dict = self.get_status_dict()
|
296
|
+
|
297
|
+
# Make the UUID JSON serializable
|
298
|
+
status_dict["job_id"] = str(self.job_uuid)
|
299
|
+
|
300
|
+
headers = {"Content-Type": "application/json"}
|
301
|
+
|
302
|
+
if self.api_key:
|
303
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
304
|
+
else:
|
305
|
+
headers["Authorization"] = f"Bearer None"
|
306
|
+
|
307
|
+
# Send the update
|
308
|
+
response = requests.patch(
|
309
|
+
self.update_url,
|
310
|
+
json=status_dict,
|
311
|
+
headers=headers,
|
312
|
+
timeout=1,
|
324
313
|
)
|
325
|
-
|
326
|
-
|
314
|
+
response.raise_for_status()
|
315
|
+
except requests.exceptions.RequestException as e:
|
316
|
+
print(f"Failed to send status update for job {self.job_uuid}: {e}")
|
317
|
+
|
318
|
+
def has_ep_api_key(self) -> bool:
|
319
|
+
"""
|
320
|
+
Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
|
321
|
+
"""
|
322
|
+
|
323
|
+
if self.api_key is not None:
|
324
|
+
return True
|
325
|
+
else:
|
326
|
+
return False
|
327
327
|
|
328
328
|
|
329
329
|
if __name__ == "__main__":
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -73,7 +73,7 @@ class TaskHistory:
|
|
73
73
|
"""Return a string representation of the TaskHistory."""
|
74
74
|
return f"TaskHistory(interviews={self.total_interviews})."
|
75
75
|
|
76
|
-
def to_dict(self):
|
76
|
+
def to_dict(self, add_edsl_version=True):
|
77
77
|
"""Return the TaskHistory as a dictionary."""
|
78
78
|
# return {
|
79
79
|
# "exceptions": [
|
@@ -82,10 +82,19 @@ class TaskHistory:
|
|
82
82
|
# ],
|
83
83
|
# "indices": self.indices,
|
84
84
|
# }
|
85
|
-
|
86
|
-
"interviews": [
|
85
|
+
d = {
|
86
|
+
"interviews": [
|
87
|
+
i.to_dict(add_edsl_version=add_edsl_version)
|
88
|
+
for i in self.total_interviews
|
89
|
+
],
|
87
90
|
"include_traceback": self.include_traceback,
|
88
91
|
}
|
92
|
+
if add_edsl_version:
|
93
|
+
from edsl import __version__
|
94
|
+
|
95
|
+
d["edsl_version"] = __version__
|
96
|
+
d["edsl_class_name"] = "TaskHistory"
|
97
|
+
return d
|
89
98
|
|
90
99
|
@classmethod
|
91
100
|
def from_dict(cls, data: dict):
|
@@ -607,18 +607,20 @@ class LanguageModel(
|
|
607
607
|
#######################
|
608
608
|
# SERIALIZATION METHODS
|
609
609
|
#######################
|
610
|
-
def
|
611
|
-
|
612
|
-
|
613
|
-
@add_edsl_version
|
614
|
-
def to_dict(self) -> dict[str, Any]:
|
615
|
-
"""Convert instance to a dictionary.
|
610
|
+
def to_dict(self, add_edsl_version=True) -> dict[str, Any]:
|
611
|
+
"""Convert instance to a dictionary
|
616
612
|
|
617
613
|
>>> m = LanguageModel.example()
|
618
614
|
>>> m.to_dict()
|
619
615
|
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
620
616
|
"""
|
621
|
-
|
617
|
+
d = {"model": self.model, "parameters": self.parameters}
|
618
|
+
if add_edsl_version:
|
619
|
+
from edsl import __version__
|
620
|
+
|
621
|
+
d["edsl_version"] = __version__
|
622
|
+
d["edsl_class_name"] = self.__class__.__name__
|
623
|
+
return d
|
622
624
|
|
623
625
|
@classmethod
|
624
626
|
@remove_edsl_version
|
@@ -46,14 +46,30 @@ class ModelList(Base, UserList):
|
|
46
46
|
"""
|
47
47
|
from edsl.utilities.utilities import dict_hash
|
48
48
|
|
49
|
-
return dict_hash(self.
|
49
|
+
return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
|
50
50
|
|
51
|
-
def
|
51
|
+
def to_dict(self, sort=False, add_edsl_version=True):
|
52
52
|
if sort:
|
53
53
|
model_list = sorted([model for model in self], key=lambda x: hash(x))
|
54
|
-
|
54
|
+
d = {
|
55
|
+
"models": [
|
56
|
+
model.to_dict(add_edsl_version=add_edsl_version)
|
57
|
+
for model in model_list
|
58
|
+
]
|
59
|
+
}
|
55
60
|
else:
|
56
|
-
|
61
|
+
d = {
|
62
|
+
"models": [
|
63
|
+
model.to_dict(add_edsl_version=add_edsl_version) for model in self
|
64
|
+
]
|
65
|
+
}
|
66
|
+
if add_edsl_version:
|
67
|
+
from edsl import __version__
|
68
|
+
|
69
|
+
d["edsl_version"] = __version__
|
70
|
+
d["edsl_class_name"] = "ModelList"
|
71
|
+
|
72
|
+
return d
|
57
73
|
|
58
74
|
@classmethod
|
59
75
|
def from_names(self, *args, **kwargs):
|
@@ -62,15 +78,6 @@ class ModelList(Base, UserList):
|
|
62
78
|
args = args[0]
|
63
79
|
return ModelList([Model(model_name, **kwargs) for model_name in args])
|
64
80
|
|
65
|
-
@add_edsl_version
|
66
|
-
def to_dict(self):
|
67
|
-
"""
|
68
|
-
Convert the ModelList to a dictionary.
|
69
|
-
>>> ModelList.example().to_dict()
|
70
|
-
{'models': [...], 'edsl_version': '...', 'edsl_class_name': 'ModelList'}
|
71
|
-
"""
|
72
|
-
return self._to_dict()
|
73
|
-
|
74
81
|
@classmethod
|
75
82
|
@remove_edsl_version
|
76
83
|
def from_dict(cls, data):
|
edsl/notebooks/Notebook.py
CHANGED
@@ -102,18 +102,17 @@ class Notebook(Base):
|
|
102
102
|
|
103
103
|
return dict_hash(self.data["cells"])
|
104
104
|
|
105
|
-
def
|
105
|
+
def to_dict(self, add_edsl_version=False) -> dict:
|
106
106
|
"""
|
107
107
|
Serialize to a dictionary.
|
108
108
|
"""
|
109
|
-
|
109
|
+
d = {"name": self.name, "data": self.data}
|
110
|
+
if add_edsl_version:
|
111
|
+
from edsl import __version__
|
110
112
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
Convert a Notebook to a dictionary.
|
115
|
-
"""
|
116
|
-
return self._to_dict()
|
113
|
+
d["edsl_version"] = __version__
|
114
|
+
d["edsl_class_name"] = self.__class__.__name__
|
115
|
+
return d
|
117
116
|
|
118
117
|
@classmethod
|
119
118
|
@remove_edsl_version
|
edsl/questions/QuestionBase.py
CHANGED
@@ -82,8 +82,7 @@ class QuestionBase(
|
|
82
82
|
if not hasattr(self, "_fake_data_factory"):
|
83
83
|
from polyfactory.factories.pydantic_factory import ModelFactory
|
84
84
|
|
85
|
-
class FakeData(ModelFactory[self.response_model]):
|
86
|
-
...
|
85
|
+
class FakeData(ModelFactory[self.response_model]): ...
|
87
86
|
|
88
87
|
self._fake_data_factory = FakeData
|
89
88
|
return self._fake_data_factory
|
@@ -135,7 +134,7 @@ class QuestionBase(
|
|
135
134
|
"""
|
136
135
|
from edsl.utilities.utilities import dict_hash
|
137
136
|
|
138
|
-
return dict_hash(self.
|
137
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
139
138
|
|
140
139
|
@property
|
141
140
|
def data(self) -> dict:
|
@@ -147,13 +146,15 @@ class QuestionBase(
|
|
147
146
|
"""
|
148
147
|
exclude_list = [
|
149
148
|
"question_type",
|
150
|
-
"_include_comment",
|
149
|
+
# "_include_comment",
|
151
150
|
"_fake_data_factory",
|
152
|
-
"_use_code",
|
151
|
+
# "_use_code",
|
153
152
|
"_model_instructions",
|
154
153
|
]
|
155
154
|
only_if_not_na_list = ["_answering_instructions", "_question_presentation"]
|
156
155
|
|
156
|
+
only_if_not_default_list = {"_include_comment": True, "_use_code": False}
|
157
|
+
|
157
158
|
def ok(key, value):
|
158
159
|
if not key.startswith("_"):
|
159
160
|
return False
|
@@ -161,6 +162,12 @@ class QuestionBase(
|
|
161
162
|
return False
|
162
163
|
if key in only_if_not_na_list and value is None:
|
163
164
|
return False
|
165
|
+
if (
|
166
|
+
key in only_if_not_default_list
|
167
|
+
and value == only_if_not_default_list[key]
|
168
|
+
):
|
169
|
+
return False
|
170
|
+
|
164
171
|
return True
|
165
172
|
|
166
173
|
candidate_data = {
|
@@ -175,25 +182,22 @@ class QuestionBase(
|
|
175
182
|
|
176
183
|
return candidate_data
|
177
184
|
|
178
|
-
def
|
185
|
+
def to_dict(self, add_edsl_version=True):
|
179
186
|
"""Convert the question to a dictionary that includes the question type (used in deserialization).
|
180
187
|
|
181
|
-
>>> from edsl import QuestionFreeText as Q; Q.example().
|
188
|
+
>>> from edsl import QuestionFreeText as Q; Q.example().to_dict(add_edsl_version = False)
|
182
189
|
{'question_name': 'how_are_you', 'question_text': 'How are you?', 'question_type': 'free_text'}
|
183
190
|
"""
|
184
191
|
candidate_data = self.data.copy()
|
185
192
|
candidate_data["question_type"] = self.question_type
|
186
|
-
|
187
|
-
|
188
|
-
|
193
|
+
d = {key: value for key, value in candidate_data.items() if value is not None}
|
194
|
+
if add_edsl_version:
|
195
|
+
from edsl import __version__
|
189
196
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
{'question_name': 'how_are_you', 'question_text': 'How are you?', 'question_type': 'free_text', 'edsl_version': '...'}
|
195
|
-
"""
|
196
|
-
return self._to_dict()
|
197
|
+
d["edsl_version"] = __version__
|
198
|
+
d["edsl_class_name"] = "QuestionBase"
|
199
|
+
|
200
|
+
return d
|
197
201
|
|
198
202
|
@classmethod
|
199
203
|
@remove_edsl_version
|
@@ -138,7 +138,7 @@ class QuestionBaseGenMixin:
|
|
138
138
|
if exclude_components is None:
|
139
139
|
exclude_components = ["question_name", "question_type"]
|
140
140
|
|
141
|
-
d = copy.deepcopy(self.
|
141
|
+
d = copy.deepcopy(self.to_dict(add_edsl_version=False))
|
142
142
|
for key, value in d.items():
|
143
143
|
if key in exclude_components:
|
144
144
|
continue
|