data-science-document-ai 1.43.7__py3-none-any.whl → 1.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: data-science-document-ai
3
- Version: 1.43.7
3
+ Version: 1.45.0
4
4
  Summary: "Document AI repo for data science"
5
5
  Author: Naomi Nguyen
6
6
  Author-email: naomi.nguyen@forto.com
@@ -1,16 +1,16 @@
1
- src/constants.py,sha256=rpYIecVLIBLh98YrJ8e5gdvM0bqrXJZWIKgFkUSn69g,3513
1
+ src/constants.py,sha256=HKHP9MqkLrC6pHgOt0XX2F8j6kbupXJ4HscClDwMBaM,3656
2
2
  src/constants_sandbox.py,sha256=Iu6HdjCoNSmOX0AwoL9qUQkhq_ZnIN5U9e-Q2UfNuGc,547
3
3
  src/docai.py,sha256=dHuR0ehVjUi1CnoNvdp_yxJtpU_HFXqAZ61ywdz7BEo,5655
4
4
  src/docai_processor_config.yaml,sha256=81NUGs-u8UFJm6mc0ZOeeNQlhe9h0f35GhjTcwErvTA,1717
5
- src/excel_processing.py,sha256=PdypkXHf-hln5cq5TyJ_IVybZk-rJF1NKZ50KXuOSdY,3390
6
- src/io.py,sha256=tOJpMyI-mP1AaXKG4UFudH47MHWzjWBgVahFJUcjGfs,4749
5
+ src/excel_processing.py,sha256=_vP2q1xEIeyjO8TvZlSTeEM-M1PMceyDSuYGfyZeceY,3361
6
+ src/io.py,sha256=rYjXVLlriEacw1uNuPIYhg12bXNu48Qs9GYMY2YcVTE,5563
7
7
  src/llm.py,sha256=OE4IEIqcM-hYK9U7e0x1rAfcqdpeo4iXPHBp64L5Qz0,8199
8
8
  src/log_setup.py,sha256=RhHnpXqcl-ii4EJzRt47CF2R-Q3YPF68tepg_Kg7tkw,2895
9
- src/pdf_processing.py,sha256=DaFM8ioERj7YeC8Yjki_dfSnKt0lf7DB14ks9i4OAfA,17741
10
- src/postprocessing/common.py,sha256=fU3ECfnR0rpF21DnVYM2YM7kPEB4gRJuMasyrNupsaA,23026
9
+ src/pdf_processing.py,sha256=lzvoza9itpEyl-rcBQbIcWuFxUAvF_Qyc-OpuPQWWMk,20354
10
+ src/postprocessing/common.py,sha256=ao9_hnBXgLv4HOyj_6I00CSDGRiwG8IP_HPg_1Yjzmw,25883
11
11
  src/postprocessing/postprocess_booking_confirmation.py,sha256=nK32eDiBNbauyQz0oCa9eraysku8aqzrcoRFoWVumDU,4827
12
12
  src/postprocessing/postprocess_commercial_invoice.py,sha256=3I8ijluTZcOs_sMnFZxfkAPle0UFQ239EMuvZfDZVPg,1028
13
- src/postprocessing/postprocess_partner_invoice.py,sha256=koGR7dN37FqJcepdzkrzNBHuBBUuCp_3CrteScASqyE,10590
13
+ src/postprocessing/postprocess_partner_invoice.py,sha256=LZcMZfJeLdcbYqPemO8gn9SmJxv-NPmb4uVCT3lKg18,12341
14
14
  src/prompts/library/bookingConfirmation/evergreen/placeholders.json,sha256=IpM9nmSPdyroliZfXB1-NDCjiHZX_Ff5BH7-scNhGqE,1406
15
15
  src/prompts/library/bookingConfirmation/evergreen/prompt.txt,sha256=5ivskCG831M2scW3oqQaoltXIyHV-n6DYUygWycXxjw,2755
16
16
  src/prompts/library/bookingConfirmation/hapag-lloyd/placeholders.json,sha256=hMPNt9s3LuxR85AxYy7bPcCDleug6gSwVjefm3ismWY,1405
@@ -53,7 +53,7 @@ src/prompts/library/shippingInstruction/other/prompt.txt,sha256=dT2e-dPuvuz0rVYp
53
53
  src/prompts/prompt_library.py,sha256=VJWHeXN-s501C2GiidIIvQQuZdU6T1R27hE2dKBiI40,2555
54
54
  src/setup.py,sha256=M-p5c8M9ejKcSZ9N86VtmtPc4TYLxe1_4_dxf6jpfVc,7262
55
55
  src/tms.py,sha256=UXbIo1QE--hIX6NZi5Qyp2R_CP338syrY9pCTPrfgnE,1741
56
- src/utils.py,sha256=iUFjfIKXl_MwkPXPMfK0ZAB9aZ__N6e8mWTBbBiPki4,16568
57
- data_science_document_ai-1.43.7.dist-info/METADATA,sha256=lajB-JuTBbL2uMTIlvdZ3rJiw5n9BFzTcXnIEYfgIj4,2152
58
- data_science_document_ai-1.43.7.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
59
- data_science_document_ai-1.43.7.dist-info/RECORD,,
56
+ src/utils.py,sha256=Ow5_Jals88o8mbZ1BoHfZpHZoCfig_UQb5aalH-mpWE,17278
57
+ data_science_document_ai-1.45.0.dist-info/METADATA,sha256=VblAnSZ_nlqjlEJtl0-ETS6tuELw9pThEKwxAxXomjA,2152
58
+ data_science_document_ai-1.45.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
59
+ data_science_document_ai-1.45.0.dist-info/RECORD,,
src/constants.py CHANGED
@@ -26,6 +26,9 @@ project_parameters = {
26
26
  "fuzzy_threshold_item_code": 70,
27
27
  "fuzzy_threshold_reverse_charge": 80,
28
28
  "fuzzy_threshold_invoice_classification": 70,
29
+ # Chunking params
30
+ "chunk_size": 1, # page (do not change this without changing the page number logic)
31
+ "chunk_after": 10, # pages
29
32
  # Big Query
30
33
  "g_ai_gbq_db_schema": "document_ai",
31
34
  "g_ai_gbq_db_table_out": "document_ai_api_calls_v1",
src/excel_processing.py CHANGED
@@ -4,8 +4,6 @@ import logging
4
4
 
5
5
  from ddtrace import tracer
6
6
 
7
- from src.postprocessing.common import llm_prediction_to_tuples
8
-
9
7
  logger = logging.getLogger(__name__)
10
8
 
11
9
  import asyncio
@@ -78,6 +76,7 @@ async def extract_data_from_excel(
78
76
  "bundeskasse",
79
77
  "commercialInvoice",
80
78
  "packingList",
79
+ "bookingConfirmation",
81
80
  ]
82
81
  else generate_schema_structure(params, input_doc_type)
83
82
  )
src/io.py CHANGED
@@ -156,4 +156,27 @@ def download_dir_from_bucket(bucket, directory_cloud, directory_local) -> bool:
156
156
  return result
157
157
 
158
158
 
159
+ def bq_logs(data_to_insert, params):
160
+ """Insert logs into Google BigQuery.
161
+
162
+ Args:
163
+ data_to_insert (list): The data to insert into BigQuery.
164
+ params (dict): The parameters dictionary.
165
+ """
166
+ # Use the pre-initialized BigQuery client
167
+ bq_client = params["bq_client"]
168
+ # Get the table string
169
+ table_string = f"{params['g_ai_project_name']}.{params['g_ai_gbq_db_schema']}.{params['g_ai_gbq_db_table_out']}"
170
+
171
+ logger.info(f"Log table: {table_string}")
172
+ # Insert the rows into the table
173
+ insert_logs = bq_client.insert_rows_json(table_string, data_to_insert)
174
+
175
+ # Check if there were any errors inserting the rows
176
+ if not insert_logs:
177
+ logger.info("New rows have been added.")
178
+ else:
179
+ logger.info("Errors occurred while inserting rows: ", insert_logs)
180
+
181
+
159
182
  # type: ignore
src/pdf_processing.py CHANGED
@@ -36,6 +36,7 @@ from src.utils import (
36
36
  get_pdf_page_count,
37
37
  get_processor_name,
38
38
  run_background_tasks,
39
+ split_pdf_into_chunks,
39
40
  transform_schema_strings,
40
41
  validate_based_on_schema,
41
42
  )
@@ -195,15 +196,11 @@ async def process_file_w_llm(params, file_content, input_doc_type, llm_client):
195
196
  result (dict): The structured data extracted from the document, formatted as JSON.
196
197
  """
197
198
  # Bundeskasse invoices contains all the required information in the first 3 pages.
198
- file_content = (
199
- extract_top_pages(file_content, num_pages=5)
200
- if input_doc_type == "bundeskasse"
201
- else file_content
202
- )
203
- number_of_pages = get_pdf_page_count(file_content)
199
+ if input_doc_type == "bundeskasse":
200
+ file_content = extract_top_pages(file_content, num_pages=5)
204
201
 
205
- # convert file_content to required document
206
- document = llm_client.prepare_document_for_gemini(file_content)
202
+ number_of_pages = get_pdf_page_count(file_content)
203
+ logger.info(f"processing {input_doc_type} with {number_of_pages} pages...")
207
204
 
208
205
  # get the schema placeholder from the Doc AI and generate the response structure
209
206
  response_schema = (
@@ -215,26 +212,28 @@ async def process_file_w_llm(params, file_content, input_doc_type, llm_client):
215
212
  "bundeskasse",
216
213
  "commercialInvoice",
217
214
  "packingList",
215
+ "bookingConfirmation",
218
216
  ]
219
217
  else generate_schema_structure(params, input_doc_type)
220
218
  )
221
219
 
222
220
  carrier = "other"
223
- if (
224
- "preprocessing" in prompt_library.library.keys()
225
- and "carrier" in prompt_library.library["preprocessing"].keys()
226
- and input_doc_type
227
- in prompt_library.library["preprocessing"]["carrier"]["placeholders"].keys()
228
- ):
229
- carrier_schema = prompt_library.library["preprocessing"]["carrier"][
230
- "placeholders"
231
- ][input_doc_type]
221
+ carrier_schema = (
222
+ prompt_library.library.get("preprocessing", {})
223
+ .get("carrier", {})
224
+ .get("placeholders", {})
225
+ .get(input_doc_type)
226
+ )
232
227
 
228
+ if carrier_schema:
233
229
  carrier_prompt = prompt_library.library["preprocessing"]["carrier"]["prompt"]
234
230
  carrier_prompt = carrier_prompt.replace(
235
231
  "DOCUMENT_TYPE_PLACEHOLDER", input_doc_type
236
232
  )
237
233
 
234
+ # convert file_content to required document
235
+ document = llm_client.prepare_document_for_gemini(file_content)
236
+
238
237
  # identify carrier for customized prompting
239
238
  carrier = await identify_carrier(
240
239
  document,
@@ -244,37 +243,115 @@ async def process_file_w_llm(params, file_content, input_doc_type, llm_client):
244
243
  doc_type=input_doc_type,
245
244
  )
246
245
 
247
- if input_doc_type == "bookingConfirmation":
248
- response_schema = prompt_library.library[input_doc_type][carrier][
249
- "placeholders"
250
- ]
251
-
246
+ # Select prompt
252
247
  if (
253
- input_doc_type in prompt_library.library.keys()
254
- and carrier in prompt_library.library[input_doc_type].keys()
248
+ input_doc_type not in prompt_library.library
249
+ or carrier not in prompt_library.library[input_doc_type]
255
250
  ):
256
- # get the related prompt from predefined prompt library
257
- prompt = prompt_library.library[input_doc_type][carrier]["prompt"]
251
+ return {}
258
252
 
259
- # Update schema to extract value-page_number pairs
260
- if number_of_pages > 1:
261
- response_schema = transform_schema_strings(response_schema)
253
+ # get the related prompt from predefined prompt library
254
+ prompt = prompt_library.library[input_doc_type][carrier]["prompt"]
262
255
 
263
- # Update the prompt to instruct LLM to include page numbers
264
- prompt += "\nFor each field, provide the page number where the information was found. The page numbering starts from 0."
256
+ # Add page-number extraction for moderately large docs
257
+ use_chunking = number_of_pages >= params["chunk_after"]
265
258
 
266
- # generate the result with LLM (gemini)
267
- result = await llm_client.get_unified_json_genai(
268
- prompt=prompt,
269
- document=document,
270
- response_schema=response_schema,
271
- doc_type=input_doc_type,
259
+ # Update schema and prompt to extract value-page_number pairs
260
+ if not use_chunking and number_of_pages > 1:
261
+ response_schema = transform_schema_strings(response_schema)
262
+ prompt += "\nFor each field, provide the page number where the information was found. The page numbering starts from 0."
263
+
264
+ tasks = []
265
+ # Process in chunks if number of pages exceeds threshold and Process all chunks concurrently
266
+ for chunk in (
267
+ split_pdf_into_chunks(file_content, chunk_size=params["chunk_size"])
268
+ if use_chunking
269
+ else [file_content]
270
+ ):
271
+ tasks.append(
272
+ process_chunk_with_retry(
273
+ chunk, prompt, response_schema, llm_client, input_doc_type
274
+ )
272
275
  )
273
276
 
274
- result = llm_prediction_to_tuples(result, number_of_pages)
277
+ results = await asyncio.gather(*tasks, return_exceptions=True)
275
278
 
276
- return result
277
- return {}
279
+ if use_chunking:
280
+ return merge_llm_results(results, response_schema)
281
+ else:
282
+ return llm_prediction_to_tuples(results[0], number_of_pages=number_of_pages)
283
+
284
+
285
+ async def process_chunk_with_retry(
286
+ chunk_content, prompt, response_schema, llm_client, input_doc_type, retries=2
287
+ ):
288
+ """Process a chunk with retries in case of failure."""
289
+ for attempt in range(1, retries + 1):
290
+ try:
291
+ return await process_chunk(
292
+ chunk_content=chunk_content,
293
+ prompt=prompt,
294
+ response_schema=response_schema,
295
+ llm_client=llm_client,
296
+ input_doc_type=input_doc_type,
297
+ )
298
+ except Exception as e:
299
+ logger.error(f"Chunk failed on attempt {attempt}: {e}")
300
+ if attempt == retries:
301
+ raise
302
+ await asyncio.sleep(1) # small backoff
303
+
304
+
305
+ async def process_chunk(
306
+ chunk_content, prompt, response_schema, llm_client, input_doc_type
307
+ ):
308
+ """Process a chunk with Gemini."""
309
+ document = llm_client.prepare_document_for_gemini(chunk_content)
310
+ return await llm_client.get_unified_json_genai(
311
+ prompt=prompt,
312
+ document=document,
313
+ response_schema=response_schema,
314
+ doc_type=input_doc_type,
315
+ )
316
+
317
+
318
+ def merge_llm_results(results, response_schema):
319
+ """Merge LLM results from multiple chunks."""
320
+ merged = {}
321
+ for i, result in enumerate(results):
322
+ if not isinstance(result, dict):
323
+ continue
324
+ # Add page number to all values coming from this chunk
325
+ result = llm_prediction_to_tuples(result, number_of_pages=1, page_number=i)
326
+
327
+ # Merge the result into the final merged dictionary
328
+ for key, value in result.items():
329
+ field_type = (
330
+ response_schema["properties"].get(key, {}).get("type", "").upper()
331
+ )
332
+
333
+ if key not in merged:
334
+ if field_type == "ARRAY":
335
+ # append the values as a list
336
+ merged[key] = (
337
+ value if isinstance(value, list) else ([value] if value else [])
338
+ )
339
+ else:
340
+ merged[key] = value
341
+ continue
342
+
343
+ if field_type == "ARRAY":
344
+ # append list contents across chunks
345
+ if isinstance(value, list):
346
+ merged[key].extend(value)
347
+ else:
348
+ merged[key].append(value)
349
+
350
+ # take first non-null value only
351
+ if merged[key] in (None, "", [], {}):
352
+ merged[key] = value
353
+
354
+ return merged
278
355
 
279
356
 
280
357
  async def extract_data_from_pdf_w_llm(params, input_doc_type, file_content, llm_client):
@@ -12,7 +12,7 @@ from src.constants import formatting_rules
12
12
  from src.io import logger
13
13
  from src.postprocessing.postprocess_partner_invoice import process_partner_invoice
14
14
  from src.prompts.prompt_library import prompt_library
15
- from src.utils import get_tms_mappings
15
+ from src.utils import batch_fetch_all_mappings, get_tms_mappings
16
16
 
17
17
  tms_domain = os.environ["TMS_DOMAIN"]
18
18
 
@@ -372,18 +372,45 @@ def clean_item_description(lineitem: str, remove_numbers: bool = True):
372
372
  return re.sub(r"\s{2,}", " ", lineitem).strip()
373
373
 
374
374
 
375
- async def format_label(entity_k, entity_value, document_type_code, params, mime_type):
375
+ async def format_label(
376
+ entity_k,
377
+ entity_value,
378
+ document_type_code,
379
+ params,
380
+ mime_type,
381
+ container_map,
382
+ terminal_map,
383
+ depot_map,
384
+ ):
376
385
  llm_client = params["LlmClient"]
377
386
  if isinstance(entity_value, dict): # if it's a nested entity
378
387
  format_tasks = [
379
- format_label(sub_k, sub_v, document_type_code, params, mime_type)
388
+ format_label(
389
+ sub_k,
390
+ sub_v,
391
+ document_type_code,
392
+ params,
393
+ mime_type,
394
+ container_map,
395
+ terminal_map,
396
+ depot_map,
397
+ )
380
398
  for sub_k, sub_v in entity_value.items()
381
399
  ]
382
400
  return entity_k, {k: v for k, v in await asyncio.gather(*format_tasks)}
383
401
  if isinstance(entity_value, list):
384
402
  format_tasks = await asyncio.gather(
385
403
  *[
386
- format_label(entity_k, sub_v, document_type_code, params, mime_type)
404
+ format_label(
405
+ entity_k,
406
+ sub_v,
407
+ document_type_code,
408
+ params,
409
+ mime_type,
410
+ container_map,
411
+ terminal_map,
412
+ depot_map,
413
+ )
387
414
  for sub_v in entity_value
388
415
  ]
389
416
  )
@@ -405,13 +432,13 @@ async def format_label(entity_k, entity_value, document_type_code, params, mime_
405
432
  )
406
433
 
407
434
  elif (entity_key == "containertype") or (entity_key == "containersize"):
408
- formatted_value = get_tms_mappings(entity_value, "container_types")
435
+ formatted_value = container_map.get(entity_value)
409
436
 
410
437
  elif check_formatting_rule(entity_k, document_type_code, "terminal"):
411
- formatted_value = get_tms_mappings(entity_value, "terminals")
438
+ formatted_value = terminal_map.get(entity_value)
412
439
 
413
440
  elif check_formatting_rule(entity_k, document_type_code, "depot"):
414
- formatted_value = get_tms_mappings(entity_value, "depots")
441
+ formatted_value = depot_map.get(entity_value)
415
442
 
416
443
  elif entity_key.startswith(("eta", "etd", "duedate", "issuedate", "servicedate")):
417
444
  try:
@@ -507,7 +534,8 @@ async def get_port_code_ai(port: str, llm_client, doc_type=None):
507
534
  """Get port code using AI model."""
508
535
  port_llm = await get_port_code_llm(port, llm_client, doc_type=doc_type)
509
536
 
510
- return get_tms_mappings(port, "ports", port_llm)
537
+ result = await get_tms_mappings(port, "ports", port_llm)
538
+ return result.get(port, None)
511
539
 
512
540
 
513
541
  async def get_port_code_llm(port: str, llm_client, doc_type=None):
@@ -598,6 +626,74 @@ def decimal_convertor(value, quantity=False):
598
626
  return value
599
627
 
600
628
 
629
+ async def collect_mapping_requests(entity_value, document_type_code):
630
+ """Collect all unique container types, terminals, and depots from the entity value."""
631
+ # Sets to store unique values
632
+ container_types = set()
633
+ terminals = set()
634
+ depots = set()
635
+
636
+ def walk(key, value):
637
+ key_lower = key.lower()
638
+
639
+ # nested dict
640
+ if isinstance(value, dict):
641
+ for k, v in value.items():
642
+ walk(k, v)
643
+
644
+ # list of values
645
+ elif isinstance(value, list):
646
+ for item in value:
647
+ walk(key, item)
648
+
649
+ # leaf node
650
+ else:
651
+ if key_lower in ("containertype", "containersize"):
652
+ # Take only "20DV" from ('20DV', 0) if it's a tuple
653
+ container_types.add(value[0]) if isinstance(
654
+ value, tuple
655
+ ) else container_types.add(value)
656
+
657
+ elif check_formatting_rule(key, document_type_code, "terminal"):
658
+ terminals.add(value[0]) if isinstance(value, tuple) else terminals.add(
659
+ value
660
+ )
661
+
662
+ elif check_formatting_rule(key, document_type_code, "depot"):
663
+ depots.add(value[0]) if isinstance(value, tuple) else depots.add(value)
664
+
665
+ walk("root", entity_value)
666
+
667
+ return container_types, terminals, depots
668
+
669
+
670
+ async def format_all_labels(entity_data, document_type_code, params, mime_type):
671
+ """Format all labels in the entity data using cached mappings."""
672
+ # Collect all mapping values needed
673
+ container_req, terminal_req, depot_req = await collect_mapping_requests(
674
+ entity_data, document_type_code
675
+ )
676
+
677
+ # Batch fetch mappings
678
+ container_map, terminal_map, depot_map = await batch_fetch_all_mappings(
679
+ container_req, terminal_req, depot_req
680
+ )
681
+
682
+ # Format labels using cached mappings
683
+ _, result = await format_label(
684
+ "root",
685
+ entity_data,
686
+ document_type_code,
687
+ params,
688
+ mime_type,
689
+ container_map,
690
+ terminal_map,
691
+ depot_map,
692
+ )
693
+
694
+ return _, result
695
+
696
+
601
697
  async def format_all_entities(result, document_type_code, params, mime_type):
602
698
  """Format the entity values in the result dictionary."""
603
699
  # Since we treat `customsInvoice` same as `partnerInvoice`
@@ -613,13 +709,13 @@ async def format_all_entities(result, document_type_code, params, mime_type):
613
709
  return {}
614
710
 
615
711
  # Format all entities recursively
616
- _, aggregated_data = await format_label(
617
- None, result, document_type_code, params, mime_type
712
+ _, aggregated_data = await format_all_labels(
713
+ result, document_type_code, params, mime_type
618
714
  )
619
715
 
620
716
  # Process partner invoice on lineitem mapping and reverse charge sentence
621
717
  if document_type_code in ["partnerInvoice", "bundeskasse"]:
622
- process_partner_invoice(params, aggregated_data, document_type_code)
718
+ await process_partner_invoice(params, aggregated_data, document_type_code)
623
719
 
624
720
  logger.info("Data Extraction completed successfully")
625
721
  return aggregated_data
@@ -651,41 +747,46 @@ def remove_stop_words(lineitem: str):
651
747
  )
652
748
 
653
749
 
654
- def llm_prediction_to_tuples(llm_prediction, number_of_pages=-1):
750
+ def llm_prediction_to_tuples(llm_prediction, number_of_pages=-1, page_number=None):
655
751
  """Convert LLM prediction dictionary to tuples of (value, page_number)."""
656
-
657
752
  # If only 1 page, simply pair each value with page number 0
658
753
  if number_of_pages == 1:
754
+ effective_page = 0 if page_number is None else page_number
659
755
  if isinstance(llm_prediction, dict):
660
756
  return {
661
- k: llm_prediction_to_tuples(v, number_of_pages)
757
+ k: llm_prediction_to_tuples(
758
+ v, number_of_pages, page_number=effective_page
759
+ )
662
760
  for k, v in llm_prediction.items()
663
761
  }
664
762
  elif isinstance(llm_prediction, list):
665
763
  return [
666
- llm_prediction_to_tuples(v, number_of_pages) for v in llm_prediction
764
+ llm_prediction_to_tuples(v, number_of_pages, page_number=effective_page)
765
+ for v in llm_prediction
667
766
  ]
668
767
  else:
669
- return (llm_prediction, 0) if llm_prediction else None
768
+ return (llm_prediction, effective_page) if llm_prediction else None
670
769
 
671
770
  # logic for multi-page predictions
672
771
  if isinstance(llm_prediction, dict):
673
772
  if "page_number" in llm_prediction.keys() and "value" in llm_prediction.keys():
674
773
  if llm_prediction["value"]:
675
774
  try:
676
- page_number = int(llm_prediction["page_number"])
775
+ _page_number = int(llm_prediction["page_number"])
677
776
  except: # noqa: E722
678
- page_number = -1
679
- return (llm_prediction["value"], page_number)
777
+ _page_number = -1
778
+ return (llm_prediction["value"], _page_number)
680
779
  return None
681
780
 
682
781
  for key, value in llm_prediction.items():
683
782
  llm_prediction[key] = llm_prediction_to_tuples(
684
- llm_prediction.get(key, value), number_of_pages
783
+ llm_prediction.get(key, value), number_of_pages, page_number
685
784
  )
686
785
 
687
786
  elif isinstance(llm_prediction, list):
688
787
  for i, item in enumerate(llm_prediction):
689
- llm_prediction[i] = llm_prediction_to_tuples(item, number_of_pages)
788
+ llm_prediction[i] = llm_prediction_to_tuples(
789
+ item, number_of_pages, page_number
790
+ )
690
791
 
691
792
  return llm_prediction
@@ -1,7 +1,5 @@
1
1
  """This module contains the postprocessing functions for the partner invoice."""
2
- from concurrent.futures import ThreadPoolExecutor
3
-
4
- from fuzzywuzzy import fuzz
2
+ from rapidfuzz import fuzz, process
5
3
 
6
4
  from src.io import logger
7
5
  from src.utils import get_tms_mappings
@@ -136,7 +134,7 @@ def update_recipient_and_vendor(aggregated_data, is_recipient_forto):
136
134
  ] = "Dasbachstraße 15, 54292 Trier, Germany"
137
135
 
138
136
 
139
- def process_partner_invoice(params, aggregated_data, document_type_code):
137
+ async def process_partner_invoice(params, aggregated_data, document_type_code):
140
138
  """Process the partner invoice data."""
141
139
  # Post process bundeskasse invoices
142
140
  if document_type_code == "bundeskasse":
@@ -160,27 +158,76 @@ def process_partner_invoice(params, aggregated_data, document_type_code):
160
158
  reverse_charge_info["formattedValue"] = reverse_charge_value
161
159
  reverse_charge = aggregated_data.pop("reverseChargeSentence", None)
162
160
 
163
- # Process each line item
164
- for line_item in line_items:
165
- if line_item.get("lineItemDescription", None) is not None:
166
- line_item["itemCode"] = associate_forto_item_code(
167
- line_item["lineItemDescription"]["formattedValue"],
168
- params,
169
- )
161
+ # Process everything in one go
162
+ processed_items = await process_line_items_batch(params, line_items, reverse_charge)
170
163
 
171
- # Add page number for the consistency
172
- line_item["itemCode"]["page"] = line_item["lineItemDescription"]["page"]
164
+ # Update your main data structure
165
+ aggregated_data["lineItem"] = processed_items
173
166
 
174
- if reverse_charge:
175
- # Distribute reverseChargeSentence to all line items
176
- line_item["reverseChargeSentence"] = reverse_charge
177
- line_item["reverseChargeSentence"]["page"] = reverse_charge["page"]
178
167
 
168
+ async def process_line_items_batch(
169
+ params: dict, line_items: list[dict], reverse_charge=None
170
+ ):
171
+ """
172
+ Processes all line items efficiently using a "Split-Apply-Combine" strategy.
173
+ """
174
+ # To store items that need external API lookup
175
+ pending_line_items = {}
176
+
177
+ # Check Fuzzy Matching
178
+ logger.info(f"Mapping line item codes with Fuzzy matching....")
179
+ for i, item in enumerate(line_items):
180
+ description_obj = item.get("lineItemDescription")
181
+
182
+ if not description_obj or not description_obj.get("formattedValue"):
183
+ continue
184
+ # Get the formatted description text
185
+ desc = description_obj["formattedValue"]
186
+
187
+ # Find Fuzzy Match
188
+ matched_code = find_matching_lineitem(
189
+ desc,
190
+ params["lookup_data"]["item_code"],
191
+ params["fuzzy_threshold_item_code"],
192
+ )
193
+
194
+ if matched_code:
195
+ # Set the code to the line item
196
+ item["itemCode"] = {
197
+ "documentValue": desc,
198
+ "formattedValue": matched_code,
199
+ "page": description_obj.get("page"),
200
+ }
201
+ else:
202
+ # Store for batch API call
203
+ pending_line_items[i] = desc
204
+
205
+ # Batch API Call for Embedding lookups
206
+ if pending_line_items:
207
+ values_to_fetch = list(set(pending_line_items.values()))
208
+ logger.info(f"Mapping {len(values_to_fetch)} line items from Embedding API...")
209
+
210
+ # Await the batch response {"desc1": "code1", "desc2": "code2"}
211
+ api_results = await get_tms_mappings(
212
+ input_list=values_to_fetch, embedding_type="line_items"
213
+ )
214
+
215
+ # Merge API results back into original list
216
+ for index, desc in pending_line_items.items():
217
+ # Get result from API response, or None if API failed for that item
218
+ forto_code = api_results.get(desc)
219
+
220
+ # Update the original item
221
+ line_items[index]["itemCode"] = {
222
+ "documentValue": desc,
223
+ "formattedValue": forto_code, # Might be None if API failed
224
+ "page": line_items[index]["lineItemDescription"].get("page"),
225
+ }
179
226
 
180
- def compute_score(args):
181
- """Compute the fuzzy matching score between a new line item and a key."""
182
- new_lineitem, key = args
183
- return key, fuzz.ratio(new_lineitem, key)
227
+ # Add reverse charge here if exists
228
+ if reverse_charge:
229
+ [item.update({"reverseChargeSentence": reverse_charge}) for item in line_items]
230
+ return line_items
184
231
 
185
232
 
186
233
  def get_fuzzy_match_score(target: str, sentences: list, threshold: int):
@@ -195,16 +242,18 @@ def get_fuzzy_match_score(target: str, sentences: list, threshold: int):
195
242
  tuple: (best_match, score) if above threshold, else (None, 0)
196
243
  """
197
244
  # Use multiprocessing to find the best match
198
- with ThreadPoolExecutor() as executor:
199
- results = executor.map(compute_score, [(target, s) for s in sentences])
245
+ result = process.extractOne(
246
+ target, sentences, scorer=fuzz.WRatio, score_cutoff=threshold
247
+ )
200
248
 
201
- # Find the best match and score
202
- best_match, best_score = max(results, key=lambda x: x[1], default=(None, 0))
249
+ if result is None:
250
+ return None, False
203
251
 
204
- # return best_match, best_score
205
- # If the best match score is above a threshold (e.g., 80), return it
206
- if best_score >= threshold:
207
- return best_match, True
252
+ match, score, index = result
253
+
254
+ # return best_match if the best match score is above a threshold (e.g., 80)
255
+ if match:
256
+ return match, True
208
257
 
209
258
  return None, False
210
259
 
@@ -236,46 +285,59 @@ def find_matching_lineitem(new_lineitem: str, kvp_dict: dict, threshold=90):
236
285
  Returns:
237
286
  str: The best matching 'Forto SLI' value from the dictionary.
238
287
  """
239
- new_lineitem = new_lineitem.upper()
240
-
241
288
  # Check if the new line item is already in the dictionary
242
289
  if new_lineitem in kvp_dict:
243
290
  return kvp_dict[new_lineitem]
244
291
 
245
292
  # Get the best fuzzy match score for the extracted line item
246
- best_match, _ = get_fuzzy_match_score(
247
- new_lineitem, list(kvp_dict.keys()), threshold
293
+ match, _ = get_fuzzy_match_score(
294
+ new_lineitem,
295
+ list(kvp_dict.keys()),
296
+ threshold,
248
297
  )
249
298
 
250
- return kvp_dict.get(best_match, None)
299
+ if match:
300
+ # find the code from the kvp_dict
301
+ return kvp_dict[match]
251
302
 
303
+ return None
252
304
 
253
- def associate_forto_item_code(input_string, params):
254
- """
255
- Finds a match for the input string using fuzzy matching first, then embedding fallback.
256
-
257
- 1. Tries to find a fuzzy match for input_string against the keys in
258
- mapping_data using RapidFuzz, requiring a score >= fuzzy_threshold.
259
- 2. If found, returns the corresponding value from mapping_data.
260
- 3. If not found above threshold, calls the embedding_fallback function.
261
305
 
306
+ async def associate_forto_item_code(line_item_data, params):
307
+ """
308
+ Associates Forto item codes to a list of line item descriptions.
262
309
  Args:
263
- input_string: The string to find a match for.
264
- params: Parameters containing the lookup data and fuzzy threshold.
310
+ line_item_data (dict): A dictionary where keys are original descriptions and values are cleaned descriptions.
311
+ params (dict): Parameters containing lookup data and thresholds.
265
312
 
266
313
  Returns:
267
- The matched value (from fuzzy match or embedding), or None if no match found.
314
+ list: A list of dictionaries with 'description' and 'itemCode' keys.
268
315
  """
269
- # Get the Forto item code using fuzzy matching
270
- forto_item_code = find_matching_lineitem(
271
- new_lineitem=input_string,
272
- kvp_dict=params["lookup_data"]["item_code"], # TODO: Parse the KVP dictionary
273
- threshold=params["fuzzy_threshold_item_code"],
274
- )
275
316
 
276
- if forto_item_code is None:
277
- # 2. Fallback to embedding function if no good fuzzy match
278
- forto_item_code = get_tms_mappings(input_string, "line_items")
317
+ result = []
318
+ pending_line_items = {}
319
+ for desc, f_desc in line_item_data.items():
320
+ # Get the Forto item code using fuzzy matching
321
+ code = find_matching_lineitem(
322
+ new_lineitem=f_desc,
323
+ kvp_dict=params["lookup_data"]["item_code"],
324
+ threshold=params["fuzzy_threshold_item_code"],
325
+ )
326
+ if code:
327
+ result.append({"description": desc, "itemCode": code})
328
+ else:
329
+ pending_line_items[desc] = f_desc
330
+
331
+ # Batch API Call for Embedding lookups
332
+ if pending_line_items:
333
+ api_results = await get_tms_mappings(
334
+ input_list=list(pending_line_items.values()),
335
+ embedding_type="line_items",
336
+ )
337
+
338
+ # Merge API results back into original list
339
+ for desc, f_desc in pending_line_items.items():
340
+ code = api_results.get(f_desc)
341
+ result.append({"description": desc, "itemCode": code})
279
342
 
280
- result = {"documentValue": input_string, "formattedValue": forto_item_code}
281
343
  return result
src/utils.py CHANGED
@@ -6,16 +6,16 @@ import json
6
6
  import os
7
7
  import pickle
8
8
  from datetime import datetime
9
- from typing import Literal
9
+ from typing import Any, Dict, List, Literal, Optional
10
10
 
11
+ import httpx
11
12
  import numpy as np
12
13
  import openpyxl
13
14
  import pandas as pd
14
- import requests
15
15
  from google.cloud import documentai_v1beta3 as docu_ai_beta
16
16
  from pypdf import PdfReader, PdfWriter
17
17
 
18
- from src.io import get_storage_client, logger
18
+ from src.io import bq_logs, get_storage_client, logger
19
19
 
20
20
 
21
21
  def get_pdf_page_count(pdf_bytes):
@@ -31,29 +31,6 @@ def get_pdf_page_count(pdf_bytes):
31
31
  return len(reader.pages)
32
32
 
33
33
 
34
- def bq_logs(data_to_insert, params):
35
- """Insert logs into Google BigQuery.
36
-
37
- Args:
38
- data_to_insert (list): The data to insert into BigQuery.
39
- params (dict): The parameters dictionary.
40
- """
41
- # Use the pre-initialized BigQuery client
42
- bq_client = params["bq_client"]
43
- # Get the table string
44
- table_string = f"{params['g_ai_project_name']}.{params['g_ai_gbq_db_schema']}.{params['g_ai_gbq_db_table_out']}"
45
-
46
- logger.info(f"Log table: {table_string}")
47
- # Insert the rows into the table
48
- insert_logs = bq_client.insert_rows_json(table_string, data_to_insert)
49
-
50
- # Check if there were any errors inserting the rows
51
- if not insert_logs:
52
- logger.info("New rows have been added.")
53
- else:
54
- logger.info("Errors occurred while inserting rows: ", insert_logs)
55
-
56
-
57
34
  async def get_data_set_schema_from_docai(
58
35
  schema_client, project_id=None, location=None, processor_id=None, name=None
59
36
  ):
@@ -383,9 +360,9 @@ def extract_top_pages(pdf_bytes, num_pages=4):
383
360
  return output.getvalue()
384
361
 
385
362
 
386
- def get_tms_mappings(
387
- input_list: list[str], embedding_type: str, llm_ports: list[str] = None
388
- ):
363
+ async def get_tms_mappings(
364
+ input_list: List[str], embedding_type: str, llm_ports: Optional[List[str]] = None
365
+ ) -> Dict[str, Any]:
389
366
  """Get TMS mappings for the given values.
390
367
 
391
368
  Args:
@@ -395,39 +372,66 @@ def get_tms_mappings(
395
372
  llm_ports (list[str], optional): List of LLM ports to use. Defaults to None.
396
373
 
397
374
  Returns:
398
- dict: A dictionary with the mapping results.
375
+ dict or string: A dictionary or a string with the mapping results.
399
376
  """
400
- # To test the API locally, port-forward the embedding service in the sandbox to 8080:80
401
- # If you want to launch uvicorn from the tms-embedding repo, then use --port 8080 in the config file
402
377
  base_url = (
403
378
  "http://0.0.0.0:8080/"
404
379
  if os.getenv("CLUSTER") is None
405
380
  else "http://tms-mappings.api.svc.cluster.local./"
406
381
  )
407
382
 
383
+ # Ensure clean inputs
384
+ if not input_list:
385
+ return {}
386
+
408
387
  # Ensure input_list is a list
409
388
  if not isinstance(input_list, list):
410
389
  input_list = [input_list]
411
390
 
412
391
  # Always send a dict with named keys
413
392
  payload = {embedding_type: input_list}
393
+
414
394
  if llm_ports:
415
395
  payload["llm_ports"] = llm_ports if isinstance(llm_ports, list) else [llm_ports]
416
396
 
417
397
  # Make the POST request to the TMS mappings API
418
- url = f"{base_url}/{embedding_type}"
419
- response = requests.post(url=url, json=payload)
398
+ url = f"{base_url}{embedding_type}"
420
399
 
421
- if response.status_code != 200:
422
- logger.error(
423
- f"Error from TMS mappings API: {response.status_code} - {response.text}"
424
- )
400
+ # Use a timeout so the code doesn't hang forever
401
+ timeout = httpx.Timeout(60.0, connect=10.0)
402
+
403
+ async with httpx.AsyncClient(timeout=timeout) as client:
404
+ try:
405
+ response = await client.post(url, json=payload)
406
+ response.raise_for_status()
425
407
 
426
- formatted_values = (
427
- response.json().get("response", {}).get("data", {}).get(input_list[0], None)
408
+ # Structure expected: {"response": {"data": {"desc1": "code1", "desc2": "code2"}}}
409
+ return response.json().get("response", {}).get("data", {})
410
+
411
+ except httpx.HTTPStatusError as exc:
412
+ logger.error(
413
+ f"Error from TMS mappings API: {exc.response.status_code} - {exc.response.text}"
414
+ )
415
+ return {}
416
+
417
+
418
+ async def batch_fetch_all_mappings(container_types, terminals, depots):
419
+ """Batch fetch all mappings for container types, terminals, and depots."""
420
+ # run batch calls concurrently
421
+ results = await asyncio.gather(
422
+ get_tms_mappings(list(container_types), "container_types"),
423
+ get_tms_mappings(list(terminals), "terminals"),
424
+ get_tms_mappings(list(depots), "depots"),
428
425
  )
429
426
 
430
- return formatted_values
427
+ batch_container_map, batch_terminal_map, batch_depot_map = results
428
+
429
+ # Convert lists of tuples to dicts if necessary
430
+ return (
431
+ dict(batch_container_map or {}),
432
+ dict(batch_terminal_map or {}),
433
+ dict(batch_depot_map or {}),
434
+ )
431
435
 
432
436
 
433
437
  def transform_schema_strings(schema):
@@ -502,3 +506,21 @@ def estimate_page_count(sheet):
502
506
  else:
503
507
  return None
504
508
  return np.ceil(pg_cnt / 500)
509
+
510
+
511
+ def split_pdf_into_chunks(file_content: bytes, chunk_size: int = 1):
512
+ """Split PDF into smaller page chunks."""
513
+ pdf = PdfReader(io.BytesIO(file_content))
514
+ total_pages = len(pdf.pages)
515
+
516
+ # TODO: update the chunk_size based on doc length. However, it breaks the page number extraction logic.
517
+ for i in range(0, total_pages, chunk_size):
518
+ writer = PdfWriter()
519
+ for j in range(i, min(i + chunk_size, total_pages)):
520
+ writer.add_page(pdf.pages[j])
521
+
522
+ buffer = io.BytesIO()
523
+ writer.write(buffer)
524
+ buffer.seek(0)
525
+
526
+ yield buffer.getvalue()