data-science-document-ai 1.42.5__py3-none-any.whl → 1.57.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {data_science_document_ai-1.42.5.dist-info → data_science_document_ai-1.57.0.dist-info}/METADATA +2 -2
  2. data_science_document_ai-1.57.0.dist-info/RECORD +60 -0
  3. src/constants.py +13 -34
  4. src/docai_processor_config.yaml +0 -69
  5. src/excel_processing.py +24 -14
  6. src/io.py +23 -0
  7. src/llm.py +0 -29
  8. src/pdf_processing.py +183 -76
  9. src/postprocessing/common.py +172 -28
  10. src/postprocessing/postprocess_partner_invoice.py +194 -59
  11. src/prompts/library/arrivalNotice/other/placeholders.json +70 -0
  12. src/prompts/library/arrivalNotice/other/prompt.txt +40 -0
  13. src/prompts/library/bookingConfirmation/evergreen/placeholders.json +135 -21
  14. src/prompts/library/bookingConfirmation/evergreen/prompt.txt +21 -17
  15. src/prompts/library/bookingConfirmation/hapag-lloyd/placeholders.json +136 -22
  16. src/prompts/library/bookingConfirmation/hapag-lloyd/prompt.txt +52 -58
  17. src/prompts/library/bookingConfirmation/maersk/placeholders.json +135 -21
  18. src/prompts/library/bookingConfirmation/maersk/prompt.txt +10 -1
  19. src/prompts/library/bookingConfirmation/msc/placeholders.json +135 -21
  20. src/prompts/library/bookingConfirmation/msc/prompt.txt +10 -1
  21. src/prompts/library/bookingConfirmation/oocl/placeholders.json +149 -21
  22. src/prompts/library/bookingConfirmation/oocl/prompt.txt +11 -3
  23. src/prompts/library/bookingConfirmation/other/placeholders.json +149 -21
  24. src/prompts/library/bookingConfirmation/other/prompt.txt +56 -57
  25. src/prompts/library/bookingConfirmation/yangming/placeholders.json +149 -21
  26. src/prompts/library/bookingConfirmation/yangming/prompt.txt +11 -1
  27. src/prompts/library/bundeskasse/other/placeholders.json +5 -5
  28. src/prompts/library/bundeskasse/other/prompt.txt +7 -5
  29. src/prompts/library/commercialInvoice/other/placeholders.json +125 -0
  30. src/prompts/library/commercialInvoice/other/prompt.txt +1 -1
  31. src/prompts/library/customsAssessment/other/placeholders.json +70 -0
  32. src/prompts/library/customsAssessment/other/prompt.txt +24 -37
  33. src/prompts/library/customsInvoice/other/prompt.txt +4 -3
  34. src/prompts/library/deliveryOrder/other/placeholders.json +80 -27
  35. src/prompts/library/deliveryOrder/other/prompt.txt +26 -40
  36. src/prompts/library/draftMbl/other/placeholders.json +33 -33
  37. src/prompts/library/draftMbl/other/prompt.txt +34 -44
  38. src/prompts/library/finalMbL/other/placeholders.json +80 -0
  39. src/prompts/library/finalMbL/other/prompt.txt +34 -44
  40. src/prompts/library/packingList/other/placeholders.json +98 -0
  41. src/prompts/library/partnerInvoice/other/prompt.txt +8 -7
  42. src/prompts/library/preprocessing/carrier/placeholders.json +0 -16
  43. src/prompts/library/shippingInstruction/other/placeholders.json +115 -0
  44. src/prompts/library/shippingInstruction/other/prompt.txt +26 -14
  45. src/prompts/prompt_library.py +0 -4
  46. src/setup.py +25 -24
  47. src/utils.py +120 -68
  48. data_science_document_ai-1.42.5.dist-info/RECORD +0 -57
  49. src/prompts/library/draftMbl/hapag-lloyd/prompt.txt +0 -45
  50. src/prompts/library/draftMbl/maersk/prompt.txt +0 -19
  51. src/prompts/library/finalMbL/hapag-lloyd/prompt.txt +0 -44
  52. src/prompts/library/finalMbL/maersk/prompt.txt +0 -19
  53. {data_science_document_ai-1.42.5.dist-info → data_science_document_ai-1.57.0.dist-info}/WHEEL +0 -0
@@ -12,7 +12,7 @@ from src.constants import formatting_rules
12
12
  from src.io import logger
13
13
  from src.postprocessing.postprocess_partner_invoice import process_partner_invoice
14
14
  from src.prompts.prompt_library import prompt_library
15
- from src.utils import get_tms_mappings
15
+ from src.utils import batch_fetch_all_mappings, get_tms_mappings
16
16
 
17
17
  tms_domain = os.environ["TMS_DOMAIN"]
18
18
 
@@ -134,8 +134,11 @@ def extract_number(data_field_value):
134
134
  formatted_value: string
135
135
 
136
136
  """
137
+ # Remove container size pattern like 20FT, 40HC, etc from 1 x 40HC
138
+ value = remove_unwanted_patterns(data_field_value)
139
+
137
140
  formatted_value = ""
138
- for c in data_field_value:
141
+ for c in value:
139
142
  if c.isnumeric() or c in [",", ".", "-"]:
140
143
  formatted_value += c
141
144
 
@@ -319,6 +322,14 @@ def remove_unwanted_patterns(lineitem: str):
319
322
  # Remove "HIGH CUBE"
320
323
  lineitem = lineitem.replace("HIGH CUBE", "")
321
324
 
325
+ # Remove container size e.g., 20FT, 40HC, etc.
326
+ pattern = [
327
+ f"{s}{t}"
328
+ for s in ("20|22|40|45".split("|"))
329
+ for t in ("FT|HC|DC|HD|GP|OT|RF|FR|TK|DV".split("|"))
330
+ ]
331
+ lineitem = re.sub(r"|".join(pattern), "", lineitem, flags=re.IGNORECASE).strip()
332
+
322
333
  return lineitem
323
334
 
324
335
 
@@ -349,42 +360,75 @@ def clean_item_description(lineitem: str, remove_numbers: bool = True):
349
360
  # Remove the currency codes
350
361
  lineitem = re.sub(currency_codes_pattern, "", lineitem, flags=re.IGNORECASE)
351
362
 
363
+ # remove other patterns
364
+ lineitem = remove_unwanted_patterns(lineitem)
365
+
352
366
  # Remove numbers from the line item
353
367
  if (
354
368
  remove_numbers
355
369
  ): # Do not remove numbers for the reverse charge sentence as it contains Article number
356
370
  lineitem = re.sub(r"\d+", "", lineitem)
357
371
 
358
- # remove other patterns
359
- lineitem = remove_unwanted_patterns(lineitem)
360
-
361
372
  # remove special chars
362
373
  lineitem = re.sub(r"[^A-Za-z0-9\s]", " ", lineitem).strip()
363
374
 
375
+ # Remove x from lineitem like 10 x
376
+ lineitem = re.sub(r"\b[xX]\b", " ", lineitem).strip()
377
+
364
378
  return re.sub(r"\s{2,}", " ", lineitem).strip()
365
379
 
366
380
 
367
- async def format_label(entity_k, entity_value, document_type_code, params):
381
+ async def format_label(
382
+ entity_k,
383
+ entity_value,
384
+ document_type_code,
385
+ params,
386
+ mime_type,
387
+ container_map,
388
+ terminal_map,
389
+ depot_map,
390
+ ):
368
391
  llm_client = params["LlmClient"]
369
392
  if isinstance(entity_value, dict): # if it's a nested entity
370
393
  format_tasks = [
371
- format_label(sub_k, sub_v, document_type_code, params)
394
+ format_label(
395
+ sub_k,
396
+ sub_v,
397
+ document_type_code,
398
+ params,
399
+ mime_type,
400
+ container_map,
401
+ terminal_map,
402
+ depot_map,
403
+ )
372
404
  for sub_k, sub_v in entity_value.items()
373
405
  ]
374
406
  return entity_k, {k: v for k, v in await asyncio.gather(*format_tasks)}
375
407
  if isinstance(entity_value, list):
376
408
  format_tasks = await asyncio.gather(
377
409
  *[
378
- format_label(entity_k, sub_v, document_type_code, params)
410
+ format_label(
411
+ entity_k,
412
+ sub_v,
413
+ document_type_code,
414
+ params,
415
+ mime_type,
416
+ container_map,
417
+ terminal_map,
418
+ depot_map,
419
+ )
379
420
  for sub_v in entity_value
380
421
  ]
381
422
  )
382
423
  return entity_k, [v for _, v in format_tasks]
383
- if isinstance(entity_value, tuple):
384
- page = entity_value[1]
385
- entity_value = entity_value[0]
386
- else:
387
- page = -1
424
+
425
+ if mime_type == "application/pdf":
426
+ if isinstance(entity_value, tuple):
427
+ page = entity_value[1]
428
+ entity_value = entity_value[0]
429
+ else:
430
+ page = -1
431
+
388
432
  entity_key = entity_k.lower()
389
433
  formatted_value = None
390
434
 
@@ -394,13 +438,13 @@ async def format_label(entity_k, entity_value, document_type_code, params):
394
438
  )
395
439
 
396
440
  elif (entity_key == "containertype") or (entity_key == "containersize"):
397
- formatted_value = get_tms_mappings(entity_value, "container_types")
441
+ formatted_value = container_map.get(entity_value)
398
442
 
399
443
  elif check_formatting_rule(entity_k, document_type_code, "terminal"):
400
- formatted_value = get_tms_mappings(entity_value, "terminals")
444
+ formatted_value = terminal_map.get(entity_value)
401
445
 
402
446
  elif check_formatting_rule(entity_k, document_type_code, "depot"):
403
- formatted_value = get_tms_mappings(entity_value, "depots")
447
+ formatted_value = depot_map.get(entity_value)
404
448
 
405
449
  elif entity_key.startswith(("eta", "etd", "duedate", "issuedate", "servicedate")):
406
450
  try:
@@ -421,7 +465,10 @@ async def format_label(entity_k, entity_value, document_type_code, params):
421
465
  except ValueError as e:
422
466
  logger.info(f"ParserError: {e}")
423
467
 
424
- elif entity_key in ["invoicenumber", "creditnoteinvoicenumber"]:
468
+ elif (
469
+ entity_key in ["invoicenumber", "creditnoteinvoicenumber"]
470
+ and document_type_code == "bundeskasse"
471
+ ):
425
472
  formatted_value = clean_invoice_number(entity_value)
426
473
 
427
474
  elif entity_key in ("shipmentid", "partnerreference"):
@@ -482,8 +529,10 @@ async def format_label(entity_k, entity_value, document_type_code, params):
482
529
  result = {
483
530
  "documentValue": entity_value,
484
531
  "formattedValue": formatted_value,
485
- "page": page,
486
532
  }
533
+ if mime_type == "application/pdf":
534
+ result["page"] = page
535
+
487
536
  return entity_k, result
488
537
 
489
538
 
@@ -491,7 +540,8 @@ async def get_port_code_ai(port: str, llm_client, doc_type=None):
491
540
  """Get port code using AI model."""
492
541
  port_llm = await get_port_code_llm(port, llm_client, doc_type=doc_type)
493
542
 
494
- return get_tms_mappings(port, "ports", port_llm)
543
+ result = await get_tms_mappings(port, "ports", port_llm)
544
+ return result.get(port, None)
495
545
 
496
546
 
497
547
  async def get_port_code_llm(port: str, llm_client, doc_type=None):
@@ -582,7 +632,75 @@ def decimal_convertor(value, quantity=False):
582
632
  return value
583
633
 
584
634
 
585
- async def format_all_entities(result, document_type_code, params):
635
+ async def collect_mapping_requests(entity_value, document_type_code):
636
+ """Collect all unique container types, terminals, and depots from the entity value."""
637
+ # Sets to store unique values
638
+ container_types = set()
639
+ terminals = set()
640
+ depots = set()
641
+
642
+ def walk(key, value):
643
+ key_lower = key.lower()
644
+
645
+ # nested dict
646
+ if isinstance(value, dict):
647
+ for k, v in value.items():
648
+ walk(k, v)
649
+
650
+ # list of values
651
+ elif isinstance(value, list):
652
+ for item in value:
653
+ walk(key, item)
654
+
655
+ # leaf node
656
+ else:
657
+ if key_lower in ("containertype", "containersize"):
658
+ # Take only "20DV" from ('20DV', 0) if it's a tuple
659
+ container_types.add(value[0]) if isinstance(
660
+ value, tuple
661
+ ) else container_types.add(value)
662
+
663
+ elif check_formatting_rule(key, document_type_code, "terminal"):
664
+ terminals.add(value[0]) if isinstance(value, tuple) else terminals.add(
665
+ value
666
+ )
667
+
668
+ elif check_formatting_rule(key, document_type_code, "depot"):
669
+ depots.add(value[0]) if isinstance(value, tuple) else depots.add(value)
670
+
671
+ walk("root", entity_value)
672
+
673
+ return container_types, terminals, depots
674
+
675
+
676
+ async def format_all_labels(entity_data, document_type_code, params, mime_type):
677
+ """Format all labels in the entity data using cached mappings."""
678
+ # Collect all mapping values needed
679
+ container_req, terminal_req, depot_req = await collect_mapping_requests(
680
+ entity_data, document_type_code
681
+ )
682
+
683
+ # Batch fetch mappings
684
+ container_map, terminal_map, depot_map = await batch_fetch_all_mappings(
685
+ container_req, terminal_req, depot_req
686
+ )
687
+
688
+ # Format labels using cached mappings
689
+ _, result = await format_label(
690
+ "root",
691
+ entity_data,
692
+ document_type_code,
693
+ params,
694
+ mime_type,
695
+ container_map,
696
+ terminal_map,
697
+ depot_map,
698
+ )
699
+
700
+ return _, result
701
+
702
+
703
+ async def format_all_entities(result, document_type_code, params, mime_type):
586
704
  """Format the entity values in the result dictionary."""
587
705
  # Since we treat `customsInvoice` same as `partnerInvoice`
588
706
  document_type_code = (
@@ -597,11 +715,13 @@ async def format_all_entities(result, document_type_code, params):
597
715
  return {}
598
716
 
599
717
  # Format all entities recursively
600
- _, aggregated_data = await format_label(None, result, document_type_code, params)
718
+ _, aggregated_data = await format_all_labels(
719
+ result, document_type_code, params, mime_type
720
+ )
601
721
 
602
722
  # Process partner invoice on lineitem mapping and reverse charge sentence
603
723
  if document_type_code in ["partnerInvoice", "bundeskasse"]:
604
- process_partner_invoice(params, aggregated_data, document_type_code)
724
+ await process_partner_invoice(params, aggregated_data, document_type_code)
605
725
 
606
726
  logger.info("Data Extraction completed successfully")
607
727
  return aggregated_data
@@ -633,22 +753,46 @@ def remove_stop_words(lineitem: str):
633
753
  )
634
754
 
635
755
 
636
- def llm_prediction_to_tuples(llm_prediction):
756
+ def llm_prediction_to_tuples(llm_prediction, number_of_pages=-1, page_number=None):
637
757
  """Convert LLM prediction dictionary to tuples of (value, page_number)."""
758
+ # If only 1 page, simply pair each value with page number 0
759
+ if number_of_pages == 1:
760
+ effective_page = 0 if page_number is None else page_number
761
+ if isinstance(llm_prediction, dict):
762
+ return {
763
+ k: llm_prediction_to_tuples(
764
+ v, number_of_pages, page_number=effective_page
765
+ )
766
+ for k, v in llm_prediction.items()
767
+ }
768
+ elif isinstance(llm_prediction, list):
769
+ return [
770
+ llm_prediction_to_tuples(v, number_of_pages, page_number=effective_page)
771
+ for v in llm_prediction
772
+ ]
773
+ else:
774
+ return (llm_prediction, effective_page) if llm_prediction else None
775
+
776
+ # logic for multi-page predictions
638
777
  if isinstance(llm_prediction, dict):
639
778
  if "page_number" in llm_prediction.keys() and "value" in llm_prediction.keys():
640
779
  if llm_prediction["value"]:
641
780
  try:
642
- page_number = int(llm_prediction["page_number"])
781
+ _page_number = int(llm_prediction["page_number"])
643
782
  except: # noqa: E722
644
- page_number = -1
645
- return (llm_prediction["value"], page_number)
783
+ _page_number = -1
784
+ return (llm_prediction["value"], _page_number)
646
785
  return None
786
+
647
787
  for key, value in llm_prediction.items():
648
788
  llm_prediction[key] = llm_prediction_to_tuples(
649
- llm_prediction.get(key, value)
789
+ llm_prediction.get(key, value), number_of_pages, page_number
650
790
  )
791
+
651
792
  elif isinstance(llm_prediction, list):
652
793
  for i, item in enumerate(llm_prediction):
653
- llm_prediction[i] = llm_prediction_to_tuples(item)
794
+ llm_prediction[i] = llm_prediction_to_tuples(
795
+ item, number_of_pages, page_number
796
+ )
797
+
654
798
  return llm_prediction
@@ -1,7 +1,7 @@
1
1
  """This module contains the postprocessing functions for the partner invoice."""
2
- from concurrent.futures import ThreadPoolExecutor
2
+ from collections import defaultdict
3
3
 
4
- from fuzzywuzzy import fuzz
4
+ from rapidfuzz import fuzz, process
5
5
 
6
6
  from src.io import logger
7
7
  from src.utils import get_tms_mappings
@@ -105,9 +105,18 @@ def post_process_bundeskasse(aggregated_data):
105
105
  )
106
106
 
107
107
  # Check if the deferredDutyPayer is forto
108
- deferredDutyPayer = line_item.get("deferredDutyPayer", {})
109
- lower = deferredDutyPayer.get("documentValue", "").lower()
110
- if any(key in lower for key in ["de789147263644738", "forto"]):
108
+ KEYWORDS = {"de789147263644738", "forto", "009812"}
109
+
110
+ def is_forto_recipient(line_item: dict) -> bool:
111
+ values_to_check = [
112
+ line_item.get("deferredDutyPayer", {}).get("documentValue", ""),
113
+ line_item.get("vatId", {}).get("documentValue", ""),
114
+ ]
115
+
116
+ combined = " ".join(values_to_check).lower()
117
+ return any(keyword in combined for keyword in KEYWORDS)
118
+
119
+ if is_forto_recipient(line_item):
111
120
  is_recipient_forto = True
112
121
 
113
122
  update_recipient_and_vendor(aggregated_data, is_recipient_forto)
@@ -136,13 +145,32 @@ def update_recipient_and_vendor(aggregated_data, is_recipient_forto):
136
145
  ] = "Dasbachstraße 15, 54292 Trier, Germany"
137
146
 
138
147
 
139
- def process_partner_invoice(params, aggregated_data, document_type_code):
148
+ def select_unique_bank_account(bank_account):
149
+ # Select the unique bank account if multiple are present
150
+ if isinstance(bank_account, list) and bank_account:
151
+ best = defaultdict(lambda: None)
152
+
153
+ for item in bank_account:
154
+ dv = item["documentValue"]
155
+ if best[dv] is None or item["page"] < best[dv]["page"]:
156
+ best[dv] = item
157
+
158
+ unique = list(best.values())
159
+ return unique
160
+
161
+
162
+ async def process_partner_invoice(params, aggregated_data, document_type_code):
140
163
  """Process the partner invoice data."""
141
164
  # Post process bundeskasse invoices
142
165
  if document_type_code == "bundeskasse":
143
166
  post_process_bundeskasse(aggregated_data)
144
167
  return
145
168
 
169
+ if "bankAccount" in aggregated_data:
170
+ aggregated_data["bankAccount"] = select_unique_bank_account(
171
+ aggregated_data["bankAccount"]
172
+ )
173
+
146
174
  line_items = aggregated_data.get("lineItem", [])
147
175
  # Add debug logging
148
176
  logger.info(f"Processing partnerInvoice with {len(line_items)} line items")
@@ -160,27 +188,78 @@ def process_partner_invoice(params, aggregated_data, document_type_code):
160
188
  reverse_charge_info["formattedValue"] = reverse_charge_value
161
189
  reverse_charge = aggregated_data.pop("reverseChargeSentence", None)
162
190
 
163
- # Process each line item
164
- for line_item in line_items:
165
- if line_item.get("lineItemDescription", None) is not None:
166
- line_item["itemCode"] = associate_forto_item_code(
167
- line_item["lineItemDescription"]["formattedValue"],
168
- params,
169
- )
191
+ # Partner Name
192
+ partner_name = aggregated_data.get("vendorName", {}).get("documentValue", None)
193
+
194
+ # Process everything in one go
195
+ processed_items = await process_line_items_batch(
196
+ params, line_items, reverse_charge, partner_name
197
+ )
170
198
 
171
- # Add page number for the consistency
172
- line_item["itemCode"]["page"] = line_item["lineItemDescription"]["page"]
199
+ # Update your main data structure
200
+ aggregated_data["lineItem"] = processed_items
173
201
 
174
- if reverse_charge:
175
- # Distribute reverseChargeSentence to all line items
176
- line_item["reverseChargeSentence"] = reverse_charge
177
- line_item["reverseChargeSentence"]["page"] = reverse_charge["page"]
178
202
 
203
+ async def process_line_items_batch(
204
+ params: dict, line_items: list[dict], reverse_charge=None, partner_name=None
205
+ ):
206
+ """
207
+ Processes all line items efficiently using a "Split-Apply-Combine" strategy.
208
+ """
209
+ # To store items that need external API lookup
210
+ pending_line_items = {}
211
+
212
+ # Check Fuzzy Matching
213
+ logger.info(f"Mapping line item codes with Fuzzy matching....")
214
+ for i, item in enumerate(line_items):
215
+ description_obj = item.get("lineItemDescription")
216
+
217
+ if not description_obj or not description_obj.get("formattedValue"):
218
+ continue
219
+ # Get the formatted description text
220
+ desc = description_obj["formattedValue"]
221
+
222
+ # Find Fuzzy Match
223
+ matched_code = find_matching_lineitem(
224
+ desc,
225
+ params["lookup_data"]["item_code"],
226
+ params["fuzzy_threshold_item_code"],
227
+ )
228
+
229
+ if matched_code:
230
+ # Set the code to the line item
231
+ item["itemCode"] = {
232
+ "documentValue": desc,
233
+ "formattedValue": matched_code,
234
+ "page": description_obj.get("page"),
235
+ }
236
+ else:
237
+ # Store for batch API call
238
+ pending_line_items[i] = desc
239
+
240
+ # Batch API Call for Embedding lookups
241
+ if pending_line_items:
242
+ code_map = await fetch_line_item_codes(pending_line_items, partner_name, params)
243
+
244
+ for index, desc in pending_line_items.items():
245
+ line_items[index]["itemCode"] = {
246
+ "documentValue": desc,
247
+ "formattedValue": code_map.get(desc),
248
+ "page": line_items[index]["lineItemDescription"].get("page"),
249
+ }
250
+
251
+ # Add reverse charge here if exists
252
+ if reverse_charge:
253
+ [
254
+ item.update({"reverseChargeSentence": reverse_charge})
255
+ for item in line_items
256
+ if (
257
+ (item.get("itemCode") and item["itemCode"]["formattedValue"] != "CDU")
258
+ or not item.get("itemCode")
259
+ )
260
+ ]
179
261
 
180
- def compute_score(args):
181
- """Compute the fuzzy matching score between a new line item and a key."""
182
- new_lineitem, key = args
183
- return key, fuzz.ratio(new_lineitem, key)
262
+ return line_items
184
263
 
185
264
 
186
265
  def get_fuzzy_match_score(target: str, sentences: list, threshold: int):
@@ -195,16 +274,18 @@ def get_fuzzy_match_score(target: str, sentences: list, threshold: int):
195
274
  tuple: (best_match, score) if above threshold, else (None, 0)
196
275
  """
197
276
  # Use multiprocessing to find the best match
198
- with ThreadPoolExecutor() as executor:
199
- results = executor.map(compute_score, [(target, s) for s in sentences])
277
+ result = process.extractOne(
278
+ target, sentences, scorer=fuzz.WRatio, score_cutoff=threshold
279
+ )
280
+
281
+ if result is None:
282
+ return None, False
200
283
 
201
- # Find the best match and score
202
- best_match, best_score = max(results, key=lambda x: x[1], default=(None, 0))
284
+ match, score, index = result
203
285
 
204
- # return best_match, best_score
205
- # If the best match score is above a threshold (e.g., 80), return it
206
- if best_score >= threshold:
207
- return best_match, True
286
+ # return best_match if the best match score is above a threshold (e.g., 80)
287
+ if match:
288
+ return match, True
208
289
 
209
290
  return None, False
210
291
 
@@ -219,11 +300,14 @@ def if_reverse_charge_sentence(sentence: str, params):
219
300
  return False
220
301
 
221
302
  # Check if the sentence is similar to any of the reverse charge sentences
222
- _, is_reverse_charge = get_fuzzy_match_score(
223
- sentence, reverse_charge_sentences, threshold
303
+ match, _ = get_fuzzy_match_score(
304
+ sentence, list(reverse_charge_sentences.keys()), threshold
224
305
  )
225
306
 
226
- return is_reverse_charge
307
+ if match:
308
+ return reverse_charge_sentences[match]
309
+
310
+ return False
227
311
 
228
312
 
229
313
  def find_matching_lineitem(new_lineitem: str, kvp_dict: dict, threshold=90):
@@ -236,46 +320,97 @@ def find_matching_lineitem(new_lineitem: str, kvp_dict: dict, threshold=90):
236
320
  Returns:
237
321
  str: The best matching 'Forto SLI' value from the dictionary.
238
322
  """
239
- new_lineitem = new_lineitem.upper()
240
-
241
323
  # Check if the new line item is already in the dictionary
242
324
  if new_lineitem in kvp_dict:
243
325
  return kvp_dict[new_lineitem]
244
326
 
245
327
  # Get the best fuzzy match score for the extracted line item
246
- best_match, _ = get_fuzzy_match_score(
247
- new_lineitem, list(kvp_dict.keys()), threshold
328
+ match, _ = get_fuzzy_match_score(
329
+ new_lineitem,
330
+ list(kvp_dict.keys()),
331
+ threshold,
248
332
  )
249
333
 
250
- return kvp_dict.get(best_match, None)
251
-
334
+ if match:
335
+ # find the code from the kvp_dict
336
+ return kvp_dict[match]
252
337
 
253
- def associate_forto_item_code(input_string, params):
254
- """
255
- Finds a match for the input string using fuzzy matching first, then embedding fallback.
338
+ return None
256
339
 
257
- 1. Tries to find a fuzzy match for input_string against the keys in
258
- mapping_data using RapidFuzz, requiring a score >= fuzzy_threshold.
259
- 2. If found, returns the corresponding value from mapping_data.
260
- 3. If not found above threshold, calls the embedding_fallback function.
261
340
 
341
+ async def associate_forto_item_code(line_item_data, params, partner_name=None):
342
+ """
343
+ Associates Forto item codes to a list of line item descriptions.
262
344
  Args:
263
- input_string: The string to find a match for.
264
- params: Parameters containing the lookup data and fuzzy threshold.
345
+ line_item_data (dict): A dictionary where keys are original descriptions and values are cleaned descriptions.
346
+ params (dict): Parameters containing lookup data and thresholds.
347
+ partner_name (str, optional): The name of the partner for context in matching. Defaults to None.
265
348
 
266
349
  Returns:
267
- The matched value (from fuzzy match or embedding), or None if no match found.
350
+ list: A list of dictionaries with 'description' and 'itemCode' keys.
268
351
  """
269
- # Get the Forto item code using fuzzy matching
270
- forto_item_code = find_matching_lineitem(
271
- new_lineitem=input_string,
272
- kvp_dict=params["lookup_data"]["item_code"], # TODO: Parse the KVP dictionary
273
- threshold=params["fuzzy_threshold_item_code"],
352
+
353
+ result = []
354
+ pending_line_items = {}
355
+ for desc, f_desc in line_item_data.items():
356
+ # Get the Forto item code using fuzzy matching
357
+ code = find_matching_lineitem(
358
+ new_lineitem=f_desc,
359
+ kvp_dict=params["lookup_data"]["item_code"],
360
+ threshold=params["fuzzy_threshold_item_code"],
361
+ )
362
+ if code:
363
+ result.append({"description": desc, "itemCode": code})
364
+ else:
365
+ pending_line_items[desc] = f_desc
366
+
367
+ # Batch API Call for Embedding lookups
368
+ if pending_line_items:
369
+ code_map = await fetch_line_item_codes(pending_line_items, partner_name, params)
370
+
371
+ for desc, f_desc in pending_line_items.items():
372
+ result.append(
373
+ {
374
+ "description": desc,
375
+ "itemCode": code_map.get(f_desc),
376
+ }
377
+ )
378
+
379
+ return result
380
+
381
+
382
+ async def fetch_line_item_codes(
383
+ pending_line_items: dict,
384
+ partner_name: str | None,
385
+ params: dict,
386
+ ):
387
+ """Returns: {original_description: mapped_code_or_None}"""
388
+ t_mode = (
389
+ find_matching_lineitem(
390
+ partner_name.upper(),
391
+ params["lookup_data"]["intermodal_partners"],
392
+ threshold=87,
393
+ )
394
+ if partner_name
395
+ else None
274
396
  )
275
397
 
276
- if forto_item_code is None:
277
- # 2. Fallback to embedding function if no good fuzzy match
278
- forto_item_code = get_tms_mappings(input_string, "line_items")
398
+ unique_descs = list(set(pending_line_items.values()))
399
+ logger.info(f"Mapping {len(unique_descs)} line items from Embedding API...")
279
400
 
280
- result = {"documentValue": input_string, "formattedValue": forto_item_code}
401
+ # Build API input map
402
+ api_input_map = {
403
+ desc: f"{t_mode} - {desc}" if t_mode else desc for desc in unique_descs
404
+ }
405
+
406
+ api_results = await get_tms_mappings(
407
+ input_list=list(api_input_map.values()),
408
+ embedding_type="line_items",
409
+ )
410
+
411
+ # Normalize response back to original descriptions
412
+ result = {
413
+ original_desc: api_results.get(api_desc)
414
+ for original_desc, api_desc in api_input_map.items()
415
+ }
281
416
  return result