themefinder 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of themefinder might be problematic. Click here for more details.

themefinder/__init__.py CHANGED
@@ -5,6 +5,8 @@ from .core import (
5
5
  theme_generation,
6
6
  theme_mapping,
7
7
  theme_refinement,
8
+ theme_target_alignment,
9
+ detail_detection,
8
10
  )
9
11
 
10
12
  __all__ = [
@@ -13,6 +15,8 @@ __all__ = [
13
15
  "theme_generation",
14
16
  "theme_condensation",
15
17
  "theme_refinement",
18
+ "theme_target_alignment",
16
19
  "theme_mapping",
20
+ "detail_detection",
17
21
  ]
18
22
  __version__ = "0.1.0"
themefinder/core.py CHANGED
@@ -3,10 +3,17 @@ from pathlib import Path
3
3
 
4
4
  import pandas as pd
5
5
  from langchain_core.prompts import PromptTemplate
6
- from langchain_core.runnables import Runnable
6
+ from langchain.schema.runnable import RunnableWithFallbacks
7
7
 
8
8
  from .llm_batch_processor import batch_and_run, load_prompt_from_file
9
- from .models import SentimentAnalysisOutput, ThemeMappingOutput
9
+ from .models import (
10
+ SentimentAnalysisResponses,
11
+ ThemeGenerationResponses,
12
+ ThemeCondensationResponses,
13
+ ThemeRefinementResponses,
14
+ ThemeMappingResponses,
15
+ DetailDetectionResponses,
16
+ )
10
17
  from .themefinder_logging import logger
11
18
 
12
19
  CONSULTATION_SYSTEM_PROMPT = load_prompt_from_file("consultation_system_prompt")
@@ -14,11 +21,12 @@ CONSULTATION_SYSTEM_PROMPT = load_prompt_from_file("consultation_system_prompt")
14
21
 
15
22
  async def find_themes(
16
23
  responses_df: pd.DataFrame,
17
- llm: Runnable,
24
+ llm: RunnableWithFallbacks,
18
25
  question: str,
19
26
  target_n_themes: int | None = None,
20
27
  system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
21
28
  verbose: bool = True,
29
+ concurrency: int = 10,
22
30
  ) -> dict[str, str | pd.DataFrame]:
23
31
  """Process survey responses through a multi-stage theme analysis pipeline.
24
32
 
@@ -32,7 +40,7 @@ async def find_themes(
32
40
 
33
41
  Args:
34
42
  responses_df (pd.DataFrame): DataFrame containing survey responses
35
- llm (Runnable): Language model instance for text analysis
43
+ llm (RunnableWithFallbacks): Language model instance for text analysis
36
44
  question (str): The survey question
37
45
  target_n_themes (int | None, optional): Target number of themes to consolidate to.
38
46
  If None, skip theme target alignment step. Defaults to None.
@@ -40,6 +48,7 @@ async def find_themes(
40
48
  Defaults to CONSULTATION_SYSTEM_PROMPT.
41
49
  verbose (bool): Whether to show information messages during processing.
42
50
  Defaults to True.
51
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
43
52
 
44
53
  Returns:
45
54
  dict[str, str | pd.DataFrame]: Dictionary containing results from each pipeline stage:
@@ -56,21 +65,28 @@ async def find_themes(
56
65
  llm,
57
66
  question=question,
58
67
  system_prompt=system_prompt,
68
+ concurrency=concurrency,
59
69
  )
60
70
  theme_df, _ = await theme_generation(
61
71
  sentiment_df,
62
72
  llm,
63
73
  question=question,
64
74
  system_prompt=system_prompt,
75
+ concurrency=concurrency,
65
76
  )
66
77
  condensed_theme_df, _ = await theme_condensation(
67
- theme_df, llm, question=question, system_prompt=system_prompt
78
+ theme_df,
79
+ llm,
80
+ question=question,
81
+ system_prompt=system_prompt,
82
+ concurrency=concurrency,
68
83
  )
69
84
  refined_theme_df, _ = await theme_refinement(
70
85
  condensed_theme_df,
71
86
  llm,
72
87
  question=question,
73
88
  system_prompt=system_prompt,
89
+ concurrency=concurrency,
74
90
  )
75
91
  if target_n_themes is not None:
76
92
  refined_theme_df, _ = await theme_target_alignment(
@@ -79,6 +95,7 @@ async def find_themes(
79
95
  question=question,
80
96
  target_n_themes=target_n_themes,
81
97
  system_prompt=system_prompt,
98
+ concurrency=concurrency,
82
99
  )
83
100
  mapping_df, mapping_unprocessables = await theme_mapping(
84
101
  sentiment_df[["response_id", "response"]],
@@ -86,6 +103,14 @@ async def find_themes(
86
103
  question=question,
87
104
  refined_themes_df=refined_theme_df,
88
105
  system_prompt=system_prompt,
106
+ concurrency=concurrency,
107
+ )
108
+ detailed_df, _ = await detail_detection(
109
+ responses_df[["response_id", "response"]],
110
+ llm,
111
+ question=question,
112
+ system_prompt=system_prompt,
113
+ concurrency=concurrency,
89
114
  )
90
115
 
91
116
  logger.info("Finished finding themes")
@@ -97,17 +122,19 @@ async def find_themes(
97
122
  "sentiment": sentiment_df,
98
123
  "themes": refined_theme_df,
99
124
  "mapping": mapping_df,
125
+ "detailed_responses": detailed_df,
100
126
  "unprocessables": pd.concat([sentiment_unprocessables, mapping_unprocessables]),
101
127
  }
102
128
 
103
129
 
104
130
  async def sentiment_analysis(
105
131
  responses_df: pd.DataFrame,
106
- llm: Runnable,
132
+ llm: RunnableWithFallbacks,
107
133
  question: str,
108
134
  batch_size: int = 20,
109
135
  prompt_template: str | Path | PromptTemplate = "sentiment_analysis",
110
136
  system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
137
+ concurrency: int = 10,
111
138
  ) -> tuple[pd.DataFrame, pd.DataFrame]:
112
139
  """Perform sentiment analysis on survey responses using an LLM.
113
140
 
@@ -117,7 +144,7 @@ async def sentiment_analysis(
117
144
  Args:
118
145
  responses_df (pd.DataFrame): DataFrame containing survey responses to analyze.
119
146
  Must contain 'response_id' and 'response' columns.
120
- llm (Runnable): Language model instance to use for sentiment analysis.
147
+ llm (RunnableWithFallbacks): Language model instance to use for sentiment analysis.
121
148
  question (str): The survey question.
122
149
  batch_size (int, optional): Number of responses to process in each batch.
123
150
  Defaults to 20.
@@ -126,6 +153,7 @@ async def sentiment_analysis(
126
153
  or PromptTemplate instance. Defaults to "sentiment_analysis".
127
154
  system_prompt (str): System prompt to guide the LLM's behavior.
128
155
  Defaults to CONSULTATION_SYSTEM_PROMPT.
156
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
129
157
 
130
158
  Returns:
131
159
  tuple[pd.DataFrame, pd.DataFrame]:
@@ -134,32 +162,33 @@ async def sentiment_analysis(
134
162
  - The second DataFrame contains the rows that could not be processed by the LLM
135
163
 
136
164
  Note:
137
- The function uses validation_check to ensure responses maintain
165
+ The function uses integrity_check to ensure responses maintain
138
166
  their original order and association after processing.
139
167
  """
140
168
  logger.info(f"Running sentiment analysis on {len(responses_df)} responses")
141
- processed_rows, unprocessable_rows = await batch_and_run(
169
+ sentiment, unprocessable = await batch_and_run(
142
170
  responses_df,
143
171
  prompt_template,
144
- llm,
172
+ llm.with_structured_output(SentimentAnalysisResponses),
145
173
  batch_size=batch_size,
146
174
  question=question,
147
- validation_check=True,
148
- task_validation_model=SentimentAnalysisOutput,
175
+ integrity_check=True,
149
176
  system_prompt=system_prompt,
177
+ concurrency=concurrency,
150
178
  )
151
179
 
152
- return processed_rows, unprocessable_rows
180
+ return sentiment, unprocessable
153
181
 
154
182
 
155
183
  async def theme_generation(
156
184
  responses_df: pd.DataFrame,
157
- llm: Runnable,
185
+ llm: RunnableWithFallbacks,
158
186
  question: str,
159
187
  batch_size: int = 50,
160
188
  partition_key: str | None = "position",
161
189
  prompt_template: str | Path | PromptTemplate = "theme_generation",
162
190
  system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
191
+ concurrency: int = 10,
163
192
  ) -> tuple[pd.DataFrame, pd.DataFrame]:
164
193
  """Generate themes from survey responses using an LLM.
165
194
 
@@ -168,7 +197,7 @@ async def theme_generation(
168
197
  Args:
169
198
  responses_df (pd.DataFrame): DataFrame containing survey responses.
170
199
  Must include 'response_id' and 'response' columns.
171
- llm (Runnable): Language model instance to use for theme generation.
200
+ llm (RunnableWithFallbacks): Language model instance to use for theme generation.
172
201
  question (str): The survey question.
173
202
  batch_size (int, optional): Number of responses to process in each batch.
174
203
  Defaults to 50.
@@ -181,6 +210,7 @@ async def theme_generation(
181
210
  or PromptTemplate instance. Defaults to "theme_generation".
182
211
  system_prompt (str): System prompt to guide the LLM's behavior.
183
212
  Defaults to CONSULTATION_SYSTEM_PROMPT.
213
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
184
214
 
185
215
  Returns:
186
216
  tuple[pd.DataFrame, pd.DataFrame]:
@@ -193,22 +223,24 @@ async def theme_generation(
193
223
  generated_themes, _ = await batch_and_run(
194
224
  responses_df,
195
225
  prompt_template,
196
- llm,
226
+ llm.with_structured_output(ThemeGenerationResponses),
197
227
  batch_size=batch_size,
198
228
  partition_key=partition_key,
199
229
  question=question,
200
230
  system_prompt=system_prompt,
231
+ concurrency=concurrency,
201
232
  )
202
233
  return generated_themes, _
203
234
 
204
235
 
205
236
  async def theme_condensation(
206
237
  themes_df: pd.DataFrame,
207
- llm: Runnable,
238
+ llm: RunnableWithFallbacks,
208
239
  question: str,
209
240
  batch_size: int = 75,
210
241
  prompt_template: str | Path | PromptTemplate = "theme_condensation",
211
242
  system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
243
+ concurrency: int = 10,
212
244
  **kwargs,
213
245
  ) -> tuple[pd.DataFrame, pd.DataFrame]:
214
246
  """Condense and combine similar themes identified from survey responses.
@@ -219,7 +251,7 @@ async def theme_condensation(
219
251
  Args:
220
252
  themes_df (pd.DataFrame): DataFrame containing the initial themes identified
221
253
  from survey responses.
222
- llm (Runnable): Language model instance to use for theme condensation.
254
+ llm (RunnableWithFallbacks): Language model instance to use for theme condensation.
223
255
  question (str): The survey question.
224
256
  batch_size (int, optional): Number of themes to process in each batch.
225
257
  Defaults to 100.
@@ -228,6 +260,7 @@ async def theme_condensation(
228
260
  or PromptTemplate instance. Defaults to "theme_condensation".
229
261
  system_prompt (str): System prompt to guide the LLM's behavior.
230
262
  Defaults to CONSULTATION_SYSTEM_PROMPT.
263
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
231
264
 
232
265
  Returns:
233
266
  tuple[pd.DataFrame, pd.DataFrame]:
@@ -247,10 +280,11 @@ async def theme_condensation(
247
280
  themes_df, _ = await batch_and_run(
248
281
  themes_df,
249
282
  prompt_template,
250
- llm,
283
+ llm.with_structured_output(ThemeCondensationResponses),
251
284
  batch_size=batch_size,
252
285
  question=question,
253
286
  system_prompt=system_prompt,
287
+ concurrency=concurrency,
254
288
  **kwargs,
255
289
  )
256
290
  themes_df = themes_df.sample(frac=1).reset_index(drop=True)
@@ -263,10 +297,11 @@ async def theme_condensation(
263
297
  themes_df, _ = await batch_and_run(
264
298
  themes_df,
265
299
  prompt_template,
266
- llm,
300
+ llm.with_structured_output(ThemeCondensationResponses),
267
301
  batch_size=batch_size,
268
302
  question=question,
269
303
  system_prompt=system_prompt,
304
+ concurrency=concurrency,
270
305
  **kwargs,
271
306
  )
272
307
 
@@ -276,11 +311,12 @@ async def theme_condensation(
276
311
 
277
312
  async def theme_refinement(
278
313
  condensed_themes_df: pd.DataFrame,
279
- llm: Runnable,
314
+ llm: RunnableWithFallbacks,
280
315
  question: str,
281
316
  batch_size: int = 10000,
282
317
  prompt_template: str | Path | PromptTemplate = "theme_refinement",
283
318
  system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
319
+ concurrency: int = 10,
284
320
  ) -> tuple[pd.DataFrame, pd.DataFrame]:
285
321
  """Refine and standardize condensed themes using an LLM.
286
322
 
@@ -292,7 +328,7 @@ async def theme_refinement(
292
328
  Args:
293
329
  condensed_themes (pd.DataFrame): DataFrame containing the condensed themes
294
330
  from the previous pipeline stage.
295
- llm (Runnable): Language model instance to use for theme refinement.
331
+ llm (RunnableWithFallbacks): Language model instance to use for theme refinement.
296
332
  question (str): The survey question.
297
333
  batch_size (int, optional): Number of themes to process in each batch.
298
334
  Defaults to 10000.
@@ -301,6 +337,7 @@ async def theme_refinement(
301
337
  or PromptTemplate instance. Defaults to "theme_refinement".
302
338
  system_prompt (str): System prompt to guide the LLM's behavior.
303
339
  Defaults to CONSULTATION_SYSTEM_PROMPT.
340
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
304
341
 
305
342
  Returns:
306
343
  tuple[pd.DataFrame, pd.DataFrame]:
@@ -319,22 +356,24 @@ async def theme_refinement(
319
356
  refined_themes, _ = await batch_and_run(
320
357
  condensed_themes_df,
321
358
  prompt_template,
322
- llm,
359
+ llm.with_structured_output(ThemeRefinementResponses),
323
360
  batch_size=batch_size,
324
361
  question=question,
325
362
  system_prompt=system_prompt,
363
+ concurrency=concurrency,
326
364
  )
327
365
  return refined_themes, _
328
366
 
329
367
 
330
368
  async def theme_target_alignment(
331
369
  refined_themes_df: pd.DataFrame,
332
- llm: Runnable,
370
+ llm: RunnableWithFallbacks,
333
371
  question: str,
334
372
  target_n_themes: int = 10,
335
373
  batch_size: int = 10000,
336
374
  prompt_template: str | Path | PromptTemplate = "theme_target_alignment",
337
375
  system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
376
+ concurrency: int = 10,
338
377
  ) -> tuple[pd.DataFrame, pd.DataFrame]:
339
378
  """Align themes to target number using an LLM.
340
379
 
@@ -346,7 +385,7 @@ async def theme_target_alignment(
346
385
  Args:
347
386
  refined_themes_df (pd.DataFrame): DataFrame containing the refined themes
348
387
  from the previous pipeline stage.
349
- llm (Runnable): Language model instance to use for theme alignment.
388
+ llm (RunnableWithFallbacks): Language model instance to use for theme alignment.
350
389
  question (str): The survey question.
351
390
  target_n_themes (int, optional): Target number of themes to consolidate to.
352
391
  Defaults to 10.
@@ -357,6 +396,7 @@ async def theme_target_alignment(
357
396
  or PromptTemplate instance. Defaults to "theme_target_alignment".
358
397
  system_prompt (str): System prompt to guide the LLM's behavior.
359
398
  Defaults to CONSULTATION_SYSTEM_PROMPT.
399
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
360
400
 
361
401
  Returns:
362
402
  tuple[pd.DataFrame, pd.DataFrame]:
@@ -376,23 +416,25 @@ async def theme_target_alignment(
376
416
  aligned_themes, _ = await batch_and_run(
377
417
  refined_themes_df,
378
418
  prompt_template,
379
- llm,
419
+ llm.with_structured_output(ThemeRefinementResponses),
380
420
  batch_size=batch_size,
381
421
  question=question,
382
422
  system_prompt=system_prompt,
383
423
  target_n_themes=target_n_themes,
424
+ concurrency=concurrency,
384
425
  )
385
426
  return aligned_themes, _
386
427
 
387
428
 
388
429
  async def theme_mapping(
389
430
  responses_df: pd.DataFrame,
390
- llm: Runnable,
431
+ llm: RunnableWithFallbacks,
391
432
  question: str,
392
433
  refined_themes_df: pd.DataFrame,
393
434
  batch_size: int = 20,
394
435
  prompt_template: str | Path | PromptTemplate = "theme_mapping",
395
436
  system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
437
+ concurrency: int = 10,
396
438
  ) -> tuple[pd.DataFrame, pd.DataFrame]:
397
439
  """Map survey responses to refined themes using an LLM.
398
440
 
@@ -402,7 +444,7 @@ async def theme_mapping(
402
444
  Args:
403
445
  responses_df (pd.DataFrame): DataFrame containing survey responses.
404
446
  Must include 'response_id' and 'response' columns.
405
- llm (Runnable): Language model instance to use for theme mapping.
447
+ llm (RunnableWithFallbacks): Language model instance to use for theme mapping.
406
448
  question (str): The survey question.
407
449
  refined_themes_df (pd.DataFrame): Single-row DataFrame where each column
408
450
  represents a theme (from theme_refinement stage).
@@ -413,6 +455,7 @@ async def theme_mapping(
413
455
  or PromptTemplate instance. Defaults to "theme_mapping".
414
456
  system_prompt (str): System prompt to guide the LLM's behavior.
415
457
  Defaults to CONSULTATION_SYSTEM_PROMPT.
458
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
416
459
 
417
460
  Returns:
418
461
  tuple[pd.DataFrame, pd.DataFrame]:
@@ -432,17 +475,70 @@ async def theme_mapping(
432
475
  )
433
476
  return transposed_df
434
477
 
435
- mapping, _ = await batch_and_run(
478
+ mapping, unprocessable = await batch_and_run(
436
479
  responses_df,
437
480
  prompt_template,
438
- llm,
481
+ llm.with_structured_output(ThemeMappingResponses),
439
482
  batch_size=batch_size,
440
483
  question=question,
441
484
  refined_themes=transpose_refined_themes(refined_themes_df).to_dict(
442
485
  orient="records"
443
486
  ),
444
- validation_check=True,
445
- task_validation_model=ThemeMappingOutput,
487
+ integrity_check=True,
488
+ system_prompt=system_prompt,
489
+ concurrency=concurrency,
490
+ )
491
+ return mapping, unprocessable
492
+
493
+
494
+ async def detail_detection(
495
+ responses_df: pd.DataFrame,
496
+ llm: RunnableWithFallbacks,
497
+ question: str,
498
+ batch_size: int = 20,
499
+ prompt_template: str | Path | PromptTemplate = "detail_detection",
500
+ system_prompt: str = CONSULTATION_SYSTEM_PROMPT,
501
+ concurrency: int = 10,
502
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
503
+ """Identify responses that provide high-value detailed evidence.
504
+
505
+ This function processes survey responses in batches to analyze their level of detail
506
+ and evidence using a language model. It identifies responses that contain specific
507
+ examples, data, or detailed reasoning that provide strong supporting evidence.
508
+
509
+ Args:
510
+ responses_df (pd.DataFrame): DataFrame containing survey responses to analyze.
511
+ Must contain 'response_id' and 'response' columns.
512
+ llm (RunnableWithFallbacks): Language model instance to use for detail detection.
513
+ question (str): The survey question.
514
+ batch_size (int, optional): Number of responses to process in each batch.
515
+ Defaults to 20.
516
+ prompt_template (str | Path | PromptTemplate, optional): Template for structuring
517
+ the prompt to the LLM. Can be a string identifier, path to template file,
518
+ or PromptTemplate instance. Defaults to "detail_detection".
519
+ system_prompt (str): System prompt to guide the LLM's behavior.
520
+ Defaults to CONSULTATION_SYSTEM_PROMPT.
521
+ concurrency (int): Number of concurrent API calls to make. Defaults to 10.
522
+
523
+ Returns:
524
+ tuple[pd.DataFrame, pd.DataFrame]:
525
+ A tuple containing two DataFrames:
526
+ - The first DataFrame contains the rows that were successfully processed by the LLM
527
+ - The second DataFrame contains the rows that could not be processed by the LLM
528
+
529
+ Note:
530
+ The function uses response_id_integrity_check to ensure responses maintain
531
+ their original order and association after processing.
532
+ """
533
+ logger.info(f"Running detail detection on {len(responses_df)} responses")
534
+ detailed, _ = await batch_and_run(
535
+ responses_df,
536
+ prompt_template,
537
+ llm.with_structured_output(DetailDetectionResponses),
538
+ batch_size=batch_size,
539
+ question=question,
540
+ integrity_check=True,
446
541
  system_prompt=system_prompt,
542
+ concurrency=concurrency,
447
543
  )
448
- return mapping, _
544
+ return detailed, _
@@ -1,17 +1,16 @@
1
1
  import asyncio
2
- import json
3
2
  import logging
4
3
  import os
5
4
  from dataclasses import dataclass
6
5
  from pathlib import Path
7
- from typing import Any, Optional, Type
6
+ from typing import Any, Optional
8
7
 
9
8
  import openai
10
9
  import pandas as pd
11
10
  import tiktoken
12
11
  from langchain_core.prompts import PromptTemplate
13
12
  from langchain_core.runnables import Runnable
14
- from pydantic import BaseModel, ValidationError
13
+ from pydantic import ValidationError
15
14
  from tenacity import (
16
15
  before,
17
16
  before_sleep_log,
@@ -35,8 +34,8 @@ async def batch_and_run(
35
34
  llm: Runnable,
36
35
  batch_size: int = 10,
37
36
  partition_key: str | None = None,
38
- validation_check: bool = False,
39
- task_validation_model: Type[BaseModel] = None,
37
+ integrity_check: bool = False,
38
+ concurrency: int = 10,
40
39
  **kwargs: Any,
41
40
  ) -> tuple[pd.DataFrame, pd.DataFrame]:
42
41
  """Process a DataFrame of responses in batches using an LLM.
@@ -51,11 +50,11 @@ async def batch_and_run(
51
50
  Defaults to 10.
52
51
  partition_key (str | None, optional): Optional column name to group input rows
53
52
  before batching. Defaults to None.
54
- validation_check (bool, optional): If True, verifies that all input
55
- response IDs are present in LLM output and validates the rows against the validation model,
56
- failed rows are retried individually.
53
+ integrity_check (bool, optional): If True, verifies that all input
54
+ response IDs are present in LLM output.
57
55
  If False, no integrity checking or retrying occurs. Defaults to False.
58
- task_validation_model (Type[BaseModel]): the pydanctic model to validate each row against
56
+ concurrency (int, optional): Maximum number of simultaneous LLM calls allowed.
57
+ Defaults to 10.
59
58
  **kwargs (Any): Additional keyword arguments to pass to the prompt template.
60
59
 
61
60
  Returns:
@@ -80,8 +79,8 @@ async def batch_and_run(
80
79
  processed_rows, failed_ids = await call_llm(
81
80
  batch_prompts=batch_prompts,
82
81
  llm=llm,
83
- validation_check=validation_check,
84
- task_validation_model=task_validation_model,
82
+ integrity_check=integrity_check,
83
+ concurrency=concurrency,
85
84
  )
86
85
  processed_results = process_llm_responses(processed_rows, input_df)
87
86
 
@@ -93,8 +92,8 @@ async def batch_and_run(
93
92
  retry_results, unprocessable_ids = await call_llm(
94
93
  batch_prompts=retry_prompts,
95
94
  llm=llm,
96
- validation_check=validation_check,
97
- task_validation_model=task_validation_model,
95
+ integrity_check=integrity_check,
96
+ concurrency=concurrency,
98
97
  )
99
98
  retry_processed_results = process_llm_responses(retry_results, retry_df)
100
99
  unprocessable_df = retry_df.loc[retry_df["response_id"].isin(unprocessable_ids)]
@@ -287,32 +286,9 @@ async def call_llm(
287
286
  batch_prompts: list[BatchPrompt],
288
287
  llm: Runnable,
289
288
  concurrency: int = 10,
290
- validation_check: bool = False,
291
- task_validation_model: Optional[Type[BaseModel]] = None,
289
+ integrity_check: bool = False,
292
290
  ) -> tuple[list[dict], list[int]]:
293
- """Process multiple batches of prompts concurrently through an LLM with retry logic.
294
-
295
- Args:
296
- batch_prompts (list[BatchPrompt]): List of BatchPrompt objects, each containing a
297
- prompt string and associated response IDs to be processed.
298
- llm (Runnable): LangChain Runnable instance that will process the prompts.
299
- concurrency (int, optional): Maximum number of simultaneous LLM calls allowed.
300
- Defaults to 10.
301
- validation_check (bool, optional): If True, verifies that all input
302
- response IDs are present in the LLM output. Failed batches are discarded and
303
- their IDs are returned for retry. Defaults to False.
304
- task_validation_model (Type[BaseModel]): The Pydantic model to check the LLM outputs against
305
-
306
- Returns:
307
- tuple[list[dict[str, Any]], set[str]]: A tuple containing:
308
- - list of successful LLM responses as dictionaries
309
- - set of failed response IDs (empty if no failures or integrity check is False)
310
-
311
- Notes:
312
- - Uses exponential backoff retry strategy with up to 6 attempts per batch
313
- - Failed batches (when integrity check fails) return None and are filtered out
314
- - Concurrency is managed via asyncio.Semaphore to prevent overwhelming the LLM
315
- """
291
+ """Process multiple batches of prompts concurrently through an LLM with retry logic."""
316
292
  semaphore = asyncio.Semaphore(concurrency)
317
293
 
318
294
  @retry(
@@ -326,24 +302,30 @@ async def call_llm(
326
302
  async with semaphore:
327
303
  try:
328
304
  llm_response = await llm.ainvoke(batch_prompt.prompt_string)
329
- all_results = json.loads(llm_response.content)
330
- except (openai.BadRequestError, json.JSONDecodeError) as e:
331
- failed_ids = batch_prompt.response_ids
305
+ all_results = (
306
+ llm_response.dict()
307
+ if hasattr(llm_response, "dict")
308
+ else llm_response
309
+ )
310
+ responses = (
311
+ all_results["responses"]
312
+ if isinstance(all_results, dict)
313
+ else all_results.responses
314
+ )
315
+ except (openai.BadRequestError, ValueError) as e:
332
316
  logger.warning(e)
333
- return [], failed_ids
317
+ return [], batch_prompt.response_ids
318
+ except ValidationError as e:
319
+ logger.warning(e)
320
+ return [], batch_prompt.response_ids
334
321
 
335
- if validation_check:
322
+ if integrity_check:
336
323
  failed_ids = get_missing_response_ids(
337
324
  batch_prompt.response_ids, all_results
338
325
  )
339
- validated_results, invalid_rows = validate_task_data(
340
- all_results["responses"], task_validation_model
341
- )
342
- failed_ids.extend([r["response_id"] for r in invalid_rows])
343
- return validated_results, failed_ids
326
+ return responses, failed_ids
344
327
  else:
345
- # Flatten the list to align with valid output format
346
- return [r for r in all_results["responses"]], []
328
+ return responses, []
347
329
 
348
330
  results = await asyncio.gather(
349
331
  *[async_llm_call(batch_prompt) for batch_prompt in batch_prompts]
@@ -458,33 +440,3 @@ def build_prompt(
458
440
  )
459
441
  response_ids = input_batch["response_id"].astype(int).to_list()
460
442
  return BatchPrompt(prompt_string=prompt, response_ids=response_ids)
461
-
462
-
463
- def validate_task_data(
464
- task_data: pd.DataFrame | list[dict], task_validation_model: Type[BaseModel] = None
465
- ) -> tuple[list[dict], list[dict]]:
466
- """
467
- Validate each row in task_output against the provided Pydantic model.
468
-
469
- Returns:
470
- valid: a list of validated records (dicts).
471
- invalid: a list of records (dicts) that failed validation.
472
- """
473
-
474
- records = (
475
- task_data.to_dict(orient="records")
476
- if isinstance(task_data, pd.DataFrame)
477
- else task_data
478
- )
479
-
480
- if task_validation_model:
481
- valid_records, invalid_records = [], []
482
- for record in records:
483
- try:
484
- task_validation_model(**record)
485
- valid_records.append(record)
486
- except ValidationError as e:
487
- invalid_records.append(record)
488
- logger.info(f"Failed Validation: {e}")
489
- return valid_records, invalid_records
490
- return records, []