themefinder 0.5.3__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of themefinder might be problematic. Click here for more details.
- themefinder/__init__.py +2 -2
- themefinder/core.py +86 -75
- themefinder/llm_batch_processor.py +303 -123
- themefinder/models.py +138 -0
- themefinder/prompts/sentiment_analysis.txt +8 -5
- themefinder/prompts/theme_condensation.txt +19 -6
- themefinder/prompts/theme_mapping.txt +3 -3
- themefinder/prompts/theme_refinement.txt +15 -28
- {themefinder-0.5.3.dist-info → themefinder-0.6.2.dist-info}/METADATA +3 -2
- themefinder-0.6.2.dist-info/RECORD +16 -0
- {themefinder-0.5.3.dist-info → themefinder-0.6.2.dist-info}/WHEEL +1 -1
- themefinder-0.5.3.dist-info/RECORD +0 -15
- {themefinder-0.5.3.dist-info → themefinder-0.6.2.dist-info}/LICENCE +0 -0
|
@@ -1,14 +1,24 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
|
+
import os
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any
|
|
7
|
+
from typing import Any, Optional, Type
|
|
7
8
|
|
|
9
|
+
import openai
|
|
8
10
|
import pandas as pd
|
|
11
|
+
import tiktoken
|
|
9
12
|
from langchain_core.prompts import PromptTemplate
|
|
10
13
|
from langchain_core.runnables import Runnable
|
|
11
|
-
from
|
|
14
|
+
from pydantic import BaseModel, ValidationError
|
|
15
|
+
from tenacity import (
|
|
16
|
+
before,
|
|
17
|
+
before_sleep_log,
|
|
18
|
+
retry,
|
|
19
|
+
stop_after_attempt,
|
|
20
|
+
wait_random_exponential,
|
|
21
|
+
)
|
|
12
22
|
|
|
13
23
|
from .themefinder_logging import logger
|
|
14
24
|
|
|
@@ -16,63 +26,82 @@ from .themefinder_logging import logger
|
|
|
16
26
|
@dataclass
|
|
17
27
|
class BatchPrompt:
|
|
18
28
|
prompt_string: str
|
|
19
|
-
response_ids: list[
|
|
29
|
+
response_ids: list[int]
|
|
20
30
|
|
|
21
31
|
|
|
22
32
|
async def batch_and_run(
|
|
23
|
-
|
|
33
|
+
input_df: pd.DataFrame,
|
|
24
34
|
prompt_template: str | Path | PromptTemplate,
|
|
25
35
|
llm: Runnable,
|
|
26
36
|
batch_size: int = 10,
|
|
27
37
|
partition_key: str | None = None,
|
|
28
|
-
|
|
38
|
+
validation_check: bool = False,
|
|
39
|
+
task_validation_model: Type[BaseModel] = None,
|
|
29
40
|
**kwargs: Any,
|
|
30
|
-
) -> pd.DataFrame:
|
|
41
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
31
42
|
"""Process a DataFrame of responses in batches using an LLM.
|
|
32
43
|
|
|
33
44
|
Args:
|
|
34
|
-
|
|
45
|
+
input_df (pd.DataFrame): DataFrame containing input to be processed.
|
|
35
46
|
Must include a 'response_id' column.
|
|
36
47
|
prompt_template (Union[str, Path, PromptTemplate]): Template for LLM prompts.
|
|
37
48
|
Can be a string (file path), Path object, or PromptTemplate.
|
|
38
49
|
llm (Runnable): LangChain Runnable instance that will process the prompts.
|
|
39
|
-
batch_size (int, optional): Number of
|
|
50
|
+
batch_size (int, optional): Number of input rows to process in each batch.
|
|
40
51
|
Defaults to 10.
|
|
41
|
-
partition_key (str | None, optional): Optional column name to group
|
|
52
|
+
partition_key (str | None, optional): Optional column name to group input rows
|
|
42
53
|
before batching. Defaults to None.
|
|
43
|
-
|
|
44
|
-
response IDs are present in LLM output and
|
|
54
|
+
validation_check (bool, optional): If True, verifies that all input
|
|
55
|
+
response IDs are present in LLM output and validates the rows against the validation model,
|
|
56
|
+
failed rows are retried individually.
|
|
45
57
|
If False, no integrity checking or retrying occurs. Defaults to False.
|
|
58
|
+
task_validation_model (Type[BaseModel]): the pydanctic model to validate each row against
|
|
46
59
|
**kwargs (Any): Additional keyword arguments to pass to the prompt template.
|
|
47
60
|
|
|
48
61
|
Returns:
|
|
49
62
|
pd.DataFrame: DataFrame containing the original responses merged with the
|
|
50
63
|
LLM-processed results.
|
|
64
|
+
Returns:
|
|
65
|
+
tuple[pd.DataFrame, pd.DataFrame]:
|
|
66
|
+
A tuple containing two DataFrames:
|
|
67
|
+
- The first DataFrame contains the rows that were successfully processes by the LLM
|
|
68
|
+
- The second DataFrame contains the rows that could not be processed by the LLM
|
|
51
69
|
"""
|
|
70
|
+
|
|
52
71
|
logger.info(f"Running batch and run with batch size {batch_size}")
|
|
53
72
|
prompt_template = convert_to_prompt_template(prompt_template)
|
|
54
|
-
|
|
55
|
-
|
|
73
|
+
batch_prompts = generate_prompts(
|
|
74
|
+
prompt_template,
|
|
75
|
+
input_df,
|
|
76
|
+
batch_size=batch_size,
|
|
77
|
+
partition_key=partition_key,
|
|
78
|
+
**kwargs,
|
|
56
79
|
)
|
|
57
|
-
|
|
58
|
-
llm_responses, failed_ids = await call_llm(
|
|
80
|
+
processed_rows, failed_ids = await call_llm(
|
|
59
81
|
batch_prompts=batch_prompts,
|
|
60
82
|
llm=llm,
|
|
61
|
-
|
|
83
|
+
validation_check=validation_check,
|
|
84
|
+
task_validation_model=task_validation_model,
|
|
62
85
|
)
|
|
63
|
-
|
|
86
|
+
processed_results = process_llm_responses(processed_rows, input_df)
|
|
87
|
+
|
|
64
88
|
if failed_ids:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
89
|
+
retry_df = input_df[input_df["response_id"].isin(failed_ids)]
|
|
90
|
+
retry_prompts = generate_prompts(
|
|
91
|
+
prompt_template, retry_df, batch_size=1, **kwargs
|
|
92
|
+
)
|
|
93
|
+
retry_results, unprocessable_ids = await call_llm(
|
|
94
|
+
batch_prompts=retry_prompts,
|
|
69
95
|
llm=llm,
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
**kwargs,
|
|
96
|
+
validation_check=validation_check,
|
|
97
|
+
task_validation_model=task_validation_model,
|
|
73
98
|
)
|
|
74
|
-
|
|
75
|
-
|
|
99
|
+
retry_processed_results = process_llm_responses(retry_results, retry_df)
|
|
100
|
+
unprocessable_df = retry_df.loc[retry_df["response_id"].isin(unprocessable_ids)]
|
|
101
|
+
processed_results = pd.concat([processed_results, retry_processed_results])
|
|
102
|
+
else:
|
|
103
|
+
unprocessable_df = pd.DataFrame()
|
|
104
|
+
return processed_results, unprocessable_df
|
|
76
105
|
|
|
77
106
|
|
|
78
107
|
def load_prompt_from_file(file_path: str | Path) -> str:
|
|
@@ -117,81 +146,150 @@ def convert_to_prompt_template(prompt_template: str | Path | PromptTemplate):
|
|
|
117
146
|
return template
|
|
118
147
|
|
|
119
148
|
|
|
120
|
-
def
|
|
121
|
-
|
|
149
|
+
def partition_dataframe(
|
|
150
|
+
df: pd.DataFrame, partition_key: Optional[str]
|
|
151
|
+
) -> list[pd.DataFrame]:
|
|
152
|
+
"""Splits the DataFrame into partitions based on the partition_key if provided."""
|
|
153
|
+
if partition_key:
|
|
154
|
+
return [group.reset_index(drop=True) for _, group in df.groupby(partition_key)]
|
|
155
|
+
return [df]
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def split_overflowing_batch(
|
|
159
|
+
batch: pd.DataFrame, allowed_tokens: int
|
|
122
160
|
) -> list[pd.DataFrame]:
|
|
123
|
-
"""
|
|
161
|
+
"""
|
|
162
|
+
Splits a DataFrame batch into smaller sub-batches such that each sub-batch's total token count
|
|
163
|
+
does not exceed the allowed token limit.
|
|
124
164
|
|
|
125
165
|
Args:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
partition_key (str | None, optional): Column name to group by before batching.
|
|
129
|
-
If provided, ensures rows with the same partition key value stay together
|
|
130
|
-
and each group is batched separately. Defaults to None.
|
|
166
|
+
batch (pd.DataFrame): The input DataFrame to split.
|
|
167
|
+
allowed_tokens (int): The maximum allowed number of tokens per sub-batch.
|
|
131
168
|
|
|
132
169
|
Returns:
|
|
133
|
-
list[pd.DataFrame]:
|
|
134
|
-
at most batch_size rows. If partition_key is used, rows within each
|
|
135
|
-
partition are kept together and batched separately.
|
|
170
|
+
list[pd.DataFrame]: A list of sub-batches, each within the token limit.
|
|
136
171
|
"""
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
172
|
+
sub_batches = []
|
|
173
|
+
current_indices = []
|
|
174
|
+
current_token_sum = 0
|
|
175
|
+
token_counts = batch.apply(
|
|
176
|
+
lambda row: calculate_string_token_length(row.to_json()), axis=1
|
|
177
|
+
).tolist()
|
|
178
|
+
|
|
179
|
+
for i, token_count in enumerate(token_counts):
|
|
180
|
+
if token_count > allowed_tokens:
|
|
181
|
+
logging.warning(
|
|
182
|
+
f"Row at index {batch.index[i]} exceeds allowed token limit ({token_count} > {allowed_tokens}). Skipping row."
|
|
183
|
+
)
|
|
184
|
+
continue
|
|
185
|
+
|
|
186
|
+
if current_token_sum + token_count > allowed_tokens:
|
|
187
|
+
if current_indices:
|
|
188
|
+
sub_batch = batch.iloc[current_indices].reset_index(drop=True)
|
|
189
|
+
if not sub_batch.empty:
|
|
190
|
+
sub_batches.append(sub_batch)
|
|
191
|
+
current_indices = [i]
|
|
192
|
+
current_token_sum = token_count
|
|
193
|
+
else:
|
|
194
|
+
current_indices.append(i)
|
|
195
|
+
current_token_sum += token_count
|
|
196
|
+
|
|
197
|
+
if current_indices:
|
|
198
|
+
sub_batch = batch.iloc[current_indices].reset_index(drop=True)
|
|
199
|
+
if not sub_batch.empty:
|
|
200
|
+
sub_batches.append(sub_batch)
|
|
201
|
+
return sub_batches
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def batch_task_input_df(
|
|
205
|
+
df: pd.DataFrame,
|
|
206
|
+
allowed_tokens: int,
|
|
207
|
+
batch_size: int,
|
|
208
|
+
partition_key: Optional[str] = None,
|
|
209
|
+
) -> list[pd.DataFrame]:
|
|
210
|
+
"""
|
|
211
|
+
Partitions and batches a DataFrame according to a token limit and batch size, optionally using a partition key. Batches that exceed the token limit are further split.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
df (pd.DataFrame): The input DataFrame to batch.
|
|
215
|
+
allowed_tokens (int): Maximum allowed tokens per batch.
|
|
216
|
+
batch_size (int): Maximum number of rows per batch before token filtering.
|
|
217
|
+
partition_key (Optional[str], optional): Column name to partition the DataFrame by.
|
|
218
|
+
Defaults to None.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
list[pd.DataFrame]: A list of batches, each within the specified token and size limits.
|
|
222
|
+
"""
|
|
223
|
+
batches = []
|
|
224
|
+
partitions = partition_dataframe(df, partition_key)
|
|
225
|
+
|
|
226
|
+
for partition in partitions:
|
|
227
|
+
partition_batches = [
|
|
228
|
+
partition.iloc[i : i + batch_size].reset_index(drop=True)
|
|
229
|
+
for i in range(0, len(partition), batch_size)
|
|
230
|
+
]
|
|
231
|
+
for batch in partition_batches:
|
|
232
|
+
batch_length = calculate_string_token_length(batch.to_json())
|
|
233
|
+
if batch_length <= allowed_tokens:
|
|
234
|
+
batches.append(batch)
|
|
235
|
+
else:
|
|
236
|
+
sub_batches = split_overflowing_batch(batch, allowed_tokens)
|
|
237
|
+
batches.extend(sub_batches)
|
|
238
|
+
return batches
|
|
152
239
|
|
|
153
240
|
|
|
154
241
|
def generate_prompts(
|
|
155
|
-
|
|
242
|
+
prompt_template: PromptTemplate,
|
|
243
|
+
input_data: pd.DataFrame,
|
|
244
|
+
batch_size: int = 50,
|
|
245
|
+
max_prompt_length: int = 50_000,
|
|
246
|
+
partition_key: str | None = None,
|
|
247
|
+
**kwargs,
|
|
156
248
|
) -> list[BatchPrompt]:
|
|
157
|
-
"""
|
|
249
|
+
"""
|
|
250
|
+
Generate a list of BatchPrompt objects by splitting the input DataFrame into batches
|
|
251
|
+
and formatting each batch using a prompt template.
|
|
252
|
+
|
|
253
|
+
The function first calculates the token length of the prompt template to determine
|
|
254
|
+
the allowed tokens available for the input data. It then splits the input data into batches,
|
|
255
|
+
optionally partitioning by a specified key. Each batch is then formatted into a prompt string
|
|
256
|
+
using the provided prompt template, and a BatchPrompt is created containing the prompt string
|
|
257
|
+
and a list of response IDs from the batch.
|
|
158
258
|
|
|
159
259
|
Args:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
260
|
+
prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method
|
|
261
|
+
used to create a prompt string from a list of response dictionaries.
|
|
262
|
+
input_data (pd.DataFrame): A DataFrame containing the input responses, with at least a
|
|
263
|
+
'response_id' column.
|
|
264
|
+
batch_size (int, optional): Maximum number of rows to include in each batch. Defaults to 50.
|
|
265
|
+
max_prompt_length (int, optional): The maximum total token length allowed for the prompt,
|
|
266
|
+
including both the prompt template and the input data. Defaults to 50,000.
|
|
267
|
+
partition_key (str | None, optional): Column name used to partition the DataFrame before batching.
|
|
268
|
+
If provided, the DataFrame will be grouped by this key so that rows with the same value
|
|
269
|
+
remain in the same batch. Defaults to None.
|
|
270
|
+
**kwargs: Additional keyword arguments to pass to the prompt template's format method.
|
|
166
271
|
|
|
167
272
|
Returns:
|
|
168
|
-
list[BatchPrompt]:
|
|
169
|
-
- prompt_string:
|
|
170
|
-
- response_ids:
|
|
171
|
-
|
|
172
|
-
Note:
|
|
173
|
-
The function converts each DataFrame to a list of dictionaries and passes it
|
|
174
|
-
to the prompt template as the 'responses' variable.
|
|
273
|
+
list[BatchPrompt]: A list of BatchPrompt objects where each object contains:
|
|
274
|
+
- prompt_string: The formatted prompt string for a batch.
|
|
275
|
+
- response_ids: A list of response IDs corresponding to the rows in that batch.
|
|
175
276
|
"""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
BatchPrompt(prompt_string=prompt, response_ids=response_ids)
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
return batched_prompts
|
|
277
|
+
prompt_token_length = calculate_string_token_length(prompt_template.template)
|
|
278
|
+
allowed_tokens_for_data = max_prompt_length - prompt_token_length
|
|
279
|
+
batches = batch_task_input_df(
|
|
280
|
+
input_data, allowed_tokens_for_data, batch_size, partition_key
|
|
281
|
+
)
|
|
282
|
+
prompts = [build_prompt(prompt_template, batch, **kwargs) for batch in batches]
|
|
283
|
+
return prompts
|
|
187
284
|
|
|
188
285
|
|
|
189
286
|
async def call_llm(
|
|
190
287
|
batch_prompts: list[BatchPrompt],
|
|
191
288
|
llm: Runnable,
|
|
192
289
|
concurrency: int = 10,
|
|
193
|
-
|
|
194
|
-
|
|
290
|
+
validation_check: bool = False,
|
|
291
|
+
task_validation_model: Optional[Type[BaseModel]] = None,
|
|
292
|
+
) -> tuple[list[dict], list[int]]:
|
|
195
293
|
"""Process multiple batches of prompts concurrently through an LLM with retry logic.
|
|
196
294
|
|
|
197
295
|
Args:
|
|
@@ -200,9 +298,10 @@ async def call_llm(
|
|
|
200
298
|
llm (Runnable): LangChain Runnable instance that will process the prompts.
|
|
201
299
|
concurrency (int, optional): Maximum number of simultaneous LLM calls allowed.
|
|
202
300
|
Defaults to 10.
|
|
203
|
-
|
|
301
|
+
validation_check (bool, optional): If True, verifies that all input
|
|
204
302
|
response IDs are present in the LLM output. Failed batches are discarded and
|
|
205
303
|
their IDs are returned for retry. Defaults to False.
|
|
304
|
+
task_validation_model (Type[BaseModel]): The Pydantic model to check the LLM outputs against
|
|
206
305
|
|
|
207
306
|
Returns:
|
|
208
307
|
tuple[list[dict[str, Any]], set[str]]: A tuple containing:
|
|
@@ -215,69 +314,76 @@ async def call_llm(
|
|
|
215
314
|
- Concurrency is managed via asyncio.Semaphore to prevent overwhelming the LLM
|
|
216
315
|
"""
|
|
217
316
|
semaphore = asyncio.Semaphore(concurrency)
|
|
218
|
-
failed_ids: set = set()
|
|
219
317
|
|
|
220
318
|
@retry(
|
|
221
319
|
wait=wait_random_exponential(min=1, max=20),
|
|
222
320
|
stop=stop_after_attempt(6),
|
|
223
321
|
before=before.before_log(logger=logger, log_level=logging.DEBUG),
|
|
322
|
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
|
224
323
|
reraise=True,
|
|
225
324
|
)
|
|
226
|
-
async def async_llm_call(batch_prompt):
|
|
325
|
+
async def async_llm_call(batch_prompt) -> tuple[list[dict], list[int]]:
|
|
227
326
|
async with semaphore:
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
batch_prompt.response_ids
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
327
|
+
try:
|
|
328
|
+
llm_response = await llm.ainvoke(batch_prompt.prompt_string)
|
|
329
|
+
all_results = json.loads(llm_response.content)
|
|
330
|
+
except (openai.BadRequestError, json.JSONDecodeError) as e:
|
|
331
|
+
failed_ids = batch_prompt.response_ids
|
|
332
|
+
logger.warning(e)
|
|
333
|
+
return [], failed_ids
|
|
334
|
+
|
|
335
|
+
if validation_check:
|
|
336
|
+
failed_ids = get_missing_response_ids(
|
|
337
|
+
batch_prompt.response_ids, all_results
|
|
338
|
+
)
|
|
339
|
+
validated_results, invalid_rows = validate_task_data(
|
|
340
|
+
all_results["responses"], task_validation_model
|
|
341
|
+
)
|
|
342
|
+
failed_ids.extend([r["response_id"] for r in invalid_rows])
|
|
343
|
+
return validated_results, failed_ids
|
|
344
|
+
else:
|
|
345
|
+
# Flatten the list to align with valid output format
|
|
346
|
+
return [r for r in all_results["responses"]], []
|
|
239
347
|
|
|
240
348
|
results = await asyncio.gather(
|
|
241
349
|
*[async_llm_call(batch_prompt) for batch_prompt in batch_prompts]
|
|
242
350
|
)
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
351
|
+
valid_inputs = [row for result, _ in results for row in result]
|
|
352
|
+
failed_response_ids = [
|
|
353
|
+
failed_response_id
|
|
354
|
+
for _, batch_failures in results
|
|
355
|
+
for failed_response_id in batch_failures
|
|
356
|
+
]
|
|
247
357
|
|
|
358
|
+
return valid_inputs, failed_response_ids
|
|
248
359
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
360
|
+
|
|
361
|
+
def get_missing_response_ids(
|
|
362
|
+
input_response_ids: list[int], parsed_response: dict
|
|
363
|
+
) -> list[int]:
|
|
364
|
+
"""Identify which response IDs are missing from the LLM's parsed response.
|
|
253
365
|
|
|
254
366
|
Args:
|
|
255
367
|
input_response_ids (set[str]): Set of response IDs that were included in the
|
|
256
|
-
original prompt
|
|
368
|
+
original prompt.
|
|
257
369
|
parsed_response (dict): Parsed response from the LLM containing a 'responses' key
|
|
258
370
|
with a list of dictionaries, each containing a 'response_id' field.
|
|
259
371
|
|
|
260
372
|
Returns:
|
|
261
|
-
|
|
262
|
-
no additional IDs are present, False otherwise.
|
|
373
|
+
set[str]: Set of response IDs that are missing from the parsed response.
|
|
263
374
|
"""
|
|
264
|
-
response_ids_set = set(input_response_ids)
|
|
265
375
|
|
|
376
|
+
response_ids_set = {int(response_id) for response_id in input_response_ids}
|
|
266
377
|
returned_ids_set = {
|
|
267
|
-
|
|
268
|
-
element["response_id"]
|
|
269
|
-
) # treat ids as strings to match response_ids_in_each_prompt
|
|
378
|
+
int(element["response_id"])
|
|
270
379
|
for element in parsed_response["responses"]
|
|
271
380
|
if element.get("response_id", False)
|
|
272
381
|
}
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
logger.info(
|
|
277
|
-
|
|
278
|
-
)
|
|
279
|
-
return False
|
|
280
|
-
return True
|
|
382
|
+
|
|
383
|
+
missing_ids = list(response_ids_set - returned_ids_set)
|
|
384
|
+
if missing_ids:
|
|
385
|
+
logger.info(f"Missing response IDs from LLM output: {missing_ids}")
|
|
386
|
+
return missing_ids
|
|
281
387
|
|
|
282
388
|
|
|
283
389
|
def process_llm_responses(
|
|
@@ -298,13 +404,87 @@ def process_llm_responses(
|
|
|
298
404
|
- If no response_id in LLM output: DataFrame containing only the LLM results
|
|
299
405
|
"""
|
|
300
406
|
responses.loc[:, "response_id"] = responses["response_id"].astype(int)
|
|
301
|
-
|
|
302
|
-
response
|
|
303
|
-
for batch_response in llm_responses
|
|
304
|
-
for response in batch_response.get("responses", [])
|
|
305
|
-
]
|
|
306
|
-
task_responses = pd.DataFrame(unpacked_responses)
|
|
407
|
+
task_responses = pd.DataFrame(llm_responses)
|
|
307
408
|
if "response_id" in task_responses.columns:
|
|
308
409
|
task_responses["response_id"] = task_responses["response_id"].astype(int)
|
|
309
410
|
return responses.merge(task_responses, how="inner", on="response_id")
|
|
310
411
|
return task_responses
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def calculate_string_token_length(input_text: str, model: str = None) -> int:
|
|
415
|
+
"""
|
|
416
|
+
Calculates the number of tokens in a given string using the specified model's tokenizer.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
input_text (str): The input string to tokenize.
|
|
420
|
+
model (str, optional): The model name used for tokenization. If not provided,
|
|
421
|
+
uses the MODEL_NAME environment variable or defaults to "gpt-4o".
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
int: The number of tokens in the input string.
|
|
425
|
+
"""
|
|
426
|
+
# Use the MODEL_NAME env var if no model is provided; otherwise default to "gpt-4o"
|
|
427
|
+
model = model or os.environ.get("MODEL_NAME", "gpt-4o")
|
|
428
|
+
tokenizer_encoding = tiktoken.encoding_for_model(model)
|
|
429
|
+
number_of_tokens = len(tokenizer_encoding.encode(input_text))
|
|
430
|
+
return number_of_tokens
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def build_prompt(
|
|
434
|
+
prompt_template: PromptTemplate, input_batch: pd.DataFrame, **kwargs
|
|
435
|
+
) -> BatchPrompt:
|
|
436
|
+
"""
|
|
437
|
+
Constructs a BatchPrompt by formatting a prompt template with a batch of responses.
|
|
438
|
+
|
|
439
|
+
The function converts the input DataFrame batch into a list of dictionaries (one per row) and passes
|
|
440
|
+
this list to the prompt template's format method under the key 'responses', along with any additional
|
|
441
|
+
keyword arguments. It also extracts the 'response_id' column from the batch,
|
|
442
|
+
and uses these to create the BatchPrompt.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method that is used
|
|
446
|
+
to generate the prompt string.
|
|
447
|
+
input_batch (pd.DataFrame): A DataFrame containing the batch of responses, which must include a 'response_id'
|
|
448
|
+
column.
|
|
449
|
+
**kwargs: Additional keyword arguments to pass to the prompt template's format method.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
BatchPrompt: An object containing:
|
|
453
|
+
- prompt_string: The formatted prompt string for the batch.
|
|
454
|
+
- response_ids: A list of response IDs (as strings) corresponding to the responses in the batch.
|
|
455
|
+
"""
|
|
456
|
+
prompt = prompt_template.format(
|
|
457
|
+
responses=input_batch.to_dict(orient="records"), **kwargs
|
|
458
|
+
)
|
|
459
|
+
response_ids = input_batch["response_id"].astype(int).to_list()
|
|
460
|
+
return BatchPrompt(prompt_string=prompt, response_ids=response_ids)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def validate_task_data(
|
|
464
|
+
task_data: pd.DataFrame | list[dict], task_validation_model: Type[BaseModel] = None
|
|
465
|
+
) -> tuple[list[dict], list[dict]]:
|
|
466
|
+
"""
|
|
467
|
+
Validate each row in task_output against the provided Pydantic model.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
valid: a list of validated records (dicts).
|
|
471
|
+
invalid: a list of records (dicts) that failed validation.
|
|
472
|
+
"""
|
|
473
|
+
|
|
474
|
+
records = (
|
|
475
|
+
task_data.to_dict(orient="records")
|
|
476
|
+
if isinstance(task_data, pd.DataFrame)
|
|
477
|
+
else task_data
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
if task_validation_model:
|
|
481
|
+
valid_records, invalid_records = [], []
|
|
482
|
+
for record in records:
|
|
483
|
+
try:
|
|
484
|
+
task_validation_model(**record)
|
|
485
|
+
valid_records.append(record)
|
|
486
|
+
except ValidationError as e:
|
|
487
|
+
invalid_records.append(record)
|
|
488
|
+
logger.info(f"Failed Validation: {e}")
|
|
489
|
+
return valid_records, invalid_records
|
|
490
|
+
return records, []
|