themefinder 0.5.4__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of themefinder might be problematic. Click here for more details.

@@ -1,14 +1,23 @@
1
1
  import asyncio
2
- import json
3
2
  import logging
3
+ import os
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
- from typing import Any
6
+ from typing import Any, Optional
7
7
 
8
+ import openai
8
9
  import pandas as pd
10
+ import tiktoken
9
11
  from langchain_core.prompts import PromptTemplate
10
12
  from langchain_core.runnables import Runnable
11
- from tenacity import before, retry, stop_after_attempt, wait_random_exponential
13
+ from pydantic import ValidationError
14
+ from tenacity import (
15
+ before,
16
+ before_sleep_log,
17
+ retry,
18
+ stop_after_attempt,
19
+ wait_random_exponential,
20
+ )
12
21
 
13
22
  from .themefinder_logging import logger
14
23
 
@@ -16,63 +25,82 @@ from .themefinder_logging import logger
16
25
  @dataclass
17
26
  class BatchPrompt:
18
27
  prompt_string: str
19
- response_ids: list[str]
28
+ response_ids: list[int]
20
29
 
21
30
 
22
31
  async def batch_and_run(
23
- responses_df: pd.DataFrame,
32
+ input_df: pd.DataFrame,
24
33
  prompt_template: str | Path | PromptTemplate,
25
34
  llm: Runnable,
26
35
  batch_size: int = 10,
27
36
  partition_key: str | None = None,
28
- response_id_integrity_check: bool = False,
37
+ integrity_check: bool = False,
38
+ concurrency: int = 10,
29
39
  **kwargs: Any,
30
- ) -> pd.DataFrame:
40
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
31
41
  """Process a DataFrame of responses in batches using an LLM.
32
42
 
33
43
  Args:
34
- responses_df (pd.DataFrame): DataFrame containing responses to be processed.
44
+ input_df (pd.DataFrame): DataFrame containing input to be processed.
35
45
  Must include a 'response_id' column.
36
46
  prompt_template (Union[str, Path, PromptTemplate]): Template for LLM prompts.
37
47
  Can be a string (file path), Path object, or PromptTemplate.
38
48
  llm (Runnable): LangChain Runnable instance that will process the prompts.
39
- batch_size (int, optional): Number of responses to process in each batch.
49
+ batch_size (int, optional): Number of input rows to process in each batch.
40
50
  Defaults to 10.
41
- partition_key (str | None, optional): Optional column name to group responses
51
+ partition_key (str | None, optional): Optional column name to group input rows
42
52
  before batching. Defaults to None.
43
- response_id_integrity_check (bool, optional): If True, verifies that all input
44
- response IDs are present in LLM output and retries failed responses individually.
53
+ integrity_check (bool, optional): If True, verifies that all input
54
+ response IDs are present in LLM output.
45
55
  If False, no integrity checking or retrying occurs. Defaults to False.
56
+ concurrency (int, optional): Maximum number of simultaneous LLM calls allowed.
57
+ Defaults to 10.
46
58
  **kwargs (Any): Additional keyword arguments to pass to the prompt template.
47
59
 
48
60
  Returns:
49
61
  pd.DataFrame: DataFrame containing the original responses merged with the
50
62
  LLM-processed results.
63
+ Returns:
64
+ tuple[pd.DataFrame, pd.DataFrame]:
65
+ A tuple containing two DataFrames:
66
+ - The first DataFrame contains the rows that were successfully processes by the LLM
67
+ - The second DataFrame contains the rows that could not be processed by the LLM
51
68
  """
69
+
52
70
  logger.info(f"Running batch and run with batch size {batch_size}")
53
71
  prompt_template = convert_to_prompt_template(prompt_template)
54
- batched_response_dfs = batch_responses(
55
- responses_df, batch_size=batch_size, partition_key=partition_key
72
+ batch_prompts = generate_prompts(
73
+ prompt_template,
74
+ input_df,
75
+ batch_size=batch_size,
76
+ partition_key=partition_key,
77
+ **kwargs,
56
78
  )
57
- batch_prompts = generate_prompts(batched_response_dfs, prompt_template, **kwargs)
58
- llm_responses, failed_ids = await call_llm(
79
+ processed_rows, failed_ids = await call_llm(
59
80
  batch_prompts=batch_prompts,
60
81
  llm=llm,
61
- response_id_integrity_check=response_id_integrity_check,
82
+ integrity_check=integrity_check,
83
+ concurrency=concurrency,
62
84
  )
63
- processed_responses = process_llm_responses(llm_responses, responses_df)
85
+ processed_results = process_llm_responses(processed_rows, input_df)
86
+
64
87
  if failed_ids:
65
- new_df = responses_df[responses_df["response_id"].astype(str).isin(failed_ids)]
66
- processed_failed_responses = await batch_and_run(
67
- responses_df=new_df,
68
- prompt_template=prompt_template,
88
+ retry_df = input_df[input_df["response_id"].isin(failed_ids)]
89
+ retry_prompts = generate_prompts(
90
+ prompt_template, retry_df, batch_size=1, **kwargs
91
+ )
92
+ retry_results, unprocessable_ids = await call_llm(
93
+ batch_prompts=retry_prompts,
69
94
  llm=llm,
70
- batch_size=1,
71
- partition_key=partition_key,
72
- **kwargs,
95
+ integrity_check=integrity_check,
96
+ concurrency=concurrency,
73
97
  )
74
- return pd.concat(objs=[processed_failed_responses, processed_responses])
75
- return processed_responses
98
+ retry_processed_results = process_llm_responses(retry_results, retry_df)
99
+ unprocessable_df = retry_df.loc[retry_df["response_id"].isin(unprocessable_ids)]
100
+ processed_results = pd.concat([processed_results, retry_processed_results])
101
+ else:
102
+ unprocessable_df = pd.DataFrame()
103
+ return processed_results, unprocessable_df
76
104
 
77
105
 
78
106
  def load_prompt_from_file(file_path: str | Path) -> str:
@@ -117,167 +145,227 @@ def convert_to_prompt_template(prompt_template: str | Path | PromptTemplate):
117
145
  return template
118
146
 
119
147
 
120
- def batch_responses(
121
- responses_df: pd.DataFrame, batch_size: int = 10, partition_key: str | None = None
148
+ def partition_dataframe(
149
+ df: pd.DataFrame, partition_key: Optional[str]
150
+ ) -> list[pd.DataFrame]:
151
+ """Splits the DataFrame into partitions based on the partition_key if provided."""
152
+ if partition_key:
153
+ return [group.reset_index(drop=True) for _, group in df.groupby(partition_key)]
154
+ return [df]
155
+
156
+
157
+ def split_overflowing_batch(
158
+ batch: pd.DataFrame, allowed_tokens: int
122
159
  ) -> list[pd.DataFrame]:
123
- """Split a DataFrame into batches, optionally partitioned by a key column.
160
+ """
161
+ Splits a DataFrame batch into smaller sub-batches such that each sub-batch's total token count
162
+ does not exceed the allowed token limit.
124
163
 
125
164
  Args:
126
- responses_df (pd.DataFrame): Input DataFrame to be split into batches.
127
- batch_size (int, optional): Maximum number of rows in each batch. Defaults to 10.
128
- partition_key (str | None, optional): Column name to group by before batching.
129
- If provided, ensures rows with the same partition key value stay together
130
- and each group is batched separately. Defaults to None.
165
+ batch (pd.DataFrame): The input DataFrame to split.
166
+ allowed_tokens (int): The maximum allowed number of tokens per sub-batch.
131
167
 
132
168
  Returns:
133
- list[pd.DataFrame]: List of DataFrame batches, where each batch contains
134
- at most batch_size rows. If partition_key is used, rows within each
135
- partition are kept together and batched separately.
169
+ list[pd.DataFrame]: A list of sub-batches, each within the token limit.
136
170
  """
137
- if partition_key:
138
- grouped = responses_df.groupby(partition_key)
139
- batches = []
140
- for _, group in grouped:
141
- group_batches = [
142
- group.iloc[i : i + batch_size].reset_index(drop=True)
143
- for i in range(0, len(group), batch_size)
144
- ]
145
- batches.extend(group_batches)
146
- return batches
147
-
148
- return [
149
- responses_df.iloc[i : i + batch_size].reset_index(drop=True)
150
- for i in range(0, len(responses_df), batch_size)
151
- ]
171
+ sub_batches = []
172
+ current_indices = []
173
+ current_token_sum = 0
174
+ token_counts = batch.apply(
175
+ lambda row: calculate_string_token_length(row.to_json()), axis=1
176
+ ).tolist()
177
+
178
+ for i, token_count in enumerate(token_counts):
179
+ if token_count > allowed_tokens:
180
+ logging.warning(
181
+ f"Row at index {batch.index[i]} exceeds allowed token limit ({token_count} > {allowed_tokens}). Skipping row."
182
+ )
183
+ continue
184
+
185
+ if current_token_sum + token_count > allowed_tokens:
186
+ if current_indices:
187
+ sub_batch = batch.iloc[current_indices].reset_index(drop=True)
188
+ if not sub_batch.empty:
189
+ sub_batches.append(sub_batch)
190
+ current_indices = [i]
191
+ current_token_sum = token_count
192
+ else:
193
+ current_indices.append(i)
194
+ current_token_sum += token_count
195
+
196
+ if current_indices:
197
+ sub_batch = batch.iloc[current_indices].reset_index(drop=True)
198
+ if not sub_batch.empty:
199
+ sub_batches.append(sub_batch)
200
+ return sub_batches
201
+
202
+
203
+ def batch_task_input_df(
204
+ df: pd.DataFrame,
205
+ allowed_tokens: int,
206
+ batch_size: int,
207
+ partition_key: Optional[str] = None,
208
+ ) -> list[pd.DataFrame]:
209
+ """
210
+ Partitions and batches a DataFrame according to a token limit and batch size, optionally using a partition key. Batches that exceed the token limit are further split.
211
+
212
+ Args:
213
+ df (pd.DataFrame): The input DataFrame to batch.
214
+ allowed_tokens (int): Maximum allowed tokens per batch.
215
+ batch_size (int): Maximum number of rows per batch before token filtering.
216
+ partition_key (Optional[str], optional): Column name to partition the DataFrame by.
217
+ Defaults to None.
218
+
219
+ Returns:
220
+ list[pd.DataFrame]: A list of batches, each within the specified token and size limits.
221
+ """
222
+ batches = []
223
+ partitions = partition_dataframe(df, partition_key)
224
+
225
+ for partition in partitions:
226
+ partition_batches = [
227
+ partition.iloc[i : i + batch_size].reset_index(drop=True)
228
+ for i in range(0, len(partition), batch_size)
229
+ ]
230
+ for batch in partition_batches:
231
+ batch_length = calculate_string_token_length(batch.to_json())
232
+ if batch_length <= allowed_tokens:
233
+ batches.append(batch)
234
+ else:
235
+ sub_batches = split_overflowing_batch(batch, allowed_tokens)
236
+ batches.extend(sub_batches)
237
+ return batches
152
238
 
153
239
 
154
240
  def generate_prompts(
155
- response_dfs: list[pd.DataFrame], prompt_template: PromptTemplate, **kwargs: Any
241
+ prompt_template: PromptTemplate,
242
+ input_data: pd.DataFrame,
243
+ batch_size: int = 50,
244
+ max_prompt_length: int = 50_000,
245
+ partition_key: str | None = None,
246
+ **kwargs,
156
247
  ) -> list[BatchPrompt]:
157
- """Generate a list of BatchPrompts from DataFrames using a prompt template.
248
+ """
249
+ Generate a list of BatchPrompt objects by splitting the input DataFrame into batches
250
+ and formatting each batch using a prompt template.
251
+
252
+ The function first calculates the token length of the prompt template to determine
253
+ the allowed tokens available for the input data. It then splits the input data into batches,
254
+ optionally partitioning by a specified key. Each batch is then formatted into a prompt string
255
+ using the provided prompt template, and a BatchPrompt is created containing the prompt string
256
+ and a list of response IDs from the batch.
158
257
 
159
258
  Args:
160
- response_dfs (list[pd.DataFrame]): List of DataFrames, each containing a batch
161
- of responses to be processed. Each DataFrame must include a 'response_id' column.
162
- prompt_template (PromptTemplate): LangChain PromptTemplate object used to format
163
- the prompts for each batch.
164
- **kwargs (Any): Additional keyword arguments to pass to the prompt template's
165
- format method.
259
+ prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method
260
+ used to create a prompt string from a list of response dictionaries.
261
+ input_data (pd.DataFrame): A DataFrame containing the input responses, with at least a
262
+ 'response_id' column.
263
+ batch_size (int, optional): Maximum number of rows to include in each batch. Defaults to 50.
264
+ max_prompt_length (int, optional): The maximum total token length allowed for the prompt,
265
+ including both the prompt template and the input data. Defaults to 50,000.
266
+ partition_key (str | None, optional): Column name used to partition the DataFrame before batching.
267
+ If provided, the DataFrame will be grouped by this key so that rows with the same value
268
+ remain in the same batch. Defaults to None.
269
+ **kwargs: Additional keyword arguments to pass to the prompt template's format method.
166
270
 
167
271
  Returns:
168
- list[BatchPrompt]: List of BatchPrompt objects, each containing:
169
- - prompt_string: Formatted prompt text for the batch
170
- - response_ids: List of response IDs included in the batch
171
-
172
- Note:
173
- The function converts each DataFrame to a list of dictionaries and passes it
174
- to the prompt template as the 'responses' variable.
272
+ list[BatchPrompt]: A list of BatchPrompt objects where each object contains:
273
+ - prompt_string: The formatted prompt string for a batch.
274
+ - response_ids: A list of response IDs corresponding to the rows in that batch.
175
275
  """
176
- batched_prompts = []
177
- for df in response_dfs:
178
- prompt = prompt_template.format(
179
- responses=df.to_dict(orient="records"), **kwargs
180
- )
181
- response_ids = df["response_id"].astype(str).to_list()
182
- batched_prompts.append(
183
- BatchPrompt(prompt_string=prompt, response_ids=response_ids)
184
- )
185
-
186
- return batched_prompts
276
+ prompt_token_length = calculate_string_token_length(prompt_template.template)
277
+ allowed_tokens_for_data = max_prompt_length - prompt_token_length
278
+ batches = batch_task_input_df(
279
+ input_data, allowed_tokens_for_data, batch_size, partition_key
280
+ )
281
+ prompts = [build_prompt(prompt_template, batch, **kwargs) for batch in batches]
282
+ return prompts
187
283
 
188
284
 
189
285
  async def call_llm(
190
286
  batch_prompts: list[BatchPrompt],
191
287
  llm: Runnable,
192
288
  concurrency: int = 10,
193
- response_id_integrity_check: bool = False,
194
- ):
195
- """Process multiple batches of prompts concurrently through an LLM with retry logic.
196
-
197
- Args:
198
- batch_prompts (list[BatchPrompt]): List of BatchPrompt objects, each containing a
199
- prompt string and associated response IDs to be processed.
200
- llm (Runnable): LangChain Runnable instance that will process the prompts.
201
- concurrency (int, optional): Maximum number of simultaneous LLM calls allowed.
202
- Defaults to 10.
203
- response_id_integrity_check (bool, optional): If True, verifies that all input
204
- response IDs are present in the LLM output. Failed batches are discarded and
205
- their IDs are returned for retry. Defaults to False.
206
-
207
- Returns:
208
- tuple[list[dict[str, Any]], set[str]]: A tuple containing:
209
- - list of successful LLM responses as dictionaries
210
- - set of failed response IDs (empty if no failures or integrity check is False)
211
-
212
- Notes:
213
- - Uses exponential backoff retry strategy with up to 6 attempts per batch
214
- - Failed batches (when integrity check fails) return None and are filtered out
215
- - Concurrency is managed via asyncio.Semaphore to prevent overwhelming the LLM
216
- """
289
+ integrity_check: bool = False,
290
+ ) -> tuple[list[dict], list[int]]:
291
+ """Process multiple batches of prompts concurrently through an LLM with retry logic."""
217
292
  semaphore = asyncio.Semaphore(concurrency)
218
- failed_ids: set = set()
219
293
 
220
294
  @retry(
221
295
  wait=wait_random_exponential(min=1, max=20),
222
296
  stop=stop_after_attempt(6),
223
297
  before=before.before_log(logger=logger, log_level=logging.DEBUG),
298
+ before_sleep=before_sleep_log(logger, logging.ERROR),
224
299
  reraise=True,
225
300
  )
226
- async def async_llm_call(batch_prompt):
301
+ async def async_llm_call(batch_prompt) -> tuple[list[dict], list[int]]:
227
302
  async with semaphore:
228
- response = await llm.ainvoke(batch_prompt.prompt_string)
229
- parsed_response = json.loads(response.content)
230
-
231
- if response_id_integrity_check and not check_response_integrity(
232
- batch_prompt.response_ids, parsed_response
233
- ):
234
- # discard this response but keep track of failed response ids
235
- failed_ids.update(batch_prompt.response_ids)
236
- return None
237
-
238
- return parsed_response
303
+ try:
304
+ llm_response = await llm.ainvoke(batch_prompt.prompt_string)
305
+ all_results = (
306
+ llm_response.dict()
307
+ if hasattr(llm_response, "dict")
308
+ else llm_response
309
+ )
310
+ responses = (
311
+ all_results["responses"]
312
+ if isinstance(all_results, dict)
313
+ else all_results.responses
314
+ )
315
+ except (openai.BadRequestError, ValueError) as e:
316
+ logger.warning(e)
317
+ return [], batch_prompt.response_ids
318
+ except ValidationError as e:
319
+ logger.warning(e)
320
+ return [], batch_prompt.response_ids
321
+
322
+ if integrity_check:
323
+ failed_ids = get_missing_response_ids(
324
+ batch_prompt.response_ids, all_results
325
+ )
326
+ return responses, failed_ids
327
+ else:
328
+ return responses, []
239
329
 
240
330
  results = await asyncio.gather(
241
331
  *[async_llm_call(batch_prompt) for batch_prompt in batch_prompts]
242
332
  )
243
- successful_responses = [
244
- r for r in results if r is not None
245
- ] # ignore discarded responses
246
- return (successful_responses, failed_ids)
333
+ valid_inputs = [row for result, _ in results for row in result]
334
+ failed_response_ids = [
335
+ failed_response_id
336
+ for _, batch_failures in results
337
+ for failed_response_id in batch_failures
338
+ ]
247
339
 
340
+ return valid_inputs, failed_response_ids
248
341
 
249
- def check_response_integrity(
250
- input_response_ids: set[str], parsed_response: dict
251
- ) -> bool:
252
- """Verify that all input response IDs are present in the LLM's parsed response.
342
+
343
+ def get_missing_response_ids(
344
+ input_response_ids: list[int], parsed_response: dict
345
+ ) -> list[int]:
346
+ """Identify which response IDs are missing from the LLM's parsed response.
253
347
 
254
348
  Args:
255
349
  input_response_ids (set[str]): Set of response IDs that were included in the
256
- original prompt sent to the LLM.
350
+ original prompt.
257
351
  parsed_response (dict): Parsed response from the LLM containing a 'responses' key
258
352
  with a list of dictionaries, each containing a 'response_id' field.
259
353
 
260
354
  Returns:
261
- bool: True if all input response IDs are present in the parsed response and
262
- no additional IDs are present, False otherwise.
355
+ set[str]: Set of response IDs that are missing from the parsed response.
263
356
  """
264
- response_ids_set = set(input_response_ids)
265
357
 
358
+ response_ids_set = {int(response_id) for response_id in input_response_ids}
266
359
  returned_ids_set = {
267
- str(
268
- element["response_id"]
269
- ) # treat ids as strings to match response_ids_in_each_prompt
360
+ int(element["response_id"])
270
361
  for element in parsed_response["responses"]
271
362
  if element.get("response_id", False)
272
363
  }
273
- # assumes: all input ids ought to be present in output
274
- if returned_ids_set != response_ids_set:
275
- logger.info("Failed integrity check")
276
- logger.info(
277
- f"Present in original but not returned from LLM: {response_ids_set - returned_ids_set}. Returned in LLM but not present in original: {returned_ids_set - response_ids_set}"
278
- )
279
- return False
280
- return True
364
+
365
+ missing_ids = list(response_ids_set - returned_ids_set)
366
+ if missing_ids:
367
+ logger.info(f"Missing response IDs from LLM output: {missing_ids}")
368
+ return missing_ids
281
369
 
282
370
 
283
371
  def process_llm_responses(
@@ -298,13 +386,57 @@ def process_llm_responses(
298
386
  - If no response_id in LLM output: DataFrame containing only the LLM results
299
387
  """
300
388
  responses.loc[:, "response_id"] = responses["response_id"].astype(int)
301
- unpacked_responses = [
302
- response
303
- for batch_response in llm_responses
304
- for response in batch_response.get("responses", [])
305
- ]
306
- task_responses = pd.DataFrame(unpacked_responses)
389
+ task_responses = pd.DataFrame(llm_responses)
307
390
  if "response_id" in task_responses.columns:
308
391
  task_responses["response_id"] = task_responses["response_id"].astype(int)
309
392
  return responses.merge(task_responses, how="inner", on="response_id")
310
393
  return task_responses
394
+
395
+
396
+ def calculate_string_token_length(input_text: str, model: str = None) -> int:
397
+ """
398
+ Calculates the number of tokens in a given string using the specified model's tokenizer.
399
+
400
+ Args:
401
+ input_text (str): The input string to tokenize.
402
+ model (str, optional): The model name used for tokenization. If not provided,
403
+ uses the MODEL_NAME environment variable or defaults to "gpt-4o".
404
+
405
+ Returns:
406
+ int: The number of tokens in the input string.
407
+ """
408
+ # Use the MODEL_NAME env var if no model is provided; otherwise default to "gpt-4o"
409
+ model = model or os.environ.get("MODEL_NAME", "gpt-4o")
410
+ tokenizer_encoding = tiktoken.encoding_for_model(model)
411
+ number_of_tokens = len(tokenizer_encoding.encode(input_text))
412
+ return number_of_tokens
413
+
414
+
415
+ def build_prompt(
416
+ prompt_template: PromptTemplate, input_batch: pd.DataFrame, **kwargs
417
+ ) -> BatchPrompt:
418
+ """
419
+ Constructs a BatchPrompt by formatting a prompt template with a batch of responses.
420
+
421
+ The function converts the input DataFrame batch into a list of dictionaries (one per row) and passes
422
+ this list to the prompt template's format method under the key 'responses', along with any additional
423
+ keyword arguments. It also extracts the 'response_id' column from the batch,
424
+ and uses these to create the BatchPrompt.
425
+
426
+ Args:
427
+ prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method that is used
428
+ to generate the prompt string.
429
+ input_batch (pd.DataFrame): A DataFrame containing the batch of responses, which must include a 'response_id'
430
+ column.
431
+ **kwargs: Additional keyword arguments to pass to the prompt template's format method.
432
+
433
+ Returns:
434
+ BatchPrompt: An object containing:
435
+ - prompt_string: The formatted prompt string for the batch.
436
+ - response_ids: A list of response IDs (as strings) corresponding to the responses in the batch.
437
+ """
438
+ prompt = prompt_template.format(
439
+ responses=input_batch.to_dict(orient="records"), **kwargs
440
+ )
441
+ response_ids = input_batch["response_id"].astype(int).to_list()
442
+ return BatchPrompt(prompt_string=prompt, response_ids=response_ids)