themefinder 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themefinder/__init__.py +24 -0
- themefinder/advanced_tasks/__init__.py +0 -0
- themefinder/advanced_tasks/cross_cutting_themes_agent.py +404 -0
- themefinder/advanced_tasks/theme_clustering_agent.py +356 -0
- themefinder/llm_batch_processor.py +442 -0
- themefinder/models.py +438 -0
- themefinder/prompts/agentic_theme_clustering.txt +34 -0
- themefinder/prompts/consultation_system_prompt.txt +1 -0
- themefinder/prompts/cross_cutting_identification.txt +16 -0
- themefinder/prompts/cross_cutting_mapping.txt +19 -0
- themefinder/prompts/cross_cutting_refinement.txt +15 -0
- themefinder/prompts/detail_detection.txt +31 -0
- themefinder/prompts/sentiment_analysis.txt +41 -0
- themefinder/prompts/theme_condensation.txt +34 -0
- themefinder/prompts/theme_generation.txt +38 -0
- themefinder/prompts/theme_mapping.txt +36 -0
- themefinder/prompts/theme_refinement.txt +54 -0
- themefinder/prompts/theme_target_alignment.txt +18 -0
- themefinder/tasks.py +656 -0
- themefinder/themefinder_logging.py +12 -0
- themefinder-0.7.4.dist-info/METADATA +174 -0
- themefinder-0.7.4.dist-info/RECORD +24 -0
- themefinder-0.7.4.dist-info/WHEEL +4 -0
- themefinder-0.7.4.dist-info/licenses/LICENCE +21 -0
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
import openai
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import tiktoken
|
|
11
|
+
from langchain_core.prompts import PromptTemplate
|
|
12
|
+
from langchain_core.runnables import Runnable
|
|
13
|
+
from pydantic import ValidationError
|
|
14
|
+
from tenacity import (
|
|
15
|
+
before,
|
|
16
|
+
before_sleep_log,
|
|
17
|
+
retry,
|
|
18
|
+
stop_after_attempt,
|
|
19
|
+
wait_random_exponential,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from themefinder.themefinder_logging import logger
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class BatchPrompt:
|
|
27
|
+
prompt_string: str
|
|
28
|
+
response_ids: list[int]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def batch_and_run(
|
|
32
|
+
input_df: pd.DataFrame,
|
|
33
|
+
prompt_template: str | Path | PromptTemplate,
|
|
34
|
+
llm: Runnable,
|
|
35
|
+
batch_size: int = 10,
|
|
36
|
+
partition_key: str | None = None,
|
|
37
|
+
integrity_check: bool = False,
|
|
38
|
+
concurrency: int = 10,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
41
|
+
"""Process a DataFrame of responses in batches using an LLM.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input_df (pd.DataFrame): DataFrame containing input to be processed.
|
|
45
|
+
Must include a 'response_id' column.
|
|
46
|
+
prompt_template (Union[str, Path, PromptTemplate]): Template for LLM prompts.
|
|
47
|
+
Can be a string (file path), Path object, or PromptTemplate.
|
|
48
|
+
llm (Runnable): LangChain Runnable instance that will process the prompts.
|
|
49
|
+
batch_size (int, optional): Number of input rows to process in each batch.
|
|
50
|
+
Defaults to 10.
|
|
51
|
+
partition_key (str | None, optional): Optional column name to group input rows
|
|
52
|
+
before batching. Defaults to None.
|
|
53
|
+
integrity_check (bool, optional): If True, verifies that all input
|
|
54
|
+
response IDs are present in LLM output.
|
|
55
|
+
If False, no integrity checking or retrying occurs. Defaults to False.
|
|
56
|
+
concurrency (int, optional): Maximum number of simultaneous LLM calls allowed.
|
|
57
|
+
Defaults to 10.
|
|
58
|
+
**kwargs (Any): Additional keyword arguments to pass to the prompt template.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
pd.DataFrame: DataFrame containing the original responses merged with the
|
|
62
|
+
LLM-processed results.
|
|
63
|
+
Returns:
|
|
64
|
+
tuple[pd.DataFrame, pd.DataFrame]:
|
|
65
|
+
A tuple containing two DataFrames:
|
|
66
|
+
- The first DataFrame contains the rows that were successfully processes by the LLM
|
|
67
|
+
- The second DataFrame contains the rows that could not be processed by the LLM
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
logger.info(f"Running batch and run with batch size {batch_size}")
|
|
71
|
+
prompt_template = convert_to_prompt_template(prompt_template)
|
|
72
|
+
batch_prompts = generate_prompts(
|
|
73
|
+
prompt_template,
|
|
74
|
+
input_df,
|
|
75
|
+
batch_size=batch_size,
|
|
76
|
+
partition_key=partition_key,
|
|
77
|
+
**kwargs,
|
|
78
|
+
)
|
|
79
|
+
processed_rows, failed_ids = await call_llm(
|
|
80
|
+
batch_prompts=batch_prompts,
|
|
81
|
+
llm=llm,
|
|
82
|
+
integrity_check=integrity_check,
|
|
83
|
+
concurrency=concurrency,
|
|
84
|
+
)
|
|
85
|
+
processed_results = process_llm_responses(processed_rows, input_df)
|
|
86
|
+
|
|
87
|
+
if failed_ids:
|
|
88
|
+
retry_df = input_df[input_df["response_id"].isin(failed_ids)]
|
|
89
|
+
retry_prompts = generate_prompts(
|
|
90
|
+
prompt_template, retry_df, batch_size=1, **kwargs
|
|
91
|
+
)
|
|
92
|
+
retry_results, unprocessable_ids = await call_llm(
|
|
93
|
+
batch_prompts=retry_prompts,
|
|
94
|
+
llm=llm,
|
|
95
|
+
integrity_check=integrity_check,
|
|
96
|
+
concurrency=concurrency,
|
|
97
|
+
)
|
|
98
|
+
retry_processed_results = process_llm_responses(retry_results, retry_df)
|
|
99
|
+
unprocessable_df = retry_df.loc[retry_df["response_id"].isin(unprocessable_ids)]
|
|
100
|
+
processed_results = pd.concat([processed_results, retry_processed_results])
|
|
101
|
+
else:
|
|
102
|
+
unprocessable_df = pd.DataFrame()
|
|
103
|
+
return processed_results, unprocessable_df
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def load_prompt_from_file(file_path: str | Path) -> str:
|
|
107
|
+
"""Load a prompt template from a text file in the prompts directory.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
file_path (str | Path): Name of the prompt file (without .txt extension)
|
|
111
|
+
or Path object pointing to the file.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
str: Content of the prompt template file.
|
|
115
|
+
"""
|
|
116
|
+
parent_dir = Path(__file__).parent
|
|
117
|
+
with Path.open(parent_dir / "prompts" / f"{file_path}.txt") as file:
|
|
118
|
+
return file.read()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def convert_to_prompt_template(prompt_template: str | Path | PromptTemplate):
|
|
122
|
+
"""Convert various input types to a LangChain PromptTemplate.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
prompt_template (str | Path | PromptTemplate): Input template that can be either:
|
|
126
|
+
- str: Name of a prompt file in the prompts directory (without .txt extension)
|
|
127
|
+
- Path: Path object pointing to a prompt file
|
|
128
|
+
- PromptTemplate: Already initialized LangChain PromptTemplate
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
PromptTemplate: Initialized LangChain PromptTemplate object.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
TypeError: If prompt_template is not one of the expected types.
|
|
135
|
+
FileNotFoundError: If using str/Path input and the prompt file doesn't exist.
|
|
136
|
+
"""
|
|
137
|
+
if isinstance(prompt_template, str | Path):
|
|
138
|
+
prompt_content = load_prompt_from_file(prompt_template)
|
|
139
|
+
template = PromptTemplate.from_template(template=prompt_content)
|
|
140
|
+
elif isinstance(prompt_template, PromptTemplate):
|
|
141
|
+
template = prompt_template
|
|
142
|
+
else:
|
|
143
|
+
msg = "Invalid prompt_template type. Expected str, Path, or PromptTemplate."
|
|
144
|
+
raise TypeError(msg)
|
|
145
|
+
return template
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def partition_dataframe(
|
|
149
|
+
df: pd.DataFrame, partition_key: Optional[str]
|
|
150
|
+
) -> list[pd.DataFrame]:
|
|
151
|
+
"""Splits the DataFrame into partitions based on the partition_key if provided."""
|
|
152
|
+
if partition_key:
|
|
153
|
+
return [group.reset_index(drop=True) for _, group in df.groupby(partition_key)]
|
|
154
|
+
return [df]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def split_overflowing_batch(
|
|
158
|
+
batch: pd.DataFrame, allowed_tokens: int
|
|
159
|
+
) -> list[pd.DataFrame]:
|
|
160
|
+
"""
|
|
161
|
+
Splits a DataFrame batch into smaller sub-batches such that each sub-batch's total token count
|
|
162
|
+
does not exceed the allowed token limit.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
batch (pd.DataFrame): The input DataFrame to split.
|
|
166
|
+
allowed_tokens (int): The maximum allowed number of tokens per sub-batch.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
list[pd.DataFrame]: A list of sub-batches, each within the token limit.
|
|
170
|
+
"""
|
|
171
|
+
sub_batches = []
|
|
172
|
+
current_indices = []
|
|
173
|
+
current_token_sum = 0
|
|
174
|
+
token_counts = batch.apply(
|
|
175
|
+
lambda row: calculate_string_token_length(row.to_json()), axis=1
|
|
176
|
+
).tolist()
|
|
177
|
+
|
|
178
|
+
for i, token_count in enumerate(token_counts):
|
|
179
|
+
if token_count > allowed_tokens:
|
|
180
|
+
logging.warning(
|
|
181
|
+
f"Row at index {batch.index[i]} exceeds allowed token limit ({token_count} > {allowed_tokens}). Skipping row."
|
|
182
|
+
)
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
if current_token_sum + token_count > allowed_tokens:
|
|
186
|
+
if current_indices:
|
|
187
|
+
sub_batch = batch.iloc[current_indices].reset_index(drop=True)
|
|
188
|
+
if not sub_batch.empty:
|
|
189
|
+
sub_batches.append(sub_batch)
|
|
190
|
+
current_indices = [i]
|
|
191
|
+
current_token_sum = token_count
|
|
192
|
+
else:
|
|
193
|
+
current_indices.append(i)
|
|
194
|
+
current_token_sum += token_count
|
|
195
|
+
|
|
196
|
+
if current_indices:
|
|
197
|
+
sub_batch = batch.iloc[current_indices].reset_index(drop=True)
|
|
198
|
+
if not sub_batch.empty:
|
|
199
|
+
sub_batches.append(sub_batch)
|
|
200
|
+
return sub_batches
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def batch_task_input_df(
|
|
204
|
+
df: pd.DataFrame,
|
|
205
|
+
allowed_tokens: int,
|
|
206
|
+
batch_size: int,
|
|
207
|
+
partition_key: Optional[str] = None,
|
|
208
|
+
) -> list[pd.DataFrame]:
|
|
209
|
+
"""
|
|
210
|
+
Partitions and batches a DataFrame according to a token limit and batch size, optionally using a partition key. Batches that exceed the token limit are further split.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
df (pd.DataFrame): The input DataFrame to batch.
|
|
214
|
+
allowed_tokens (int): Maximum allowed tokens per batch.
|
|
215
|
+
batch_size (int): Maximum number of rows per batch before token filtering.
|
|
216
|
+
partition_key (Optional[str], optional): Column name to partition the DataFrame by.
|
|
217
|
+
Defaults to None.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
list[pd.DataFrame]: A list of batches, each within the specified token and size limits.
|
|
221
|
+
"""
|
|
222
|
+
batches = []
|
|
223
|
+
partitions = partition_dataframe(df, partition_key)
|
|
224
|
+
|
|
225
|
+
for partition in partitions:
|
|
226
|
+
partition_batches = [
|
|
227
|
+
partition.iloc[i : i + batch_size].reset_index(drop=True)
|
|
228
|
+
for i in range(0, len(partition), batch_size)
|
|
229
|
+
]
|
|
230
|
+
for batch in partition_batches:
|
|
231
|
+
batch_length = calculate_string_token_length(batch.to_json())
|
|
232
|
+
if batch_length <= allowed_tokens:
|
|
233
|
+
batches.append(batch)
|
|
234
|
+
else:
|
|
235
|
+
sub_batches = split_overflowing_batch(batch, allowed_tokens)
|
|
236
|
+
batches.extend(sub_batches)
|
|
237
|
+
return batches
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def generate_prompts(
|
|
241
|
+
prompt_template: PromptTemplate,
|
|
242
|
+
input_data: pd.DataFrame,
|
|
243
|
+
batch_size: int = 50,
|
|
244
|
+
max_prompt_length: int = 50_000,
|
|
245
|
+
partition_key: str | None = None,
|
|
246
|
+
**kwargs,
|
|
247
|
+
) -> list[BatchPrompt]:
|
|
248
|
+
"""
|
|
249
|
+
Generate a list of BatchPrompt objects by splitting the input DataFrame into batches
|
|
250
|
+
and formatting each batch using a prompt template.
|
|
251
|
+
|
|
252
|
+
The function first calculates the token length of the prompt template to determine
|
|
253
|
+
the allowed tokens available for the input data. It then splits the input data into batches,
|
|
254
|
+
optionally partitioning by a specified key. Each batch is then formatted into a prompt string
|
|
255
|
+
using the provided prompt template, and a BatchPrompt is created containing the prompt string
|
|
256
|
+
and a list of response IDs from the batch.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method
|
|
260
|
+
used to create a prompt string from a list of response dictionaries.
|
|
261
|
+
input_data (pd.DataFrame): A DataFrame containing the input responses, with at least a
|
|
262
|
+
'response_id' column.
|
|
263
|
+
batch_size (int, optional): Maximum number of rows to include in each batch. Defaults to 50.
|
|
264
|
+
max_prompt_length (int, optional): The maximum total token length allowed for the prompt,
|
|
265
|
+
including both the prompt template and the input data. Defaults to 50,000.
|
|
266
|
+
partition_key (str | None, optional): Column name used to partition the DataFrame before batching.
|
|
267
|
+
If provided, the DataFrame will be grouped by this key so that rows with the same value
|
|
268
|
+
remain in the same batch. Defaults to None.
|
|
269
|
+
**kwargs: Additional keyword arguments to pass to the prompt template's format method.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
list[BatchPrompt]: A list of BatchPrompt objects where each object contains:
|
|
273
|
+
- prompt_string: The formatted prompt string for a batch.
|
|
274
|
+
- response_ids: A list of response IDs corresponding to the rows in that batch.
|
|
275
|
+
"""
|
|
276
|
+
prompt_token_length = calculate_string_token_length(prompt_template.template)
|
|
277
|
+
allowed_tokens_for_data = max_prompt_length - prompt_token_length
|
|
278
|
+
batches = batch_task_input_df(
|
|
279
|
+
input_data, allowed_tokens_for_data, batch_size, partition_key
|
|
280
|
+
)
|
|
281
|
+
prompts = [build_prompt(prompt_template, batch, **kwargs) for batch in batches]
|
|
282
|
+
return prompts
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
async def call_llm(
|
|
286
|
+
batch_prompts: list[BatchPrompt],
|
|
287
|
+
llm: Runnable,
|
|
288
|
+
concurrency: int = 10,
|
|
289
|
+
integrity_check: bool = False,
|
|
290
|
+
) -> tuple[list[dict], list[int]]:
|
|
291
|
+
"""Process multiple batches of prompts concurrently through an LLM with retry logic."""
|
|
292
|
+
semaphore = asyncio.Semaphore(concurrency)
|
|
293
|
+
|
|
294
|
+
@retry(
|
|
295
|
+
wait=wait_random_exponential(min=1, max=20),
|
|
296
|
+
stop=stop_after_attempt(6),
|
|
297
|
+
before=before.before_log(logger=logger, log_level=logging.DEBUG),
|
|
298
|
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
|
299
|
+
reraise=True,
|
|
300
|
+
)
|
|
301
|
+
async def async_llm_call(batch_prompt) -> tuple[list[dict], list[int]]:
|
|
302
|
+
async with semaphore:
|
|
303
|
+
try:
|
|
304
|
+
llm_response = await llm.ainvoke(batch_prompt.prompt_string)
|
|
305
|
+
all_results = (
|
|
306
|
+
llm_response.dict()
|
|
307
|
+
if hasattr(llm_response, "dict")
|
|
308
|
+
else llm_response
|
|
309
|
+
)
|
|
310
|
+
responses = (
|
|
311
|
+
all_results["responses"]
|
|
312
|
+
if isinstance(all_results, dict)
|
|
313
|
+
else all_results.responses
|
|
314
|
+
)
|
|
315
|
+
except (openai.BadRequestError, ValueError) as e:
|
|
316
|
+
logger.warning(e)
|
|
317
|
+
return [], batch_prompt.response_ids
|
|
318
|
+
except ValidationError as e:
|
|
319
|
+
logger.warning(e)
|
|
320
|
+
return [], batch_prompt.response_ids
|
|
321
|
+
|
|
322
|
+
if integrity_check:
|
|
323
|
+
failed_ids = get_missing_response_ids(
|
|
324
|
+
batch_prompt.response_ids, all_results
|
|
325
|
+
)
|
|
326
|
+
return responses, failed_ids
|
|
327
|
+
else:
|
|
328
|
+
return responses, []
|
|
329
|
+
|
|
330
|
+
results = await asyncio.gather(
|
|
331
|
+
*[async_llm_call(batch_prompt) for batch_prompt in batch_prompts]
|
|
332
|
+
)
|
|
333
|
+
valid_inputs = [row for result, _ in results for row in result]
|
|
334
|
+
failed_response_ids = [
|
|
335
|
+
failed_response_id
|
|
336
|
+
for _, batch_failures in results
|
|
337
|
+
for failed_response_id in batch_failures
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
return valid_inputs, failed_response_ids
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def get_missing_response_ids(
|
|
344
|
+
input_response_ids: list[int], parsed_response: dict
|
|
345
|
+
) -> list[int]:
|
|
346
|
+
"""Identify which response IDs are missing from the LLM's parsed response.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
input_response_ids (set[str]): Set of response IDs that were included in the
|
|
350
|
+
original prompt.
|
|
351
|
+
parsed_response (dict): Parsed response from the LLM containing a 'responses' key
|
|
352
|
+
with a list of dictionaries, each containing a 'response_id' field.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
set[str]: Set of response IDs that are missing from the parsed response.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
response_ids_set = {int(response_id) for response_id in input_response_ids}
|
|
359
|
+
returned_ids_set = {
|
|
360
|
+
int(element["response_id"])
|
|
361
|
+
for element in parsed_response["responses"]
|
|
362
|
+
if element.get("response_id", False)
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
missing_ids = list(response_ids_set - returned_ids_set)
|
|
366
|
+
if missing_ids:
|
|
367
|
+
logger.info(f"Missing response IDs from LLM output: {missing_ids}")
|
|
368
|
+
return missing_ids
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def process_llm_responses(
|
|
372
|
+
llm_responses: list[dict[str, Any]], responses: pd.DataFrame
|
|
373
|
+
) -> pd.DataFrame:
|
|
374
|
+
"""Process and merge LLM responses with the original DataFrame.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
llm_responses (list[dict[str, Any]]): List of LLM response dictionaries, where each
|
|
378
|
+
dictionary contains a 'responses' key with a list of individual response objects.
|
|
379
|
+
responses (pd.DataFrame): Original DataFrame containing the input responses, must
|
|
380
|
+
include a 'response_id' column.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
pd.DataFrame: A merged DataFrame containing:
|
|
384
|
+
- If response_id exists in LLM output: Original responses joined with LLM results
|
|
385
|
+
on response_id (inner join)
|
|
386
|
+
- If no response_id in LLM output: DataFrame containing only the LLM results
|
|
387
|
+
"""
|
|
388
|
+
responses.loc[:, "response_id"] = responses["response_id"].astype(int)
|
|
389
|
+
task_responses = pd.DataFrame(llm_responses)
|
|
390
|
+
if "response_id" in task_responses.columns:
|
|
391
|
+
task_responses["response_id"] = task_responses["response_id"].astype(int)
|
|
392
|
+
return responses.merge(task_responses, how="inner", on="response_id")
|
|
393
|
+
return task_responses
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def calculate_string_token_length(input_text: str, model: str = None) -> int:
|
|
397
|
+
"""
|
|
398
|
+
Calculates the number of tokens in a given string using the specified model's tokenizer.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
input_text (str): The input string to tokenize.
|
|
402
|
+
model (str, optional): The model name used for tokenization. If not provided,
|
|
403
|
+
uses the MODEL_NAME environment variable or defaults to "gpt-4o".
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
int: The number of tokens in the input string.
|
|
407
|
+
"""
|
|
408
|
+
# Use the MODEL_NAME env var if no model is provided; otherwise default to "gpt-4o"
|
|
409
|
+
model = model or os.environ.get("MODEL_NAME", "gpt-4o")
|
|
410
|
+
tokenizer_encoding = tiktoken.encoding_for_model(model)
|
|
411
|
+
number_of_tokens = len(tokenizer_encoding.encode(input_text))
|
|
412
|
+
return number_of_tokens
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def build_prompt(
|
|
416
|
+
prompt_template: PromptTemplate, input_batch: pd.DataFrame, **kwargs
|
|
417
|
+
) -> BatchPrompt:
|
|
418
|
+
"""
|
|
419
|
+
Constructs a BatchPrompt by formatting a prompt template with a batch of responses.
|
|
420
|
+
|
|
421
|
+
The function converts the input DataFrame batch into a list of dictionaries (one per row) and passes
|
|
422
|
+
this list to the prompt template's format method under the key 'responses', along with any additional
|
|
423
|
+
keyword arguments. It also extracts the 'response_id' column from the batch,
|
|
424
|
+
and uses these to create the BatchPrompt.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method that is used
|
|
428
|
+
to generate the prompt string.
|
|
429
|
+
input_batch (pd.DataFrame): A DataFrame containing the batch of responses, which must include a 'response_id'
|
|
430
|
+
column.
|
|
431
|
+
**kwargs: Additional keyword arguments to pass to the prompt template's format method.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
BatchPrompt: An object containing:
|
|
435
|
+
- prompt_string: The formatted prompt string for the batch.
|
|
436
|
+
- response_ids: A list of response IDs (as strings) corresponding to the responses in the batch.
|
|
437
|
+
"""
|
|
438
|
+
prompt = prompt_template.format(
|
|
439
|
+
responses=input_batch.to_dict(orient="records"), **kwargs
|
|
440
|
+
)
|
|
441
|
+
response_ids = input_batch["response_id"].astype(int).to_list()
|
|
442
|
+
return BatchPrompt(prompt_string=prompt, response_ids=response_ids)
|