themefinder 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,442 @@
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Optional
7
+
8
+ import openai
9
+ import pandas as pd
10
+ import tiktoken
11
+ from langchain_core.prompts import PromptTemplate
12
+ from langchain_core.runnables import Runnable
13
+ from pydantic import ValidationError
14
+ from tenacity import (
15
+ before,
16
+ before_sleep_log,
17
+ retry,
18
+ stop_after_attempt,
19
+ wait_random_exponential,
20
+ )
21
+
22
+ from themefinder.themefinder_logging import logger
23
+
24
+
25
+ @dataclass
26
+ class BatchPrompt:
27
+ prompt_string: str
28
+ response_ids: list[int]
29
+
30
+
31
+ async def batch_and_run(
32
+ input_df: pd.DataFrame,
33
+ prompt_template: str | Path | PromptTemplate,
34
+ llm: Runnable,
35
+ batch_size: int = 10,
36
+ partition_key: str | None = None,
37
+ integrity_check: bool = False,
38
+ concurrency: int = 10,
39
+ **kwargs: Any,
40
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
41
+ """Process a DataFrame of responses in batches using an LLM.
42
+
43
+ Args:
44
+ input_df (pd.DataFrame): DataFrame containing input to be processed.
45
+ Must include a 'response_id' column.
46
+ prompt_template (Union[str, Path, PromptTemplate]): Template for LLM prompts.
47
+ Can be a string (file path), Path object, or PromptTemplate.
48
+ llm (Runnable): LangChain Runnable instance that will process the prompts.
49
+ batch_size (int, optional): Number of input rows to process in each batch.
50
+ Defaults to 10.
51
+ partition_key (str | None, optional): Optional column name to group input rows
52
+ before batching. Defaults to None.
53
+ integrity_check (bool, optional): If True, verifies that all input
54
+ response IDs are present in LLM output.
55
+ If False, no integrity checking or retrying occurs. Defaults to False.
56
+ concurrency (int, optional): Maximum number of simultaneous LLM calls allowed.
57
+ Defaults to 10.
58
+ **kwargs (Any): Additional keyword arguments to pass to the prompt template.
59
+
60
+ Returns:
61
+ pd.DataFrame: DataFrame containing the original responses merged with the
62
+ LLM-processed results.
63
+ Returns:
64
+ tuple[pd.DataFrame, pd.DataFrame]:
65
+ A tuple containing two DataFrames:
66
+ - The first DataFrame contains the rows that were successfully processes by the LLM
67
+ - The second DataFrame contains the rows that could not be processed by the LLM
68
+ """
69
+
70
+ logger.info(f"Running batch and run with batch size {batch_size}")
71
+ prompt_template = convert_to_prompt_template(prompt_template)
72
+ batch_prompts = generate_prompts(
73
+ prompt_template,
74
+ input_df,
75
+ batch_size=batch_size,
76
+ partition_key=partition_key,
77
+ **kwargs,
78
+ )
79
+ processed_rows, failed_ids = await call_llm(
80
+ batch_prompts=batch_prompts,
81
+ llm=llm,
82
+ integrity_check=integrity_check,
83
+ concurrency=concurrency,
84
+ )
85
+ processed_results = process_llm_responses(processed_rows, input_df)
86
+
87
+ if failed_ids:
88
+ retry_df = input_df[input_df["response_id"].isin(failed_ids)]
89
+ retry_prompts = generate_prompts(
90
+ prompt_template, retry_df, batch_size=1, **kwargs
91
+ )
92
+ retry_results, unprocessable_ids = await call_llm(
93
+ batch_prompts=retry_prompts,
94
+ llm=llm,
95
+ integrity_check=integrity_check,
96
+ concurrency=concurrency,
97
+ )
98
+ retry_processed_results = process_llm_responses(retry_results, retry_df)
99
+ unprocessable_df = retry_df.loc[retry_df["response_id"].isin(unprocessable_ids)]
100
+ processed_results = pd.concat([processed_results, retry_processed_results])
101
+ else:
102
+ unprocessable_df = pd.DataFrame()
103
+ return processed_results, unprocessable_df
104
+
105
+
106
+ def load_prompt_from_file(file_path: str | Path) -> str:
107
+ """Load a prompt template from a text file in the prompts directory.
108
+
109
+ Args:
110
+ file_path (str | Path): Name of the prompt file (without .txt extension)
111
+ or Path object pointing to the file.
112
+
113
+ Returns:
114
+ str: Content of the prompt template file.
115
+ """
116
+ parent_dir = Path(__file__).parent
117
+ with Path.open(parent_dir / "prompts" / f"{file_path}.txt") as file:
118
+ return file.read()
119
+
120
+
121
+ def convert_to_prompt_template(prompt_template: str | Path | PromptTemplate):
122
+ """Convert various input types to a LangChain PromptTemplate.
123
+
124
+ Args:
125
+ prompt_template (str | Path | PromptTemplate): Input template that can be either:
126
+ - str: Name of a prompt file in the prompts directory (without .txt extension)
127
+ - Path: Path object pointing to a prompt file
128
+ - PromptTemplate: Already initialized LangChain PromptTemplate
129
+
130
+ Returns:
131
+ PromptTemplate: Initialized LangChain PromptTemplate object.
132
+
133
+ Raises:
134
+ TypeError: If prompt_template is not one of the expected types.
135
+ FileNotFoundError: If using str/Path input and the prompt file doesn't exist.
136
+ """
137
+ if isinstance(prompt_template, str | Path):
138
+ prompt_content = load_prompt_from_file(prompt_template)
139
+ template = PromptTemplate.from_template(template=prompt_content)
140
+ elif isinstance(prompt_template, PromptTemplate):
141
+ template = prompt_template
142
+ else:
143
+ msg = "Invalid prompt_template type. Expected str, Path, or PromptTemplate."
144
+ raise TypeError(msg)
145
+ return template
146
+
147
+
148
+ def partition_dataframe(
149
+ df: pd.DataFrame, partition_key: Optional[str]
150
+ ) -> list[pd.DataFrame]:
151
+ """Splits the DataFrame into partitions based on the partition_key if provided."""
152
+ if partition_key:
153
+ return [group.reset_index(drop=True) for _, group in df.groupby(partition_key)]
154
+ return [df]
155
+
156
+
157
+ def split_overflowing_batch(
158
+ batch: pd.DataFrame, allowed_tokens: int
159
+ ) -> list[pd.DataFrame]:
160
+ """
161
+ Splits a DataFrame batch into smaller sub-batches such that each sub-batch's total token count
162
+ does not exceed the allowed token limit.
163
+
164
+ Args:
165
+ batch (pd.DataFrame): The input DataFrame to split.
166
+ allowed_tokens (int): The maximum allowed number of tokens per sub-batch.
167
+
168
+ Returns:
169
+ list[pd.DataFrame]: A list of sub-batches, each within the token limit.
170
+ """
171
+ sub_batches = []
172
+ current_indices = []
173
+ current_token_sum = 0
174
+ token_counts = batch.apply(
175
+ lambda row: calculate_string_token_length(row.to_json()), axis=1
176
+ ).tolist()
177
+
178
+ for i, token_count in enumerate(token_counts):
179
+ if token_count > allowed_tokens:
180
+ logging.warning(
181
+ f"Row at index {batch.index[i]} exceeds allowed token limit ({token_count} > {allowed_tokens}). Skipping row."
182
+ )
183
+ continue
184
+
185
+ if current_token_sum + token_count > allowed_tokens:
186
+ if current_indices:
187
+ sub_batch = batch.iloc[current_indices].reset_index(drop=True)
188
+ if not sub_batch.empty:
189
+ sub_batches.append(sub_batch)
190
+ current_indices = [i]
191
+ current_token_sum = token_count
192
+ else:
193
+ current_indices.append(i)
194
+ current_token_sum += token_count
195
+
196
+ if current_indices:
197
+ sub_batch = batch.iloc[current_indices].reset_index(drop=True)
198
+ if not sub_batch.empty:
199
+ sub_batches.append(sub_batch)
200
+ return sub_batches
201
+
202
+
203
+ def batch_task_input_df(
204
+ df: pd.DataFrame,
205
+ allowed_tokens: int,
206
+ batch_size: int,
207
+ partition_key: Optional[str] = None,
208
+ ) -> list[pd.DataFrame]:
209
+ """
210
+ Partitions and batches a DataFrame according to a token limit and batch size, optionally using a partition key. Batches that exceed the token limit are further split.
211
+
212
+ Args:
213
+ df (pd.DataFrame): The input DataFrame to batch.
214
+ allowed_tokens (int): Maximum allowed tokens per batch.
215
+ batch_size (int): Maximum number of rows per batch before token filtering.
216
+ partition_key (Optional[str], optional): Column name to partition the DataFrame by.
217
+ Defaults to None.
218
+
219
+ Returns:
220
+ list[pd.DataFrame]: A list of batches, each within the specified token and size limits.
221
+ """
222
+ batches = []
223
+ partitions = partition_dataframe(df, partition_key)
224
+
225
+ for partition in partitions:
226
+ partition_batches = [
227
+ partition.iloc[i : i + batch_size].reset_index(drop=True)
228
+ for i in range(0, len(partition), batch_size)
229
+ ]
230
+ for batch in partition_batches:
231
+ batch_length = calculate_string_token_length(batch.to_json())
232
+ if batch_length <= allowed_tokens:
233
+ batches.append(batch)
234
+ else:
235
+ sub_batches = split_overflowing_batch(batch, allowed_tokens)
236
+ batches.extend(sub_batches)
237
+ return batches
238
+
239
+
240
+ def generate_prompts(
241
+ prompt_template: PromptTemplate,
242
+ input_data: pd.DataFrame,
243
+ batch_size: int = 50,
244
+ max_prompt_length: int = 50_000,
245
+ partition_key: str | None = None,
246
+ **kwargs,
247
+ ) -> list[BatchPrompt]:
248
+ """
249
+ Generate a list of BatchPrompt objects by splitting the input DataFrame into batches
250
+ and formatting each batch using a prompt template.
251
+
252
+ The function first calculates the token length of the prompt template to determine
253
+ the allowed tokens available for the input data. It then splits the input data into batches,
254
+ optionally partitioning by a specified key. Each batch is then formatted into a prompt string
255
+ using the provided prompt template, and a BatchPrompt is created containing the prompt string
256
+ and a list of response IDs from the batch.
257
+
258
+ Args:
259
+ prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method
260
+ used to create a prompt string from a list of response dictionaries.
261
+ input_data (pd.DataFrame): A DataFrame containing the input responses, with at least a
262
+ 'response_id' column.
263
+ batch_size (int, optional): Maximum number of rows to include in each batch. Defaults to 50.
264
+ max_prompt_length (int, optional): The maximum total token length allowed for the prompt,
265
+ including both the prompt template and the input data. Defaults to 50,000.
266
+ partition_key (str | None, optional): Column name used to partition the DataFrame before batching.
267
+ If provided, the DataFrame will be grouped by this key so that rows with the same value
268
+ remain in the same batch. Defaults to None.
269
+ **kwargs: Additional keyword arguments to pass to the prompt template's format method.
270
+
271
+ Returns:
272
+ list[BatchPrompt]: A list of BatchPrompt objects where each object contains:
273
+ - prompt_string: The formatted prompt string for a batch.
274
+ - response_ids: A list of response IDs corresponding to the rows in that batch.
275
+ """
276
+ prompt_token_length = calculate_string_token_length(prompt_template.template)
277
+ allowed_tokens_for_data = max_prompt_length - prompt_token_length
278
+ batches = batch_task_input_df(
279
+ input_data, allowed_tokens_for_data, batch_size, partition_key
280
+ )
281
+ prompts = [build_prompt(prompt_template, batch, **kwargs) for batch in batches]
282
+ return prompts
283
+
284
+
285
+ async def call_llm(
286
+ batch_prompts: list[BatchPrompt],
287
+ llm: Runnable,
288
+ concurrency: int = 10,
289
+ integrity_check: bool = False,
290
+ ) -> tuple[list[dict], list[int]]:
291
+ """Process multiple batches of prompts concurrently through an LLM with retry logic."""
292
+ semaphore = asyncio.Semaphore(concurrency)
293
+
294
+ @retry(
295
+ wait=wait_random_exponential(min=1, max=20),
296
+ stop=stop_after_attempt(6),
297
+ before=before.before_log(logger=logger, log_level=logging.DEBUG),
298
+ before_sleep=before_sleep_log(logger, logging.ERROR),
299
+ reraise=True,
300
+ )
301
+ async def async_llm_call(batch_prompt) -> tuple[list[dict], list[int]]:
302
+ async with semaphore:
303
+ try:
304
+ llm_response = await llm.ainvoke(batch_prompt.prompt_string)
305
+ all_results = (
306
+ llm_response.dict()
307
+ if hasattr(llm_response, "dict")
308
+ else llm_response
309
+ )
310
+ responses = (
311
+ all_results["responses"]
312
+ if isinstance(all_results, dict)
313
+ else all_results.responses
314
+ )
315
+ except (openai.BadRequestError, ValueError) as e:
316
+ logger.warning(e)
317
+ return [], batch_prompt.response_ids
318
+ except ValidationError as e:
319
+ logger.warning(e)
320
+ return [], batch_prompt.response_ids
321
+
322
+ if integrity_check:
323
+ failed_ids = get_missing_response_ids(
324
+ batch_prompt.response_ids, all_results
325
+ )
326
+ return responses, failed_ids
327
+ else:
328
+ return responses, []
329
+
330
+ results = await asyncio.gather(
331
+ *[async_llm_call(batch_prompt) for batch_prompt in batch_prompts]
332
+ )
333
+ valid_inputs = [row for result, _ in results for row in result]
334
+ failed_response_ids = [
335
+ failed_response_id
336
+ for _, batch_failures in results
337
+ for failed_response_id in batch_failures
338
+ ]
339
+
340
+ return valid_inputs, failed_response_ids
341
+
342
+
343
+ def get_missing_response_ids(
344
+ input_response_ids: list[int], parsed_response: dict
345
+ ) -> list[int]:
346
+ """Identify which response IDs are missing from the LLM's parsed response.
347
+
348
+ Args:
349
+ input_response_ids (set[str]): Set of response IDs that were included in the
350
+ original prompt.
351
+ parsed_response (dict): Parsed response from the LLM containing a 'responses' key
352
+ with a list of dictionaries, each containing a 'response_id' field.
353
+
354
+ Returns:
355
+ set[str]: Set of response IDs that are missing from the parsed response.
356
+ """
357
+
358
+ response_ids_set = {int(response_id) for response_id in input_response_ids}
359
+ returned_ids_set = {
360
+ int(element["response_id"])
361
+ for element in parsed_response["responses"]
362
+ if element.get("response_id", False)
363
+ }
364
+
365
+ missing_ids = list(response_ids_set - returned_ids_set)
366
+ if missing_ids:
367
+ logger.info(f"Missing response IDs from LLM output: {missing_ids}")
368
+ return missing_ids
369
+
370
+
371
+ def process_llm_responses(
372
+ llm_responses: list[dict[str, Any]], responses: pd.DataFrame
373
+ ) -> pd.DataFrame:
374
+ """Process and merge LLM responses with the original DataFrame.
375
+
376
+ Args:
377
+ llm_responses (list[dict[str, Any]]): List of LLM response dictionaries, where each
378
+ dictionary contains a 'responses' key with a list of individual response objects.
379
+ responses (pd.DataFrame): Original DataFrame containing the input responses, must
380
+ include a 'response_id' column.
381
+
382
+ Returns:
383
+ pd.DataFrame: A merged DataFrame containing:
384
+ - If response_id exists in LLM output: Original responses joined with LLM results
385
+ on response_id (inner join)
386
+ - If no response_id in LLM output: DataFrame containing only the LLM results
387
+ """
388
+ responses.loc[:, "response_id"] = responses["response_id"].astype(int)
389
+ task_responses = pd.DataFrame(llm_responses)
390
+ if "response_id" in task_responses.columns:
391
+ task_responses["response_id"] = task_responses["response_id"].astype(int)
392
+ return responses.merge(task_responses, how="inner", on="response_id")
393
+ return task_responses
394
+
395
+
396
+ def calculate_string_token_length(input_text: str, model: str = None) -> int:
397
+ """
398
+ Calculates the number of tokens in a given string using the specified model's tokenizer.
399
+
400
+ Args:
401
+ input_text (str): The input string to tokenize.
402
+ model (str, optional): The model name used for tokenization. If not provided,
403
+ uses the MODEL_NAME environment variable or defaults to "gpt-4o".
404
+
405
+ Returns:
406
+ int: The number of tokens in the input string.
407
+ """
408
+ # Use the MODEL_NAME env var if no model is provided; otherwise default to "gpt-4o"
409
+ model = model or os.environ.get("MODEL_NAME", "gpt-4o")
410
+ tokenizer_encoding = tiktoken.encoding_for_model(model)
411
+ number_of_tokens = len(tokenizer_encoding.encode(input_text))
412
+ return number_of_tokens
413
+
414
+
415
+ def build_prompt(
416
+ prompt_template: PromptTemplate, input_batch: pd.DataFrame, **kwargs
417
+ ) -> BatchPrompt:
418
+ """
419
+ Constructs a BatchPrompt by formatting a prompt template with a batch of responses.
420
+
421
+ The function converts the input DataFrame batch into a list of dictionaries (one per row) and passes
422
+ this list to the prompt template's format method under the key 'responses', along with any additional
423
+ keyword arguments. It also extracts the 'response_id' column from the batch,
424
+ and uses these to create the BatchPrompt.
425
+
426
+ Args:
427
+ prompt_template (PromptTemplate): An object with a 'template' attribute and a 'format' method that is used
428
+ to generate the prompt string.
429
+ input_batch (pd.DataFrame): A DataFrame containing the batch of responses, which must include a 'response_id'
430
+ column.
431
+ **kwargs: Additional keyword arguments to pass to the prompt template's format method.
432
+
433
+ Returns:
434
+ BatchPrompt: An object containing:
435
+ - prompt_string: The formatted prompt string for the batch.
436
+ - response_ids: A list of response IDs (as strings) corresponding to the responses in the batch.
437
+ """
438
+ prompt = prompt_template.format(
439
+ responses=input_batch.to_dict(orient="records"), **kwargs
440
+ )
441
+ response_ids = input_batch["response_id"].astype(int).to_list()
442
+ return BatchPrompt(prompt_string=prompt, response_ids=response_ids)