data-science-document-ai 1.42.5__py3-none-any.whl → 1.56.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {data_science_document_ai-1.42.5.dist-info → data_science_document_ai-1.56.1.dist-info}/METADATA +2 -2
  2. {data_science_document_ai-1.42.5.dist-info → data_science_document_ai-1.56.1.dist-info}/RECORD +34 -31
  3. src/constants.py +7 -10
  4. src/docai_processor_config.yaml +0 -56
  5. src/excel_processing.py +24 -14
  6. src/io.py +23 -0
  7. src/llm.py +0 -29
  8. src/pdf_processing.py +156 -51
  9. src/postprocessing/common.py +172 -28
  10. src/postprocessing/postprocess_partner_invoice.py +194 -59
  11. src/prompts/library/arrivalNotice/other/placeholders.json +70 -0
  12. src/prompts/library/arrivalNotice/other/prompt.txt +40 -0
  13. src/prompts/library/bundeskasse/other/placeholders.json +5 -5
  14. src/prompts/library/bundeskasse/other/prompt.txt +7 -5
  15. src/prompts/library/commercialInvoice/other/placeholders.json +125 -0
  16. src/prompts/library/commercialInvoice/other/prompt.txt +1 -1
  17. src/prompts/library/customsAssessment/other/placeholders.json +70 -0
  18. src/prompts/library/customsAssessment/other/prompt.txt +24 -37
  19. src/prompts/library/customsInvoice/other/prompt.txt +4 -3
  20. src/prompts/library/deliveryOrder/other/placeholders.json +80 -27
  21. src/prompts/library/deliveryOrder/other/prompt.txt +26 -40
  22. src/prompts/library/draftMbl/other/placeholders.json +33 -33
  23. src/prompts/library/draftMbl/other/prompt.txt +34 -44
  24. src/prompts/library/finalMbL/other/placeholders.json +80 -0
  25. src/prompts/library/finalMbL/other/prompt.txt +34 -44
  26. src/prompts/library/packingList/other/placeholders.json +98 -0
  27. src/prompts/library/partnerInvoice/other/prompt.txt +8 -7
  28. src/prompts/library/preprocessing/carrier/placeholders.json +0 -16
  29. src/prompts/library/shippingInstruction/other/placeholders.json +115 -0
  30. src/prompts/library/shippingInstruction/other/prompt.txt +26 -14
  31. src/prompts/prompt_library.py +0 -4
  32. src/setup.py +15 -16
  33. src/utils.py +120 -68
  34. src/prompts/library/draftMbl/hapag-lloyd/prompt.txt +0 -45
  35. src/prompts/library/draftMbl/maersk/prompt.txt +0 -19
  36. src/prompts/library/finalMbL/hapag-lloyd/prompt.txt +0 -44
  37. src/prompts/library/finalMbL/maersk/prompt.txt +0 -19
  38. {data_science_document_ai-1.42.5.dist-info → data_science_document_ai-1.56.1.dist-info}/WHEEL +0 -0
src/pdf_processing.py CHANGED
@@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
9
9
  import asyncio
10
10
  from collections import defaultdict
11
11
 
12
+ from ddtrace import tracer
12
13
  from fastapi import HTTPException
13
14
  from google.cloud.documentai_v1 import Document as docaiv1_document
14
15
 
@@ -31,9 +32,10 @@ from src.postprocessing.postprocess_partner_invoice import (
31
32
  from src.prompts.prompt_library import prompt_library
32
33
  from src.utils import (
33
34
  extract_top_pages,
34
- generate_schema_structure,
35
+ get_pdf_page_count,
35
36
  get_processor_name,
36
37
  run_background_tasks,
38
+ split_pdf_into_chunks,
37
39
  transform_schema_strings,
38
40
  validate_based_on_schema,
39
41
  )
@@ -193,38 +195,32 @@ async def process_file_w_llm(params, file_content, input_doc_type, llm_client):
193
195
  result (dict): The structured data extracted from the document, formatted as JSON.
194
196
  """
195
197
  # Bundeskasse invoices contains all the required information in the first 3 pages.
196
- file_content = (
197
- extract_top_pages(file_content, num_pages=5)
198
- if input_doc_type == "bundeskasse"
199
- else file_content
200
- )
198
+ if input_doc_type == "bundeskasse":
199
+ file_content = extract_top_pages(file_content, num_pages=5)
201
200
 
202
- # convert file_content to required document
203
- document = llm_client.prepare_document_for_gemini(file_content)
201
+ number_of_pages = get_pdf_page_count(file_content)
202
+ logger.info(f"processing {input_doc_type} with {number_of_pages} pages...")
204
203
 
205
- # get the schema placeholder from the Doc AI and generate the response structure
206
- response_schema = (
207
- prompt_library.library[input_doc_type]["other"]["placeholders"]
208
- if input_doc_type in ["partnerInvoice", "customsInvoice", "bundeskasse"]
209
- else generate_schema_structure(params, input_doc_type)
210
- )
204
+ # get the schema placeholder
205
+ response_schema = prompt_library.library[input_doc_type]["other"]["placeholders"]
211
206
 
212
207
  carrier = "other"
213
- if (
214
- "preprocessing" in prompt_library.library.keys()
215
- and "carrier" in prompt_library.library["preprocessing"].keys()
216
- and input_doc_type
217
- in prompt_library.library["preprocessing"]["carrier"]["placeholders"].keys()
218
- ):
219
- carrier_schema = prompt_library.library["preprocessing"]["carrier"][
220
- "placeholders"
221
- ][input_doc_type]
208
+ carrier_schema = (
209
+ prompt_library.library.get("preprocessing", {})
210
+ .get("carrier", {})
211
+ .get("placeholders", {})
212
+ .get(input_doc_type)
213
+ )
222
214
 
215
+ if carrier_schema:
223
216
  carrier_prompt = prompt_library.library["preprocessing"]["carrier"]["prompt"]
224
217
  carrier_prompt = carrier_prompt.replace(
225
218
  "DOCUMENT_TYPE_PLACEHOLDER", input_doc_type
226
219
  )
227
220
 
221
+ # convert file_content to required document
222
+ document = llm_client.prepare_document_for_gemini(file_content)
223
+
228
224
  # identify carrier for customized prompting
229
225
  carrier = await identify_carrier(
230
226
  document,
@@ -234,30 +230,119 @@ async def process_file_w_llm(params, file_content, input_doc_type, llm_client):
234
230
  doc_type=input_doc_type,
235
231
  )
236
232
 
237
- if input_doc_type == "bookingConfirmation":
238
- response_schema = prompt_library.library[input_doc_type][carrier][
239
- "placeholders"
240
- ]
241
-
233
+ # Select prompt
242
234
  if (
243
- input_doc_type in prompt_library.library.keys()
244
- and carrier in prompt_library.library[input_doc_type].keys()
235
+ input_doc_type not in prompt_library.library
236
+ or carrier not in prompt_library.library[input_doc_type]
245
237
  ):
246
- # get the related prompt from predefined prompt library
247
- prompt = prompt_library.library[input_doc_type][carrier]["prompt"]
248
-
249
- # generate the result with LLM (gemini)
250
- result = await llm_client.get_unified_json_genai(
251
- prompt=prompt,
252
- document=document,
253
- response_schema=response_schema,
254
- doc_type=input_doc_type,
238
+ return {}
239
+
240
+ # get the related prompt from predefined prompt library
241
+ prompt = prompt_library.library[input_doc_type][carrier]["prompt"]
242
+
243
+ # Add page-number extraction for moderately large docs
244
+ use_chunking = number_of_pages >= params["chunk_after"]
245
+
246
+ # Update schema and prompt to extract value-page_number pairs
247
+ if not use_chunking and number_of_pages > 1:
248
+ response_schema = transform_schema_strings(response_schema)
249
+ prompt += "\nFor each field, provide the page number where the information was found. The page numbering starts from 0."
250
+
251
+ tasks = []
252
+ # Process in chunks if number of pages exceeds threshold and Process all chunks concurrently
253
+ for chunk in (
254
+ split_pdf_into_chunks(file_content, chunk_size=params["chunk_size"])
255
+ if use_chunking
256
+ else [file_content]
257
+ ):
258
+ tasks.append(
259
+ process_chunk_with_retry(
260
+ chunk,
261
+ prompt,
262
+ response_schema,
263
+ llm_client,
264
+ input_doc_type,
265
+ )
255
266
  )
256
267
 
257
- result = llm_prediction_to_tuples(result)
268
+ results = await asyncio.gather(*tasks, return_exceptions=True)
258
269
 
259
- return result
260
- return {}
270
+ if use_chunking:
271
+ return merge_llm_results(results, response_schema)
272
+ else:
273
+ return llm_prediction_to_tuples(results[0], number_of_pages=number_of_pages)
274
+
275
+
276
+ async def process_chunk_with_retry(
277
+ chunk_content, prompt, response_schema, llm_client, input_doc_type, retries=2
278
+ ):
279
+ """Process a chunk with retries in case of failure."""
280
+ for attempt in range(1, retries + 1):
281
+ try:
282
+ return await process_chunk(
283
+ chunk_content=chunk_content,
284
+ prompt=prompt,
285
+ response_schema=response_schema,
286
+ llm_client=llm_client,
287
+ input_doc_type=input_doc_type,
288
+ )
289
+ except Exception as e:
290
+ logger.error(f"Chunk failed on attempt {attempt}: {e}")
291
+ if attempt == retries:
292
+ raise
293
+ await asyncio.sleep(1) # small backoff
294
+
295
+
296
+ async def process_chunk(
297
+ chunk_content, prompt, response_schema, llm_client, input_doc_type
298
+ ):
299
+ """Process a chunk with Gemini."""
300
+ document = llm_client.prepare_document_for_gemini(chunk_content)
301
+ return await llm_client.get_unified_json_genai(
302
+ prompt=prompt,
303
+ document=document,
304
+ response_schema=response_schema,
305
+ doc_type=input_doc_type,
306
+ )
307
+
308
+
309
+ def merge_llm_results(results, response_schema):
310
+ """Merge LLM results from multiple chunks."""
311
+ merged = {}
312
+ for i, result in enumerate(results):
313
+ if not isinstance(result, dict):
314
+ continue
315
+ # Add page number to all values coming from this chunk
316
+ result = llm_prediction_to_tuples(result, number_of_pages=1, page_number=i)
317
+
318
+ # Merge the result into the final merged dictionary
319
+ for key, value in result.items():
320
+ field_type = (
321
+ response_schema["properties"].get(key, {}).get("type", "").upper()
322
+ )
323
+
324
+ if key not in merged:
325
+ if field_type == "ARRAY":
326
+ # append the values as a list
327
+ merged[key] = (
328
+ value if isinstance(value, list) else ([value] if value else [])
329
+ )
330
+ else:
331
+ merged[key] = value
332
+ continue
333
+
334
+ if field_type == "ARRAY":
335
+ # append list contents across chunks
336
+ if isinstance(value, list):
337
+ merged[key].extend(value)
338
+ else:
339
+ merged[key].append(value)
340
+
341
+ # take first non-null value only
342
+ if merged[key] in (None, "", [], {}):
343
+ merged[key] = value
344
+
345
+ return merged
261
346
 
262
347
 
263
348
  async def extract_data_from_pdf_w_llm(params, input_doc_type, file_content, llm_client):
@@ -334,15 +419,9 @@ async def extract_data_by_doctype(
334
419
  processor_client,
335
420
  if_use_docai,
336
421
  if_use_llm,
422
+ llm_client,
337
423
  isBetaTest=False,
338
424
  ):
339
- # Select LLM client (Using 2.5 Flash model for Bundeskasse)
340
- llm_client = (
341
- params["LlmClient_Flash"]
342
- if input_doc_type == "bundeskasse"
343
- else params["LlmClient"]
344
- )
345
-
346
425
  async def extract_w_docai():
347
426
  return await extract_data_from_pdf_w_docai(
348
427
  params=params,
@@ -391,6 +470,7 @@ async def data_extraction_manual_flow(
391
470
  meta,
392
471
  processor_client,
393
472
  schema_client,
473
+ use_default_logging=False,
394
474
  ):
395
475
  """
396
476
  Process a PDF file and extract data from it.
@@ -411,6 +491,15 @@ async def data_extraction_manual_flow(
411
491
  """
412
492
  # Get the start time for processing
413
493
  start_time = asyncio.get_event_loop().time()
494
+
495
+ # Select LLM client (Using 2.5 Pro model only for PI and customsInvoice)
496
+ llm_client = (
497
+ params["LlmClient_Flash"]
498
+ if meta.documentTypeCode not in ["customsInvoice", "partnerInvoice"]
499
+ else params["LlmClient"]
500
+ )
501
+
502
+ page_count = None
414
503
  # Validate the file type
415
504
  if mime_type == "application/pdf":
416
505
  # Enable Doc Ai only for certain document types.
@@ -432,8 +521,10 @@ async def data_extraction_manual_flow(
432
521
  processor_client,
433
522
  if_use_docai=if_use_docai,
434
523
  if_use_llm=if_use_llm,
524
+ llm_client=llm_client,
435
525
  isBetaTest=False,
436
526
  )
527
+ page_count = get_pdf_page_count(file_content)
437
528
 
438
529
  elif "excel" in mime_type or "spreadsheet" in mime_type:
439
530
  # Extract data from the Excel file
@@ -442,8 +533,19 @@ async def data_extraction_manual_flow(
442
533
  input_doc_type=meta.documentTypeCode,
443
534
  file_content=file_content,
444
535
  mime_type=mime_type,
536
+ llm_client=llm_client,
445
537
  )
446
538
 
539
+ # Get sheet count from dd-trace span (set in extract_data_from_excel)
540
+ # Note: we use the span metric instead of len(extracted_data) because
541
+ # some sheets may fail extraction and not appear in extracted_data
542
+ span = tracer.current_span()
543
+ page_count = span.get_metric("est_page_count") if span else len(extracted_data)
544
+ if page_count > 100:
545
+ logger.warning(
546
+ f"Check logic. Count of sheets in excel file is weirdly large: {page_count}"
547
+ )
548
+
447
549
  else:
448
550
  raise HTTPException(
449
551
  status_code=400,
@@ -451,7 +553,7 @@ async def data_extraction_manual_flow(
451
553
  )
452
554
  # Create the result dictionary with the extracted data
453
555
  extracted_data = await format_all_entities(
454
- extracted_data, meta.documentTypeCode, params
556
+ extracted_data, meta.documentTypeCode, params, mime_type
455
557
  )
456
558
  result = {
457
559
  "id": meta.id,
@@ -466,7 +568,9 @@ async def data_extraction_manual_flow(
466
568
  logger.info(f"Time taken to process the document: {round(elapsed_time, 4)} seconds")
467
569
 
468
570
  # Schedule background tasks without using FastAPI's BackgroundTasks
469
- if os.getenv("CLUSTER") != "ode": # skip data export to bigquery in ODE environment
571
+ if (
572
+ os.getenv("CLUSTER") != "ode"
573
+ ) & use_default_logging: # skip data export to bigquery in ODE environment
470
574
  asyncio.create_task(
471
575
  run_background_tasks(
472
576
  params,
@@ -477,6 +581,7 @@ async def data_extraction_manual_flow(
477
581
  processor_version,
478
582
  mime_type,
479
583
  elapsed_time,
584
+ page_count,
480
585
  )
481
586
  )
482
587
  return result
@@ -12,7 +12,7 @@ from src.constants import formatting_rules
12
12
  from src.io import logger
13
13
  from src.postprocessing.postprocess_partner_invoice import process_partner_invoice
14
14
  from src.prompts.prompt_library import prompt_library
15
- from src.utils import get_tms_mappings
15
+ from src.utils import batch_fetch_all_mappings, get_tms_mappings
16
16
 
17
17
  tms_domain = os.environ["TMS_DOMAIN"]
18
18
 
@@ -134,8 +134,11 @@ def extract_number(data_field_value):
134
134
  formatted_value: string
135
135
 
136
136
  """
137
+ # Remove container size pattern like 20FT, 40HC, etc from 1 x 40HC
138
+ value = remove_unwanted_patterns(data_field_value)
139
+
137
140
  formatted_value = ""
138
- for c in data_field_value:
141
+ for c in value:
139
142
  if c.isnumeric() or c in [",", ".", "-"]:
140
143
  formatted_value += c
141
144
 
@@ -319,6 +322,14 @@ def remove_unwanted_patterns(lineitem: str):
319
322
  # Remove "HIGH CUBE"
320
323
  lineitem = lineitem.replace("HIGH CUBE", "")
321
324
 
325
+ # Remove container size e.g., 20FT, 40HC, etc.
326
+ pattern = [
327
+ f"{s}{t}"
328
+ for s in ("20|22|40|45".split("|"))
329
+ for t in ("FT|HC|DC|HD|GP|OT|RF|FR|TK|DV".split("|"))
330
+ ]
331
+ lineitem = re.sub(r"|".join(pattern), "", lineitem, flags=re.IGNORECASE).strip()
332
+
322
333
  return lineitem
323
334
 
324
335
 
@@ -349,42 +360,75 @@ def clean_item_description(lineitem: str, remove_numbers: bool = True):
349
360
  # Remove the currency codes
350
361
  lineitem = re.sub(currency_codes_pattern, "", lineitem, flags=re.IGNORECASE)
351
362
 
363
+ # remove other patterns
364
+ lineitem = remove_unwanted_patterns(lineitem)
365
+
352
366
  # Remove numbers from the line item
353
367
  if (
354
368
  remove_numbers
355
369
  ): # Do not remove numbers for the reverse charge sentence as it contains Article number
356
370
  lineitem = re.sub(r"\d+", "", lineitem)
357
371
 
358
- # remove other patterns
359
- lineitem = remove_unwanted_patterns(lineitem)
360
-
361
372
  # remove special chars
362
373
  lineitem = re.sub(r"[^A-Za-z0-9\s]", " ", lineitem).strip()
363
374
 
375
+ # Remove x from lineitem like 10 x
376
+ lineitem = re.sub(r"\b[xX]\b", " ", lineitem).strip()
377
+
364
378
  return re.sub(r"\s{2,}", " ", lineitem).strip()
365
379
 
366
380
 
367
- async def format_label(entity_k, entity_value, document_type_code, params):
381
+ async def format_label(
382
+ entity_k,
383
+ entity_value,
384
+ document_type_code,
385
+ params,
386
+ mime_type,
387
+ container_map,
388
+ terminal_map,
389
+ depot_map,
390
+ ):
368
391
  llm_client = params["LlmClient"]
369
392
  if isinstance(entity_value, dict): # if it's a nested entity
370
393
  format_tasks = [
371
- format_label(sub_k, sub_v, document_type_code, params)
394
+ format_label(
395
+ sub_k,
396
+ sub_v,
397
+ document_type_code,
398
+ params,
399
+ mime_type,
400
+ container_map,
401
+ terminal_map,
402
+ depot_map,
403
+ )
372
404
  for sub_k, sub_v in entity_value.items()
373
405
  ]
374
406
  return entity_k, {k: v for k, v in await asyncio.gather(*format_tasks)}
375
407
  if isinstance(entity_value, list):
376
408
  format_tasks = await asyncio.gather(
377
409
  *[
378
- format_label(entity_k, sub_v, document_type_code, params)
410
+ format_label(
411
+ entity_k,
412
+ sub_v,
413
+ document_type_code,
414
+ params,
415
+ mime_type,
416
+ container_map,
417
+ terminal_map,
418
+ depot_map,
419
+ )
379
420
  for sub_v in entity_value
380
421
  ]
381
422
  )
382
423
  return entity_k, [v for _, v in format_tasks]
383
- if isinstance(entity_value, tuple):
384
- page = entity_value[1]
385
- entity_value = entity_value[0]
386
- else:
387
- page = -1
424
+
425
+ if mime_type == "application/pdf":
426
+ if isinstance(entity_value, tuple):
427
+ page = entity_value[1]
428
+ entity_value = entity_value[0]
429
+ else:
430
+ page = -1
431
+
388
432
  entity_key = entity_k.lower()
389
433
  formatted_value = None
390
434
 
@@ -394,13 +438,13 @@ async def format_label(entity_k, entity_value, document_type_code, params):
394
438
  )
395
439
 
396
440
  elif (entity_key == "containertype") or (entity_key == "containersize"):
397
- formatted_value = get_tms_mappings(entity_value, "container_types")
441
+ formatted_value = container_map.get(entity_value)
398
442
 
399
443
  elif check_formatting_rule(entity_k, document_type_code, "terminal"):
400
- formatted_value = get_tms_mappings(entity_value, "terminals")
444
+ formatted_value = terminal_map.get(entity_value)
401
445
 
402
446
  elif check_formatting_rule(entity_k, document_type_code, "depot"):
403
- formatted_value = get_tms_mappings(entity_value, "depots")
447
+ formatted_value = depot_map.get(entity_value)
404
448
 
405
449
  elif entity_key.startswith(("eta", "etd", "duedate", "issuedate", "servicedate")):
406
450
  try:
@@ -421,7 +465,10 @@ async def format_label(entity_k, entity_value, document_type_code, params):
421
465
  except ValueError as e:
422
466
  logger.info(f"ParserError: {e}")
423
467
 
424
- elif entity_key in ["invoicenumber", "creditnoteinvoicenumber"]:
468
+ elif (
469
+ entity_key in ["invoicenumber", "creditnoteinvoicenumber"]
470
+ and document_type_code == "bundeskasse"
471
+ ):
425
472
  formatted_value = clean_invoice_number(entity_value)
426
473
 
427
474
  elif entity_key in ("shipmentid", "partnerreference"):
@@ -482,8 +529,10 @@ async def format_label(entity_k, entity_value, document_type_code, params):
482
529
  result = {
483
530
  "documentValue": entity_value,
484
531
  "formattedValue": formatted_value,
485
- "page": page,
486
532
  }
533
+ if mime_type == "application/pdf":
534
+ result["page"] = page
535
+
487
536
  return entity_k, result
488
537
 
489
538
 
@@ -491,7 +540,8 @@ async def get_port_code_ai(port: str, llm_client, doc_type=None):
491
540
  """Get port code using AI model."""
492
541
  port_llm = await get_port_code_llm(port, llm_client, doc_type=doc_type)
493
542
 
494
- return get_tms_mappings(port, "ports", port_llm)
543
+ result = await get_tms_mappings(port, "ports", port_llm)
544
+ return result.get(port, None)
495
545
 
496
546
 
497
547
  async def get_port_code_llm(port: str, llm_client, doc_type=None):
@@ -582,7 +632,75 @@ def decimal_convertor(value, quantity=False):
582
632
  return value
583
633
 
584
634
 
585
- async def format_all_entities(result, document_type_code, params):
635
+ async def collect_mapping_requests(entity_value, document_type_code):
636
+ """Collect all unique container types, terminals, and depots from the entity value."""
637
+ # Sets to store unique values
638
+ container_types = set()
639
+ terminals = set()
640
+ depots = set()
641
+
642
+ def walk(key, value):
643
+ key_lower = key.lower()
644
+
645
+ # nested dict
646
+ if isinstance(value, dict):
647
+ for k, v in value.items():
648
+ walk(k, v)
649
+
650
+ # list of values
651
+ elif isinstance(value, list):
652
+ for item in value:
653
+ walk(key, item)
654
+
655
+ # leaf node
656
+ else:
657
+ if key_lower in ("containertype", "containersize"):
658
+ # Take only "20DV" from ('20DV', 0) if it's a tuple
659
+ container_types.add(value[0]) if isinstance(
660
+ value, tuple
661
+ ) else container_types.add(value)
662
+
663
+ elif check_formatting_rule(key, document_type_code, "terminal"):
664
+ terminals.add(value[0]) if isinstance(value, tuple) else terminals.add(
665
+ value
666
+ )
667
+
668
+ elif check_formatting_rule(key, document_type_code, "depot"):
669
+ depots.add(value[0]) if isinstance(value, tuple) else depots.add(value)
670
+
671
+ walk("root", entity_value)
672
+
673
+ return container_types, terminals, depots
674
+
675
+
676
+ async def format_all_labels(entity_data, document_type_code, params, mime_type):
677
+ """Format all labels in the entity data using cached mappings."""
678
+ # Collect all mapping values needed
679
+ container_req, terminal_req, depot_req = await collect_mapping_requests(
680
+ entity_data, document_type_code
681
+ )
682
+
683
+ # Batch fetch mappings
684
+ container_map, terminal_map, depot_map = await batch_fetch_all_mappings(
685
+ container_req, terminal_req, depot_req
686
+ )
687
+
688
+ # Format labels using cached mappings
689
+ _, result = await format_label(
690
+ "root",
691
+ entity_data,
692
+ document_type_code,
693
+ params,
694
+ mime_type,
695
+ container_map,
696
+ terminal_map,
697
+ depot_map,
698
+ )
699
+
700
+ return _, result
701
+
702
+
703
+ async def format_all_entities(result, document_type_code, params, mime_type):
586
704
  """Format the entity values in the result dictionary."""
587
705
  # Since we treat `customsInvoice` same as `partnerInvoice`
588
706
  document_type_code = (
@@ -597,11 +715,13 @@ async def format_all_entities(result, document_type_code, params):
597
715
  return {}
598
716
 
599
717
  # Format all entities recursively
600
- _, aggregated_data = await format_label(None, result, document_type_code, params)
718
+ _, aggregated_data = await format_all_labels(
719
+ result, document_type_code, params, mime_type
720
+ )
601
721
 
602
722
  # Process partner invoice on lineitem mapping and reverse charge sentence
603
723
  if document_type_code in ["partnerInvoice", "bundeskasse"]:
604
- process_partner_invoice(params, aggregated_data, document_type_code)
724
+ await process_partner_invoice(params, aggregated_data, document_type_code)
605
725
 
606
726
  logger.info("Data Extraction completed successfully")
607
727
  return aggregated_data
@@ -633,22 +753,46 @@ def remove_stop_words(lineitem: str):
633
753
  )
634
754
 
635
755
 
636
- def llm_prediction_to_tuples(llm_prediction):
756
+ def llm_prediction_to_tuples(llm_prediction, number_of_pages=-1, page_number=None):
637
757
  """Convert LLM prediction dictionary to tuples of (value, page_number)."""
758
+ # If only 1 page, simply pair each value with page number 0
759
+ if number_of_pages == 1:
760
+ effective_page = 0 if page_number is None else page_number
761
+ if isinstance(llm_prediction, dict):
762
+ return {
763
+ k: llm_prediction_to_tuples(
764
+ v, number_of_pages, page_number=effective_page
765
+ )
766
+ for k, v in llm_prediction.items()
767
+ }
768
+ elif isinstance(llm_prediction, list):
769
+ return [
770
+ llm_prediction_to_tuples(v, number_of_pages, page_number=effective_page)
771
+ for v in llm_prediction
772
+ ]
773
+ else:
774
+ return (llm_prediction, effective_page) if llm_prediction else None
775
+
776
+ # logic for multi-page predictions
638
777
  if isinstance(llm_prediction, dict):
639
778
  if "page_number" in llm_prediction.keys() and "value" in llm_prediction.keys():
640
779
  if llm_prediction["value"]:
641
780
  try:
642
- page_number = int(llm_prediction["page_number"])
781
+ _page_number = int(llm_prediction["page_number"])
643
782
  except: # noqa: E722
644
- page_number = -1
645
- return (llm_prediction["value"], page_number)
783
+ _page_number = -1
784
+ return (llm_prediction["value"], _page_number)
646
785
  return None
786
+
647
787
  for key, value in llm_prediction.items():
648
788
  llm_prediction[key] = llm_prediction_to_tuples(
649
- llm_prediction.get(key, value)
789
+ llm_prediction.get(key, value), number_of_pages, page_number
650
790
  )
791
+
651
792
  elif isinstance(llm_prediction, list):
652
793
  for i, item in enumerate(llm_prediction):
653
- llm_prediction[i] = llm_prediction_to_tuples(item)
794
+ llm_prediction[i] = llm_prediction_to_tuples(
795
+ item, number_of_pages, page_number
796
+ )
797
+
654
798
  return llm_prediction