edsl 0.1.61__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +66 -0
- edsl/__version__.py +1 -1
- edsl/base/base_class.py +53 -0
- edsl/cli.py +93 -27
- edsl/config/config_class.py +4 -0
- edsl/coop/coop.py +403 -28
- edsl/coop/coop_jobs_objects.py +2 -2
- edsl/coop/coop_regular_objects.py +3 -1
- edsl/dataset/dataset.py +47 -41
- edsl/dataset/dataset_operations_mixin.py +138 -15
- edsl/dataset/report_from_template.py +509 -0
- edsl/inference_services/services/azure_ai.py +8 -2
- edsl/inference_services/services/open_ai_service.py +7 -5
- edsl/jobs/jobs.py +5 -4
- edsl/jobs/jobs_checks.py +11 -6
- edsl/jobs/remote_inference.py +17 -10
- edsl/prompts/prompt.py +7 -2
- edsl/questions/question_registry.py +4 -1
- edsl/results/result.py +93 -38
- edsl/results/results.py +24 -15
- edsl/scenarios/file_store.py +69 -0
- edsl/scenarios/scenario.py +233 -0
- edsl/scenarios/scenario_list.py +294 -130
- edsl/scenarios/scenario_source.py +1 -2
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/METADATA +1 -1
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/RECORD +29 -28
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/LICENSE +0 -0
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/WHEEL +0 -0
- {edsl-0.1.61.dist-info → edsl-1.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,509 @@
|
|
1
|
+
"""
|
2
|
+
Template-based report generation for EDSL datasets.
|
3
|
+
|
4
|
+
This module provides the TemplateReportGenerator class that handles Jinja2-based
|
5
|
+
report generation with support for various output formats including text and DOCX,
|
6
|
+
with optional markdown formatting.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Optional, Union, List, TYPE_CHECKING
|
10
|
+
import warnings
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from docx import Document
|
14
|
+
|
15
|
+
|
16
|
+
class TemplateReportGenerator:
|
17
|
+
"""
|
18
|
+
Handles template-based report generation for EDSL datasets.
|
19
|
+
|
20
|
+
This class encapsulates the logic for generating reports using Jinja2 templates,
|
21
|
+
with support for multiple output formats and advanced features like markdown
|
22
|
+
conversion to DOCX.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, dataset):
|
26
|
+
"""
|
27
|
+
Initialize the report generator with a dataset.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
dataset: The dataset object to generate reports from
|
31
|
+
"""
|
32
|
+
self.dataset = dataset
|
33
|
+
|
34
|
+
@staticmethod
|
35
|
+
def _is_pandoc_available() -> bool:
|
36
|
+
"""Check if pandoc is available on the system."""
|
37
|
+
try:
|
38
|
+
import subprocess
|
39
|
+
subprocess.run(["pandoc", "--version"], capture_output=True, check=True)
|
40
|
+
return True
|
41
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
42
|
+
return False
|
43
|
+
|
44
|
+
def generate_report(
|
45
|
+
self,
|
46
|
+
template: str,
|
47
|
+
*fields: Optional[str],
|
48
|
+
top_n: Optional[int] = None,
|
49
|
+
remove_prefix: bool = True,
|
50
|
+
return_string: bool = False,
|
51
|
+
format: str = "text",
|
52
|
+
filename: Optional[str] = None,
|
53
|
+
separator: str = "\n\n",
|
54
|
+
observation_title_template: Optional[str] = None,
|
55
|
+
explode: bool = False,
|
56
|
+
markdown_to_docx: bool = True,
|
57
|
+
use_pandoc: bool = True,
|
58
|
+
) -> Optional[Union[str, "Document", List]]:
|
59
|
+
"""Generates a report using a Jinja2 template for each row in the dataset."""
|
60
|
+
try:
|
61
|
+
from jinja2 import Template
|
62
|
+
except ImportError:
|
63
|
+
from .exceptions import DatasetImportError
|
64
|
+
raise DatasetImportError(
|
65
|
+
"The jinja2 package is required for template-based reports. Install it with 'pip install jinja2'."
|
66
|
+
)
|
67
|
+
|
68
|
+
from ..utilities.utilities import is_notebook
|
69
|
+
|
70
|
+
# If no fields specified, use all columns
|
71
|
+
if not fields:
|
72
|
+
fields = self.dataset.relevant_columns()
|
73
|
+
|
74
|
+
# Validate all fields exist
|
75
|
+
for field in fields:
|
76
|
+
if field not in self.dataset.relevant_columns():
|
77
|
+
from .exceptions import DatasetKeyError
|
78
|
+
raise DatasetKeyError(f"Field '{field}' not found in dataset")
|
79
|
+
|
80
|
+
# Get data as list of dictionaries
|
81
|
+
list_of_dicts = self.dataset.to_dicts(remove_prefix=remove_prefix)
|
82
|
+
|
83
|
+
# Apply top_n limit if specified
|
84
|
+
if top_n is not None:
|
85
|
+
list_of_dicts = list_of_dicts[:top_n]
|
86
|
+
|
87
|
+
# Filter to only include requested fields if specified
|
88
|
+
list_of_dicts = self._filter_fields(list_of_dicts, fields, remove_prefix)
|
89
|
+
|
90
|
+
# Create Jinja2 template
|
91
|
+
jinja_template = Template(template)
|
92
|
+
|
93
|
+
# Render template for each row
|
94
|
+
rendered_reports = self._render_templates(jinja_template, list_of_dicts)
|
95
|
+
|
96
|
+
# Set up observation title template
|
97
|
+
if observation_title_template is None:
|
98
|
+
observation_title_template = "Observation {{ index }}"
|
99
|
+
|
100
|
+
# Create observation title Jinja2 template
|
101
|
+
title_template = Template(observation_title_template)
|
102
|
+
|
103
|
+
# Handle explode mode - create separate files/documents per observation
|
104
|
+
if explode:
|
105
|
+
return self._handle_explode_mode(
|
106
|
+
rendered_reports, list_of_dicts, title_template, format,
|
107
|
+
filename, markdown_to_docx, use_pandoc
|
108
|
+
)
|
109
|
+
|
110
|
+
# Handle non-explode mode (original combined behavior)
|
111
|
+
return self._handle_combined_mode(
|
112
|
+
rendered_reports, list_of_dicts, title_template, format,
|
113
|
+
filename, separator, return_string, markdown_to_docx, use_pandoc
|
114
|
+
)
|
115
|
+
|
116
|
+
def _filter_fields(self, list_of_dicts: List[dict], fields: tuple, remove_prefix: bool) -> List[dict]:
|
117
|
+
"""Filter the list of dictionaries to only include requested fields."""
|
118
|
+
if fields and remove_prefix:
|
119
|
+
# Remove prefixes from field names for filtering
|
120
|
+
filter_fields = [field.split(".")[-1] if "." in field else field for field in fields]
|
121
|
+
return [
|
122
|
+
{k: v for k, v in row.items() if k in filter_fields}
|
123
|
+
for row in list_of_dicts
|
124
|
+
]
|
125
|
+
elif fields:
|
126
|
+
# Use exact field names for filtering
|
127
|
+
return [
|
128
|
+
{k: v for k, v in row.items() if k in fields}
|
129
|
+
for row in list_of_dicts
|
130
|
+
]
|
131
|
+
return list_of_dicts
|
132
|
+
|
133
|
+
def _render_templates(self, jinja_template, list_of_dicts: List[dict]) -> List[str]:
|
134
|
+
"""Render the Jinja2 template for each row of data."""
|
135
|
+
rendered_reports = []
|
136
|
+
for i, row_data in enumerate(list_of_dicts):
|
137
|
+
try:
|
138
|
+
# Add index variables to template context
|
139
|
+
template_data = row_data.copy()
|
140
|
+
template_data['index'] = i + 1
|
141
|
+
template_data['index0'] = i
|
142
|
+
rendered = jinja_template.render(**template_data)
|
143
|
+
rendered_reports.append(rendered)
|
144
|
+
except Exception as e:
|
145
|
+
from .exceptions import DatasetValueError
|
146
|
+
raise DatasetValueError(f"Error rendering template with data {row_data}: {e}")
|
147
|
+
return rendered_reports
|
148
|
+
|
149
|
+
def _convert_markdown_to_docx(self, markdown_content: str, use_pandoc: bool, temp_dir: str = None) -> "Document":
|
150
|
+
"""Convert markdown content to a DOCX document."""
|
151
|
+
if use_pandoc:
|
152
|
+
return self._convert_with_pandoc(markdown_content, temp_dir)
|
153
|
+
else:
|
154
|
+
return self._convert_with_python(markdown_content)
|
155
|
+
|
156
|
+
def _convert_with_pandoc(self, markdown_content: str, temp_dir: str = None) -> "Document":
|
157
|
+
"""Use pandoc for markdown to DOCX conversion."""
|
158
|
+
import subprocess
|
159
|
+
import tempfile
|
160
|
+
import os
|
161
|
+
|
162
|
+
# Check if pandoc is available
|
163
|
+
if not self._is_pandoc_available():
|
164
|
+
from .exceptions import DatasetExportError
|
165
|
+
raise DatasetExportError(
|
166
|
+
"Pandoc is not installed or not available in PATH. "
|
167
|
+
"To fix this: 1) Install pandoc (https://pandoc.org/installing.html), "
|
168
|
+
"2) Set use_pandoc=False to use Python-based conversion, or "
|
169
|
+
"3) Use format='text' instead of 'docx' for simple text output."
|
170
|
+
)
|
171
|
+
|
172
|
+
# Create temporary files
|
173
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, dir=temp_dir) as md_file:
|
174
|
+
md_file.write(markdown_content)
|
175
|
+
md_filename = md_file.name
|
176
|
+
|
177
|
+
docx_filename = md_filename.replace('.md', '.docx')
|
178
|
+
|
179
|
+
try:
|
180
|
+
# Run pandoc conversion
|
181
|
+
subprocess.run([
|
182
|
+
"pandoc",
|
183
|
+
md_filename,
|
184
|
+
"-o", docx_filename,
|
185
|
+
"--from", "markdown",
|
186
|
+
"--to", "docx"
|
187
|
+
], check=True)
|
188
|
+
|
189
|
+
# Load the generated DOCX
|
190
|
+
from docx import Document
|
191
|
+
doc = Document(docx_filename)
|
192
|
+
|
193
|
+
# Clean up temporary files
|
194
|
+
os.unlink(md_filename)
|
195
|
+
os.unlink(docx_filename)
|
196
|
+
|
197
|
+
return doc
|
198
|
+
|
199
|
+
except subprocess.CalledProcessError as e:
|
200
|
+
# Clean up on error
|
201
|
+
if os.path.exists(md_filename):
|
202
|
+
os.unlink(md_filename)
|
203
|
+
if os.path.exists(docx_filename):
|
204
|
+
os.unlink(docx_filename)
|
205
|
+
from .exceptions import DatasetExportError
|
206
|
+
raise DatasetExportError(f"Pandoc conversion failed: {e}")
|
207
|
+
|
208
|
+
def _convert_with_python(self, markdown_content: str) -> "Document":
|
209
|
+
"""Use Python-based conversion for markdown to DOCX."""
|
210
|
+
try:
|
211
|
+
import markdown
|
212
|
+
from markdown.extensions import codehilite, fenced_code, tables
|
213
|
+
from docx import Document
|
214
|
+
from docx.shared import Pt, RGBColor
|
215
|
+
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
|
216
|
+
import re
|
217
|
+
import html
|
218
|
+
except ImportError:
|
219
|
+
from .exceptions import DatasetImportError
|
220
|
+
raise DatasetImportError(
|
221
|
+
"Python-based markdown conversion requires 'markdown' and 'python-docx' packages. "
|
222
|
+
"Install with 'pip install markdown python-docx' or set use_pandoc=True to use pandoc."
|
223
|
+
)
|
224
|
+
|
225
|
+
# Convert markdown to HTML first
|
226
|
+
md = markdown.Markdown(extensions=['codehilite', 'fenced_code', 'tables', 'nl2br'])
|
227
|
+
html_content = md.convert(markdown_content)
|
228
|
+
|
229
|
+
# Create a new document
|
230
|
+
doc = Document()
|
231
|
+
|
232
|
+
# Parse HTML and convert to DOCX elements
|
233
|
+
lines = html_content.split('\n')
|
234
|
+
|
235
|
+
for line in lines:
|
236
|
+
line = line.strip()
|
237
|
+
if not line:
|
238
|
+
continue
|
239
|
+
|
240
|
+
# Handle headers
|
241
|
+
if line.startswith('<h1>') and line.endswith('</h1>'):
|
242
|
+
text = html.unescape(re.sub(r'<[^>]+>', '', line))
|
243
|
+
doc.add_heading(text, level=1)
|
244
|
+
elif line.startswith('<h2>') and line.endswith('</h2>'):
|
245
|
+
text = html.unescape(re.sub(r'<[^>]+>', '', line))
|
246
|
+
doc.add_heading(text, level=2)
|
247
|
+
elif line.startswith('<h3>') and line.endswith('</h3>'):
|
248
|
+
text = html.unescape(re.sub(r'<[^>]+>', '', line))
|
249
|
+
doc.add_heading(text, level=3)
|
250
|
+
# Handle paragraphs
|
251
|
+
elif line.startswith('<p>') and line.endswith('</p>'):
|
252
|
+
text = html.unescape(re.sub(r'<[^>]+>', '', line))
|
253
|
+
if text.strip():
|
254
|
+
p = doc.add_paragraph()
|
255
|
+
# Handle basic formatting within paragraphs
|
256
|
+
parts = re.split(r'(<strong>.*?</strong>|<em>.*?</em>|<code>.*?</code>)', text)
|
257
|
+
for part in parts:
|
258
|
+
if part.startswith('<strong>') and part.endswith('</strong>'):
|
259
|
+
clean_text = html.unescape(re.sub(r'<[^>]+>', '', part))
|
260
|
+
run = p.add_run(clean_text)
|
261
|
+
run.bold = True
|
262
|
+
elif part.startswith('<em>') and part.endswith('</em>'):
|
263
|
+
clean_text = html.unescape(re.sub(r'<[^>]+>', '', part))
|
264
|
+
run = p.add_run(clean_text)
|
265
|
+
run.italic = True
|
266
|
+
elif part.startswith('<code>') and part.endswith('</code>'):
|
267
|
+
clean_text = html.unescape(re.sub(r'<[^>]+>', '', part))
|
268
|
+
run = p.add_run(clean_text)
|
269
|
+
run.font.name = 'Courier New'
|
270
|
+
run.font.size = Pt(10)
|
271
|
+
else:
|
272
|
+
if part.strip():
|
273
|
+
p.add_run(html.unescape(part))
|
274
|
+
# Handle code blocks
|
275
|
+
elif '<pre>' in line or '<code>' in line:
|
276
|
+
text = html.unescape(re.sub(r'<[^>]+>', '', line))
|
277
|
+
if text.strip():
|
278
|
+
p = doc.add_paragraph()
|
279
|
+
run = p.add_run(text)
|
280
|
+
run.font.name = 'Courier New'
|
281
|
+
run.font.size = Pt(10)
|
282
|
+
# Handle other content
|
283
|
+
else:
|
284
|
+
text = html.unescape(re.sub(r'<[^>]+>', '', line))
|
285
|
+
if text.strip():
|
286
|
+
doc.add_paragraph(text)
|
287
|
+
|
288
|
+
return doc
|
289
|
+
|
290
|
+
def _handle_explode_mode(
|
291
|
+
self, rendered_reports: List[str], list_of_dicts: List[dict],
|
292
|
+
title_template, format: str, filename: Optional[str],
|
293
|
+
markdown_to_docx: bool, use_pandoc: bool
|
294
|
+
) -> List:
|
295
|
+
"""Handle explode mode - create separate files/documents per observation."""
|
296
|
+
# Validate filename template when exploding to files
|
297
|
+
if filename and not any(var in filename for var in ['{index}', '{index0}'] + list(list_of_dicts[0].keys()) if list_of_dicts):
|
298
|
+
warnings.warn(
|
299
|
+
"When explode=True, filename should contain template variables like {index} "
|
300
|
+
"to avoid overwriting files. Example: 'report_{index}.docx'"
|
301
|
+
)
|
302
|
+
|
303
|
+
results = []
|
304
|
+
|
305
|
+
for i, (rendered_content, row_data) in enumerate(zip(rendered_reports, list_of_dicts)):
|
306
|
+
# Add index variables to row data for title template
|
307
|
+
title_data = row_data.copy()
|
308
|
+
title_data['index'] = i + 1
|
309
|
+
title_data['index0'] = i
|
310
|
+
|
311
|
+
# Render the observation title
|
312
|
+
observation_title = title_template.render(**title_data)
|
313
|
+
|
314
|
+
if format.lower() == "docx":
|
315
|
+
doc = self._create_docx_document(
|
316
|
+
observation_title, rendered_content, markdown_to_docx, use_pandoc
|
317
|
+
)
|
318
|
+
|
319
|
+
if filename:
|
320
|
+
# Generate filename from template
|
321
|
+
individual_filename = filename.format(index=i+1, index0=i, **row_data)
|
322
|
+
doc.save(individual_filename)
|
323
|
+
results.append(individual_filename)
|
324
|
+
else:
|
325
|
+
results.append(doc)
|
326
|
+
|
327
|
+
elif format.lower() == "text":
|
328
|
+
# Generate individual text content
|
329
|
+
individual_content = f"# {observation_title}\n\n{rendered_content}"
|
330
|
+
|
331
|
+
if filename:
|
332
|
+
# Generate filename from template
|
333
|
+
individual_filename = filename.format(index=i+1, index0=i, **row_data)
|
334
|
+
with open(individual_filename, 'w', encoding='utf-8') as f:
|
335
|
+
f.write(individual_content)
|
336
|
+
results.append(individual_filename)
|
337
|
+
else:
|
338
|
+
results.append(individual_content)
|
339
|
+
|
340
|
+
if filename:
|
341
|
+
print(f"Created {len(results)} individual files")
|
342
|
+
|
343
|
+
return results
|
344
|
+
|
345
|
+
def _handle_combined_mode(
|
346
|
+
self, rendered_reports: List[str], list_of_dicts: List[dict],
|
347
|
+
title_template, format: str, filename: Optional[str],
|
348
|
+
separator: str, return_string: bool, markdown_to_docx: bool, use_pandoc: bool
|
349
|
+
) -> Optional[Union[str, "Document"]]:
|
350
|
+
"""Handle non-explode mode (original combined behavior)."""
|
351
|
+
if format.lower() == "docx":
|
352
|
+
return self._create_combined_docx(
|
353
|
+
rendered_reports, list_of_dicts, title_template,
|
354
|
+
filename, markdown_to_docx, use_pandoc
|
355
|
+
)
|
356
|
+
elif format.lower() == "text":
|
357
|
+
return self._create_combined_text(
|
358
|
+
rendered_reports, list_of_dicts, title_template,
|
359
|
+
filename, separator, return_string
|
360
|
+
)
|
361
|
+
else:
|
362
|
+
from .exceptions import DatasetExportError
|
363
|
+
raise DatasetExportError(
|
364
|
+
f"Unsupported format: {format}. Use 'text' or 'docx'."
|
365
|
+
)
|
366
|
+
|
367
|
+
def _create_docx_document(
|
368
|
+
self, observation_title: str, rendered_content: str,
|
369
|
+
markdown_to_docx: bool, use_pandoc: bool
|
370
|
+
) -> "Document":
|
371
|
+
"""Create a single DOCX document for an observation."""
|
372
|
+
try:
|
373
|
+
from docx import Document
|
374
|
+
from docx.shared import Pt
|
375
|
+
except ImportError:
|
376
|
+
from .exceptions import DatasetImportError
|
377
|
+
raise DatasetImportError(
|
378
|
+
"The python-docx package is required for DOCX export. Install it with 'pip install python-docx'."
|
379
|
+
)
|
380
|
+
|
381
|
+
if markdown_to_docx:
|
382
|
+
# Convert markdown content to DOCX with proper formatting
|
383
|
+
full_markdown = f"# {observation_title}\n\n{rendered_content}"
|
384
|
+
return self._convert_markdown_to_docx(full_markdown, use_pandoc)
|
385
|
+
else:
|
386
|
+
# Use plain text approach (original behavior)
|
387
|
+
doc = Document()
|
388
|
+
doc.add_heading(observation_title, level=1)
|
389
|
+
|
390
|
+
# Add the rendered template content
|
391
|
+
lines = rendered_content.split('\n')
|
392
|
+
for line in lines:
|
393
|
+
if line.strip():
|
394
|
+
doc.add_paragraph(line)
|
395
|
+
else:
|
396
|
+
doc.add_paragraph()
|
397
|
+
return doc
|
398
|
+
|
399
|
+
def _create_combined_docx(
|
400
|
+
self, rendered_reports: List[str], list_of_dicts: List[dict],
|
401
|
+
title_template, filename: Optional[str], markdown_to_docx: bool, use_pandoc: bool
|
402
|
+
) -> Optional["Document"]:
|
403
|
+
"""Create a combined DOCX document for all observations."""
|
404
|
+
try:
|
405
|
+
from docx import Document
|
406
|
+
from docx.shared import Pt
|
407
|
+
except ImportError:
|
408
|
+
from .exceptions import DatasetImportError
|
409
|
+
raise DatasetImportError(
|
410
|
+
"The python-docx package is required for DOCX export. Install it with 'pip install python-docx'."
|
411
|
+
)
|
412
|
+
|
413
|
+
if markdown_to_docx:
|
414
|
+
# Convert all content to one markdown document
|
415
|
+
markdown_parts = []
|
416
|
+
for i, (rendered_content, row_data) in enumerate(zip(rendered_reports, list_of_dicts)):
|
417
|
+
# Add index variables to row data for title template
|
418
|
+
title_data = row_data.copy()
|
419
|
+
title_data['index'] = i + 1
|
420
|
+
title_data['index0'] = i
|
421
|
+
|
422
|
+
# Render the observation title
|
423
|
+
observation_title = title_template.render(**title_data)
|
424
|
+
|
425
|
+
# Add title and content as markdown
|
426
|
+
section_markdown = f"# {observation_title}\n\n{rendered_content}"
|
427
|
+
markdown_parts.append(section_markdown)
|
428
|
+
|
429
|
+
# Combine all markdown sections
|
430
|
+
full_markdown = "\n\n\\pagebreak\n\n".join(markdown_parts)
|
431
|
+
doc = self._convert_markdown_to_docx(full_markdown, use_pandoc)
|
432
|
+
else:
|
433
|
+
# Use plain text approach (original behavior)
|
434
|
+
doc = Document()
|
435
|
+
|
436
|
+
for i, (rendered_content, row_data) in enumerate(zip(rendered_reports, list_of_dicts)):
|
437
|
+
# Add index variables to row data for title template
|
438
|
+
title_data = row_data.copy()
|
439
|
+
title_data['index'] = i + 1
|
440
|
+
title_data['index0'] = i
|
441
|
+
|
442
|
+
# Render the observation title
|
443
|
+
observation_title = title_template.render(**title_data)
|
444
|
+
|
445
|
+
# Add a heading for each observation
|
446
|
+
doc.add_heading(observation_title, level=1)
|
447
|
+
|
448
|
+
# Add the rendered template content
|
449
|
+
lines = rendered_content.split('\n')
|
450
|
+
for line in lines:
|
451
|
+
if line.strip():
|
452
|
+
doc.add_paragraph(line)
|
453
|
+
else:
|
454
|
+
doc.add_paragraph()
|
455
|
+
|
456
|
+
# Add page break between observations except for the last one
|
457
|
+
if i < len(rendered_reports) - 1:
|
458
|
+
doc.add_page_break()
|
459
|
+
|
460
|
+
# Save to file if filename is provided
|
461
|
+
if filename:
|
462
|
+
doc.save(filename)
|
463
|
+
print(f"Report saved to {filename}")
|
464
|
+
return None
|
465
|
+
|
466
|
+
return doc
|
467
|
+
|
468
|
+
def _create_combined_text(
|
469
|
+
self, rendered_reports: List[str], list_of_dicts: List[dict],
|
470
|
+
title_template, filename: Optional[str], separator: str, return_string: bool
|
471
|
+
) -> Optional[str]:
|
472
|
+
"""Create a combined text report for all observations."""
|
473
|
+
# Handle text format with custom observation titles
|
474
|
+
final_report_parts = []
|
475
|
+
|
476
|
+
for i, (rendered_content, row_data) in enumerate(zip(rendered_reports, list_of_dicts)):
|
477
|
+
# Add index variables to row data for title template
|
478
|
+
title_data = row_data.copy()
|
479
|
+
title_data['index'] = i + 1
|
480
|
+
title_data['index0'] = i
|
481
|
+
|
482
|
+
# Render the observation title
|
483
|
+
observation_title = title_template.render(**title_data)
|
484
|
+
|
485
|
+
# Combine title and content
|
486
|
+
section_content = f"# {observation_title}\n\n{rendered_content}"
|
487
|
+
final_report_parts.append(section_content)
|
488
|
+
|
489
|
+
final_report = separator.join(final_report_parts)
|
490
|
+
|
491
|
+
# Save to file if filename is provided
|
492
|
+
if filename:
|
493
|
+
with open(filename, 'w', encoding='utf-8') as f:
|
494
|
+
f.write(final_report)
|
495
|
+
print(f"Report saved to {filename}")
|
496
|
+
if not return_string:
|
497
|
+
return None
|
498
|
+
|
499
|
+
# In notebooks, display the content
|
500
|
+
from ..utilities.utilities import is_notebook
|
501
|
+
is_nb = is_notebook()
|
502
|
+
if is_nb and not return_string:
|
503
|
+
from IPython.display import display, HTML
|
504
|
+
# Use HTML display to preserve formatting
|
505
|
+
display(HTML(f"<pre>{final_report}</pre>"))
|
506
|
+
return None
|
507
|
+
|
508
|
+
# Return the string if requested or if not in a notebook
|
509
|
+
return final_report
|
@@ -2,6 +2,7 @@ import os
|
|
2
2
|
from typing import Any, Optional, List, TYPE_CHECKING
|
3
3
|
from openai import AsyncAzureOpenAI
|
4
4
|
from ..inference_service_abc import InferenceServiceABC
|
5
|
+
|
5
6
|
# Use TYPE_CHECKING to avoid circular imports at runtime
|
6
7
|
if TYPE_CHECKING:
|
7
8
|
from ...language_models import LanguageModel
|
@@ -49,7 +50,9 @@ class AzureAIService(InferenceServiceABC):
|
|
49
50
|
azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
|
50
51
|
if not azure_endpoints:
|
51
52
|
from ..exceptions import InferenceServiceEnvironmentError
|
52
|
-
|
53
|
+
|
54
|
+
return []
|
55
|
+
# raise InferenceServiceEnvironmentError("AZURE_ENDPOINT_URL_AND_KEY is not defined")
|
53
56
|
azure_endpoints = azure_endpoints.split(",")
|
54
57
|
for data in azure_endpoints:
|
55
58
|
try:
|
@@ -100,12 +103,13 @@ class AzureAIService(InferenceServiceABC):
|
|
100
103
|
@classmethod
|
101
104
|
def create_model(
|
102
105
|
cls, model_name: str = "azureai", model_class_name=None
|
103
|
-
) ->
|
106
|
+
) -> "LanguageModel":
|
104
107
|
if model_class_name is None:
|
105
108
|
model_class_name = cls.to_class_name(model_name)
|
106
109
|
|
107
110
|
# Import LanguageModel only when actually creating a model
|
108
111
|
from ...language_models import LanguageModel
|
112
|
+
|
109
113
|
class LLM(LanguageModel):
|
110
114
|
"""
|
111
115
|
Child class of LanguageModel for interacting with Azure OpenAI models.
|
@@ -140,6 +144,7 @@ class AzureAIService(InferenceServiceABC):
|
|
140
144
|
|
141
145
|
if not api_key:
|
142
146
|
from ..exceptions import InferenceServiceEnvironmentError
|
147
|
+
|
143
148
|
raise InferenceServiceEnvironmentError(
|
144
149
|
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
145
150
|
)
|
@@ -151,6 +156,7 @@ class AzureAIService(InferenceServiceABC):
|
|
151
156
|
|
152
157
|
if not endpoint:
|
153
158
|
from ..exceptions import InferenceServiceEnvironmentError
|
159
|
+
|
154
160
|
raise InferenceServiceEnvironmentError(
|
155
161
|
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
156
162
|
)
|
@@ -5,6 +5,7 @@ import os
|
|
5
5
|
import openai
|
6
6
|
|
7
7
|
from ..inference_service_abc import InferenceServiceABC
|
8
|
+
|
8
9
|
# Use TYPE_CHECKING to avoid circular imports at runtime
|
9
10
|
if TYPE_CHECKING:
|
10
11
|
from ...language_models import LanguageModel
|
@@ -112,12 +113,13 @@ class OpenAIService(InferenceServiceABC):
|
|
112
113
|
return cls._models_list_cache
|
113
114
|
|
114
115
|
@classmethod
|
115
|
-
def create_model(cls, model_name, model_class_name=None) ->
|
116
|
+
def create_model(cls, model_name, model_class_name=None) -> "LanguageModel":
|
116
117
|
if model_class_name is None:
|
117
118
|
model_class_name = cls.to_class_name(model_name)
|
118
119
|
|
119
120
|
# Import LanguageModel only when actually creating a model
|
120
121
|
from ...language_models import LanguageModel
|
122
|
+
|
121
123
|
class LLM(LanguageModel):
|
122
124
|
"""
|
123
125
|
Child class of LanguageModel for interacting with OpenAI models
|
@@ -236,10 +238,10 @@ class OpenAIService(InferenceServiceABC):
|
|
236
238
|
try:
|
237
239
|
response = await client.chat.completions.create(**params)
|
238
240
|
except Exception as e:
|
239
|
-
#breakpoint()
|
240
|
-
#print(e)
|
241
|
-
#raise e
|
242
|
-
return {
|
241
|
+
# breakpoint()
|
242
|
+
# print(e)
|
243
|
+
# raise e
|
244
|
+
return {"message": str(e)}
|
243
245
|
return response.model_dump()
|
244
246
|
|
245
247
|
LLM.__name__ = "LanguageModel"
|
edsl/jobs/jobs.py
CHANGED
@@ -602,8 +602,8 @@ class Jobs(Base):
|
|
602
602
|
|
603
603
|
def _check_if_remote_keys_ok(self):
|
604
604
|
jc = JobsChecks(self)
|
605
|
-
if jc.
|
606
|
-
jc.key_process()
|
605
|
+
if not jc.user_has_ep_api_key():
|
606
|
+
jc.key_process(remote_inference=True)
|
607
607
|
|
608
608
|
def _check_if_local_keys_ok(self):
|
609
609
|
jc = JobsChecks(self)
|
@@ -758,7 +758,8 @@ class Jobs(Base):
|
|
758
758
|
# Make sure all required objects exist
|
759
759
|
self.replace_missing_objects()
|
760
760
|
self._prepare_to_run()
|
761
|
-
self.
|
761
|
+
if not self.run_config.parameters.disable_remote_inference:
|
762
|
+
self._check_if_remote_keys_ok()
|
762
763
|
|
763
764
|
# Setup caching
|
764
765
|
from ..caching import CacheHandler, Cache
|
@@ -1120,7 +1121,7 @@ class Jobs(Base):
|
|
1120
1121
|
raise CoopValueError(
|
1121
1122
|
"You must specify both a scenario list and a scenario list method to use scenarios with your survey."
|
1122
1123
|
)
|
1123
|
-
elif scenario_list_method
|
1124
|
+
elif scenario_list_method == "loop":
|
1124
1125
|
questions, long_scenario_list = self.survey.to_long_format(self.scenarios)
|
1125
1126
|
|
1126
1127
|
# Replace the questions with new ones from the loop method
|
edsl/jobs/jobs_checks.py
CHANGED
@@ -28,7 +28,7 @@ class JobsChecks:
|
|
28
28
|
raise MissingAPIKeyError(
|
29
29
|
model_name=str(model.model),
|
30
30
|
inference_service=model._inference_service_,
|
31
|
-
silent=False
|
31
|
+
silent=False,
|
32
32
|
)
|
33
33
|
|
34
34
|
def get_missing_api_keys(self) -> set:
|
@@ -39,7 +39,7 @@ class JobsChecks:
|
|
39
39
|
|
40
40
|
from ..enums import service_to_api_keyname
|
41
41
|
|
42
|
-
for model in self.jobs.models:
|
42
|
+
for model in self.jobs.models: # + [Model()]:
|
43
43
|
if not model.has_valid_api_key():
|
44
44
|
key_name = service_to_api_keyname.get(
|
45
45
|
model._inference_service_, "NOT FOUND"
|
@@ -131,13 +131,16 @@ class JobsChecks:
|
|
131
131
|
"needs_key_process": self.needs_key_process(),
|
132
132
|
}
|
133
133
|
|
134
|
-
def key_process(self):
|
134
|
+
def key_process(self, remote_inference: bool = False) -> None:
|
135
135
|
import secrets
|
136
136
|
from dotenv import load_dotenv
|
137
137
|
from ..coop.coop import Coop
|
138
138
|
from ..utilities.utilities import write_api_key_to_env
|
139
139
|
|
140
|
-
|
140
|
+
if remote_inference:
|
141
|
+
missing_api_keys = ["EXPECTED_PARROT_API_KEY"]
|
142
|
+
else:
|
143
|
+
missing_api_keys = self.get_missing_api_keys()
|
141
144
|
|
142
145
|
edsl_auth_token = secrets.token_urlsafe(16)
|
143
146
|
|
@@ -150,7 +153,7 @@ class JobsChecks:
|
|
150
153
|
\nClick the link below to create an account and run your survey with your Expected Parrot key:
|
151
154
|
"""
|
152
155
|
)
|
153
|
-
|
156
|
+
|
154
157
|
coop = Coop()
|
155
158
|
coop._display_login_url(
|
156
159
|
edsl_auth_token=edsl_auth_token,
|
@@ -164,7 +167,9 @@ class JobsChecks:
|
|
164
167
|
return
|
165
168
|
|
166
169
|
path_to_env = write_api_key_to_env(api_key)
|
167
|
-
print(
|
170
|
+
print(
|
171
|
+
f"\n✨ Your Expected Parrot key has been stored at the following path: {path_to_env}\n"
|
172
|
+
)
|
168
173
|
|
169
174
|
# Retrieve API key so we can continue running the job
|
170
175
|
load_dotenv()
|