data-science-document-ai 1.37.0__py3-none-any.whl → 1.51.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {data_science_document_ai-1.37.0.dist-info → data_science_document_ai-1.51.0.dist-info}/METADATA +3 -3
  2. data_science_document_ai-1.51.0.dist-info/RECORD +60 -0
  3. {data_science_document_ai-1.37.0.dist-info → data_science_document_ai-1.51.0.dist-info}/WHEEL +1 -1
  4. src/constants.py +6 -10
  5. src/docai.py +14 -5
  6. src/docai_processor_config.yaml +0 -56
  7. src/excel_processing.py +34 -13
  8. src/io.py +69 -1
  9. src/llm.py +10 -32
  10. src/pdf_processing.py +192 -57
  11. src/postprocessing/common.py +252 -590
  12. src/postprocessing/postprocess_partner_invoice.py +139 -89
  13. src/prompts/library/arrivalNotice/other/placeholders.json +70 -0
  14. src/prompts/library/arrivalNotice/other/prompt.txt +40 -0
  15. src/prompts/library/bookingConfirmation/evergreen/placeholders.json +17 -17
  16. src/prompts/library/bookingConfirmation/evergreen/prompt.txt +1 -0
  17. src/prompts/library/bookingConfirmation/hapag-lloyd/placeholders.json +18 -18
  18. src/prompts/library/bookingConfirmation/hapag-lloyd/prompt.txt +1 -1
  19. src/prompts/library/bookingConfirmation/maersk/placeholders.json +17 -17
  20. src/prompts/library/bookingConfirmation/maersk/prompt.txt +1 -1
  21. src/prompts/library/bookingConfirmation/msc/placeholders.json +17 -17
  22. src/prompts/library/bookingConfirmation/msc/prompt.txt +1 -1
  23. src/prompts/library/bookingConfirmation/oocl/placeholders.json +17 -17
  24. src/prompts/library/bookingConfirmation/oocl/prompt.txt +3 -1
  25. src/prompts/library/bookingConfirmation/other/placeholders.json +17 -17
  26. src/prompts/library/bookingConfirmation/other/prompt.txt +1 -1
  27. src/prompts/library/bookingConfirmation/yangming/placeholders.json +17 -17
  28. src/prompts/library/bookingConfirmation/yangming/prompt.txt +1 -1
  29. src/prompts/library/bundeskasse/other/placeholders.json +25 -25
  30. src/prompts/library/bundeskasse/other/prompt.txt +8 -6
  31. src/prompts/library/commercialInvoice/other/placeholders.json +125 -0
  32. src/prompts/library/commercialInvoice/other/prompt.txt +2 -1
  33. src/prompts/library/customsAssessment/other/placeholders.json +67 -16
  34. src/prompts/library/customsAssessment/other/prompt.txt +24 -37
  35. src/prompts/library/customsInvoice/other/placeholders.json +29 -20
  36. src/prompts/library/customsInvoice/other/prompt.txt +9 -4
  37. src/prompts/library/deliveryOrder/other/placeholders.json +79 -28
  38. src/prompts/library/deliveryOrder/other/prompt.txt +26 -40
  39. src/prompts/library/draftMbl/other/placeholders.json +33 -33
  40. src/prompts/library/draftMbl/other/prompt.txt +34 -44
  41. src/prompts/library/finalMbL/other/placeholders.json +34 -34
  42. src/prompts/library/finalMbL/other/prompt.txt +34 -44
  43. src/prompts/library/packingList/other/placeholders.json +98 -0
  44. src/prompts/library/packingList/other/prompt.txt +1 -1
  45. src/prompts/library/partnerInvoice/other/placeholders.json +2 -23
  46. src/prompts/library/partnerInvoice/other/prompt.txt +7 -18
  47. src/prompts/library/preprocessing/carrier/placeholders.json +0 -16
  48. src/prompts/library/shippingInstruction/other/placeholders.json +115 -0
  49. src/prompts/library/shippingInstruction/other/prompt.txt +28 -15
  50. src/setup.py +13 -61
  51. src/utils.py +189 -29
  52. data_science_document_ai-1.37.0.dist-info/RECORD +0 -59
  53. src/prompts/library/draftMbl/hapag-lloyd/prompt.txt +0 -44
  54. src/prompts/library/draftMbl/maersk/prompt.txt +0 -17
  55. src/prompts/library/finalMbL/hapag-lloyd/prompt.txt +0 -44
  56. src/prompts/library/finalMbL/maersk/prompt.txt +0 -17
@@ -1,426 +1,22 @@
1
1
  import asyncio
2
- import datetime
3
2
  import json
4
3
  import os
5
4
  import re
6
5
  from datetime import timezone
7
6
 
8
- import numpy as np
9
7
  import pandas as pd
10
- import requests
11
8
  from nltk.corpus import stopwords
12
9
  from rapidfuzz import process
13
- from vertexai.preview.language_models import TextEmbeddingModel
14
10
 
15
11
  from src.constants import formatting_rules
16
- from src.io import get_storage_client, logger
12
+ from src.io import logger
17
13
  from src.postprocessing.postprocess_partner_invoice import process_partner_invoice
18
14
  from src.prompts.prompt_library import prompt_library
19
- from src.tms import call_tms, set_tms_service_token
15
+ from src.utils import batch_fetch_all_mappings, get_tms_mappings
20
16
 
21
17
  tms_domain = os.environ["TMS_DOMAIN"]
22
18
 
23
19
 
24
- class EmbeddingsManager: # noqa: D101
25
- def __init__(self, params): # noqa: D107
26
- self.params = params
27
- self.embeddings_dict = {}
28
- self.embed_model = setup_embed_model()
29
- self.bucket = self.get_bucket_storage()
30
- self.embedding_folder = self.embed_model._model_id
31
- self.embedding_dimension = 768 # TODO: to be reduced
32
-
33
- def get_bucket_storage(self):
34
- """
35
- Retrieve the bucket storage object.
36
-
37
- Returns:
38
- The bucket storage object.
39
- """
40
- params = self.params
41
- storage_client = get_storage_client(params)
42
- bucket = storage_client.bucket(params["doc_ai_bucket_name"])
43
- return bucket
44
-
45
- def _find_most_similar_option(self, input_string, option_ids, option_embeddings):
46
- """
47
- Find the most similar option to the given input string based on embeddings.
48
-
49
- Args:
50
- model: The model used for generating embeddings.
51
- input_string (str): The input string to find the most similar option for.
52
- option_ids (list): The list of option IDs.
53
- option_embeddings (np.ndarray): The embeddings of the options.
54
-
55
- Returns:
56
- The ID of the most similar option.
57
- """
58
- try:
59
- input_embedding = self.embed_model.get_embeddings(
60
- [input_string], output_dimensionality=self.embedding_dimension
61
- )[0].values
62
- similarities = np.dot(option_embeddings, input_embedding)
63
- idx = np.argmax(similarities)
64
- return option_ids[idx]
65
- except Exception as e:
66
- logger.error(f"Embeddings error: {e}")
67
- return None
68
-
69
- def load_embeddings(self):
70
- """
71
- Load embeddings for container types, ports, and terminals.
72
-
73
- Returns:
74
- None
75
- """
76
- for data_field in [
77
- "container_types",
78
- "ports",
79
- "terminals",
80
- "depots",
81
- "item_codes_label",
82
- ]:
83
- self.embeddings_dict[data_field] = load_embed_by_data_field(
84
- self.bucket,
85
- f"{self.embedding_folder}/{data_field}/output",
86
- self.embedding_dimension,
87
- )
88
-
89
- async def update_embeddings(self):
90
- """
91
- Update the embeddings dictionary.
92
-
93
- Returns:
94
- dict: The updated embeddings dictionary with the following keys:
95
- - "container_types": A tuple containing the container types and their embeddings.
96
- - "ports": A tuple containing the ports and their embeddings.
97
- - "terminals": A tuple containing the terminal IDs and their embeddings.
98
- """
99
- # Update embeddings dict here.
100
- # Ensure this method is async if you're calling async operations.
101
- set_tms_service_token()
102
- (
103
- container_types,
104
- container_type_embeddings,
105
- ) = self.setup_container_type_embeddings(
106
- *self.embeddings_dict.get("container_types", ([], []))
107
- )
108
-
109
- ports, port_embeddings = self.setup_ports_embeddings(
110
- *self.embeddings_dict.get("ports", ([], []))
111
- )
112
-
113
- # Setup terminal embeddings
114
- # Since retrieving terminal attributes requires calling TMS' api to extract terminals by each port,
115
- # we only do it for new ports.
116
- prev_port_ids, _ = self.embeddings_dict.get("ports", ([], []))
117
- added_port_ids = [port for port in ports if port not in prev_port_ids]
118
- if added_port_ids:
119
- terminal_ids, terminal_embeddings = self.setup_terminal_embeddings(
120
- added_port_ids
121
- )
122
- else:
123
- terminal_ids, terminal_embeddings = self.embeddings_dict["terminals"]
124
-
125
- depot_names, depot_embeddings = self.setup_depot_embeddings(
126
- *self.embeddings_dict.get("depots", ([], []))
127
- )
128
-
129
- item_code_names, item_code_embeddings = self.setup_item_code_embeddings(
130
- *self.embeddings_dict.get("item_codes_label", ([], []))
131
- )
132
-
133
- self.embeddings_dict = {
134
- "container_types": (container_types, container_type_embeddings),
135
- "ports": (ports, port_embeddings),
136
- "terminals": (terminal_ids, terminal_embeddings),
137
- "depots": (depot_names, depot_embeddings),
138
- "item_codes_label": (item_code_names, item_code_embeddings),
139
- }
140
- return self.embeddings_dict
141
-
142
- def batch_embed(self, option_strings: list[dict], suffix: str):
143
- """
144
- Compute embeddings for a batch of option strings and uploads them to a cloud storage bucket.
145
-
146
- Args:
147
- option_strings (list): A list of option strings to compute embeddings for.
148
- suffix (str): A suffix to be used in the storage path for the embeddings:
149
- input & output will be stored under "{bucket}/{parent_folder}/{suffix}/"
150
-
151
- Returns:
152
- tuple: A tuple containing the option IDs and embeddings.
153
- """
154
- now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
155
- input_path = f"{self.embedding_folder}/{suffix}/input/{now}.jsonl"
156
- blob = self.bucket.blob(input_path)
157
-
158
- # Convert each dictionary to a JSON string and join them with newlines
159
- option_strings = [
160
- {**option, "task_type": "SEMANTIC_SIMILARITY", "output_dimensionality": 256}
161
- for option in option_strings
162
- ]
163
- jsonl_string = "\n".join(json.dumps(d) for d in option_strings)
164
-
165
- # Convert the combined string to bytes
166
- jsonl_bytes = jsonl_string.encode("utf-8")
167
-
168
- # Upload the bytes to the blob
169
- blob.upload_from_string(jsonl_bytes, content_type="text/plain")
170
-
171
- # Compute embeddings for the options
172
- embedding_path = f"{self.embedding_folder}/{suffix}/output"
173
- assert len(option_strings) <= 30000 # Limit for batch embedding
174
- batch_resp = self.embed_model.batch_predict(
175
- dataset=f"gs://{self.bucket.name}/{input_path}", # noqa
176
- destination_uri_prefix=f"gs://{self.bucket.name}/{embedding_path}", # noqa
177
- )
178
-
179
- if batch_resp.state.name != "JOB_STATE_SUCCEEDED":
180
- logger.warning(
181
- f"Batch prediction job failed with state {batch_resp.state.name}"
182
- )
183
- else:
184
- logger.info(f"Embeddings for {suffix} computed successfully.")
185
-
186
- option_ids, option_embeddings = load_embed_by_data_field(
187
- self.bucket, embedding_path, self.embedding_dimension
188
- )
189
- return option_ids, option_embeddings
190
-
191
- def setup_container_type_embeddings(
192
- self, computed_container_type_ids, computed_container_type_embeddings
193
- ):
194
- """
195
- Set up container type embeddings.
196
-
197
- Args:
198
- computed_container_type_ids (list): The list of already computed container type IDs.
199
- computed_container_type_embeddings (list): The list of already computed container type embeddings.
200
-
201
- Returns:
202
- tuple: A tuple containing the updated container type IDs and embeddings.
203
- """
204
- url = (
205
- f"https://tms.forto.{tms_domain}/api/transport-units/api/types/list" # noqa
206
- )
207
- resp = call_tms(requests.get, url)
208
- container_types = resp.json()
209
-
210
- container_attribute_strings = [
211
- dict(
212
- title=container_type["code"],
213
- content=" | ".join(
214
- [container_type["code"]]
215
- + [
216
- f"{v}"
217
- for k, v in container_type["containerAttributes"].items()
218
- if k in ["isoSizeType", "isoTypeGroup", "containerCategory"]
219
- ]
220
- + [container_type.get(k, "") for k in ["displayName", "notes"]]
221
- ),
222
- )
223
- for container_type in container_types
224
- if container_type["type"] == "Container"
225
- and container_type["code"] not in computed_container_type_ids
226
- and container_type.get("containerAttributes") is not None
227
- ]
228
- if not container_attribute_strings:
229
- logger.info("No new container types found.")
230
- return computed_container_type_ids, computed_container_type_embeddings
231
-
232
- logger.info("Computing embeddings for container types...")
233
- container_type_ids, container_type_embeddings = self.batch_embed(
234
- container_attribute_strings, "container_types"
235
- )
236
- return container_type_ids, container_type_embeddings
237
-
238
- def setup_ports_embeddings(self, computed_port_ids, computed_port_embeddings):
239
- """
240
- Set up port embeddings.
241
-
242
- Steps:
243
- - Retrieve active ports from the TMS API
244
- - Compute embeddings for new tradelane-enabled ports
245
- - Return ALL port IDs and embeddings.
246
-
247
- Args:
248
- computed_port_ids (list): The list of previously computed port IDs.
249
- computed_port_embeddings (list): The list of previously computed port embeddings.
250
-
251
- Returns:
252
- tuple: A tuple containing ALL port IDs and embeddings.
253
- """
254
- url = f"https://tms.forto.{tms_domain}/api/transport-network/api/ports?pageSize=1000000&status=active" # noqa
255
- resp = call_tms(requests.get, url)
256
- resp_json = resp.json()
257
- if len(resp_json["data"]) != resp_json["_paging"]["totalRecords"]:
258
- logger.error("Not all ports were returned.")
259
-
260
- new_sea_ports = [
261
- port
262
- for port in resp_json["data"]
263
- if "sea" in port["modes"] and port["id"] not in computed_port_ids
264
- ]
265
- if not new_sea_ports:
266
- logger.info("No new ports found.")
267
- return computed_port_ids, computed_port_embeddings
268
-
269
- port_attribute_strings = [
270
- dict(
271
- title=port["id"],
272
- content=" ".join(
273
- [
274
- "port for shipping",
275
- add_text_without_space(
276
- port["name"]
277
- ), # for cases like QUINHON - Quinhon
278
- port["id"],
279
- ]
280
- ),
281
- )
282
- for port in new_sea_ports
283
- ]
284
-
285
- logger.info("Computing embeddings for ports.")
286
- port_ids, port_embeddings = self.batch_embed(port_attribute_strings, "ports")
287
- return port_ids, port_embeddings
288
-
289
- def setup_depot_embeddings(self, computed_depot_names, computed_depot_embeddings):
290
- """
291
- Set up depot embeddings.
292
-
293
- Steps:
294
- - Retrieve active depot from the TMS API
295
- - Compute embeddings for new tdepot
296
- - Return ALL depot names and embeddings.
297
-
298
- Args:
299
- computed_depot_names (list): The list of previously computed depot names.
300
- computed_depot_embeddings (list): The list of previously computed depot embeddings.
301
-
302
- Returns:
303
- tuple: A tuple containing ALL depot names and embeddings.
304
- """
305
- url = f"https://tms.forto.{tms_domain}/api/transport-network/api/depots?pageSize=1000000" # noqa
306
- resp = call_tms(requests.get, url)
307
- resp_json = resp.json()
308
-
309
- new_depots = [
310
- depot
311
- for depot in resp_json["data"]
312
- if depot["name"] not in computed_depot_names
313
- ]
314
- if not new_depots:
315
- logger.info("No new depots found.")
316
- return computed_depot_names, computed_depot_embeddings
317
-
318
- depot_attribute_strings = [
319
- dict(
320
- title=depot["name"],
321
- content=" | ".join(
322
- [
323
- "depot",
324
- "name - " + depot["name"],
325
- "address - " + depot["address"]["fullAddress"],
326
- ]
327
- ),
328
- )
329
- for depot in resp_json["data"]
330
- ]
331
-
332
- logger.info("Computing embeddings for depots.")
333
- depot_names, depot_embeddings = self.batch_embed(
334
- depot_attribute_strings, "depots"
335
- )
336
- return depot_names, depot_embeddings
337
-
338
- def setup_terminal_embeddings(self, added_port_ids):
339
- """
340
- Set up terminal embeddings for `added_port_ids`, using `model`, uploaded to `bucket`.
341
-
342
- Args:
343
- added_port_ids (list): A list of added port IDs.
344
-
345
- Returns:
346
- tuple: A tuple containing the ALL terminal IDs and terminal embeddings.
347
- Not just for the added port IDs.
348
- """
349
- terminal_attibute_strings = [
350
- setup_terminal_attributes(port_id) for port_id in added_port_ids
351
- ]
352
- terminal_attibute_strings = sum(terminal_attibute_strings, [])
353
- if not terminal_attibute_strings:
354
- logger.info("No new terminals found.")
355
- return [], np.array([])
356
-
357
- terminal_ids, terminal_embeddings = self.batch_embed(
358
- terminal_attibute_strings, "terminals"
359
- )
360
- return terminal_ids, terminal_embeddings
361
-
362
- def setup_item_code_embeddings(
363
- self, computed_item_code_names, computed_item_code_embeddings
364
- ):
365
- """
366
- Set up item_code embeddings.
367
-
368
- Steps:
369
- - Retrieve active item_code from the TMS API
370
- - Compute embeddings for new titem_code
371
- - Return ALL item_code names and embeddings.
372
-
373
- Args:
374
- computed_item_code_names (list): The list of previously computed item_code names.
375
- computed_item_code_embeddings (list): The list of previously computed item_code embeddings.
376
-
377
- Returns:
378
- tuple: A tuple containing ALL item_code names and embeddings.
379
- """
380
- url = f"https://tms.forto.{tms_domain}/api/catalog/item-codes?transportTypes=fcl&pageSize=1000000" # noqa
381
- resp = call_tms(requests.get, url)
382
- resp_json = resp.json()
383
-
384
- new_item_codes = [
385
- item_code
386
- for item_code in resp_json["results"]
387
- if item_code["id"] not in computed_item_code_names
388
- ]
389
- if not new_item_codes:
390
- logger.info("No new item_codes found.")
391
- return computed_item_code_names, computed_item_code_embeddings
392
-
393
- item_code_attribute_strings = [
394
- dict(
395
- title=item_code["id"],
396
- content=" | ".join(
397
- [
398
- item_code["id"],
399
- item_code["label"],
400
- ]
401
- ),
402
- )
403
- for item_code in resp_json["results"]
404
- ]
405
-
406
- logger.info("Computing embeddings for item_codes.")
407
- item_code_names, item_code_embeddings = self.batch_embed(
408
- item_code_attribute_strings, "item_codes_label"
409
- )
410
- return item_code_names, item_code_embeddings
411
-
412
-
413
- def setup_embed_model():
414
- """
415
- Set up and return a text embedding model.
416
-
417
- Returns:
418
- TextEmbeddingModel: The initialized text embedding model.
419
- """
420
- model = TextEmbeddingModel.from_pretrained("text-multilingual-embedding-002")
421
- return model
422
-
423
-
424
20
  def convert_container_number(container_number):
425
21
  """
426
22
  Convert a container number to ISO standard.
@@ -488,16 +84,16 @@ def clean_shipment_id(shipment_id):
488
84
  """
489
85
  if not shipment_id:
490
86
  return
491
- # '#S123456@-1' -> 'S123456'
492
- # Find the pattern of a shipment ID that starts with 'S' followed by 5 to 7 digits
493
- match = re.findall(r"S\d{5,7}", shipment_id)
87
+ # '#S1234565@-1' -> 'S1234565'
88
+ # Find the pattern of a shipment ID that starts with 'S' followed by 7 to 8 digits
89
+ match = re.findall(r"S\d{6,8}", shipment_id)
494
90
  stripped_value = match[0] if match else None
495
91
 
496
92
  if not stripped_value:
497
93
  return None
498
94
 
499
95
  # Check if length is valid (should be either 7 or 8)
500
- if len(stripped_value) not in (6, 7, 8):
96
+ if len(stripped_value) not in (7, 8, 9):
501
97
  return None
502
98
 
503
99
  return stripped_value
@@ -538,9 +134,12 @@ def extract_number(data_field_value):
538
134
  formatted_value: string
539
135
 
540
136
  """
137
+ # Remove container size pattern like 20FT, 40HC, etc from 1 x 40HC
138
+ value = remove_unwanted_patterns(data_field_value)
139
+
541
140
  formatted_value = ""
542
- for c in data_field_value:
543
- if c.isnumeric() or c in [",", "."]:
141
+ for c in value:
142
+ if c.isnumeric() or c in [",", ".", "-"]:
544
143
  formatted_value += c
545
144
 
546
145
  # First and last characters should not be [",", "."]
@@ -570,106 +169,6 @@ def extract_string(data_field_value):
570
169
  return formatted_value if formatted_value not in ["''", ""] else None
571
170
 
572
171
 
573
- def extract_google_embed_resp(prediction_string, embedding_dimension):
574
- """
575
- Extract relevant information from the Google Embed API response.
576
-
577
- Args:
578
- prediction_string (str): The prediction string returned by the Google Embed API.
579
-
580
- Returns:
581
- dict: A dictionary containing the extracted information.
582
- - _id (str): The title of the instance.
583
- - attr_text (str): The content of the instance.
584
- - embedding (list): The embeddings values from the predictions.
585
-
586
- """
587
- res = json.loads(prediction_string)
588
- return dict(
589
- _id=res["instance"]["title"],
590
- attr_text=res["instance"]["content"],
591
- embedding=res["predictions"][0]["embeddings"]["values"][:embedding_dimension],
592
- )
593
-
594
-
595
- def load_embed_by_data_field(bucket, embedding_path, embedding_dimension):
596
- """
597
- Load embeddings by data field from the specified bucket and embedding path.
598
-
599
- Args:
600
- bucket (Bucket): The bucket object representing the storage bucket.
601
- embedding_path (str): The path to the embeddings in the bucket (different by data_field).
602
-
603
- Returns:
604
- tuple: A tuple containing the option IDs and option embeddings.
605
- - option_ids (list): A list of option IDs.
606
- - option_embeddings (ndarray): An array of option embeddings.
607
- """
608
- # Retrieve the embeddings from the output files
609
- blobs = bucket.list_blobs(prefix=embedding_path)
610
- all_blob_data = []
611
- for blob in blobs:
612
- blob_data = blob.download_as_bytes().decode("utf-8").splitlines()
613
- embeddings = [
614
- extract_google_embed_resp(data, embedding_dimension) for data in blob_data
615
- ]
616
- all_blob_data.extend(embeddings)
617
- option_ids = [embed["_id"] for embed in all_blob_data]
618
- option_embeddings = np.stack([embed["embedding"] for embed in all_blob_data])
619
- return option_ids, option_embeddings
620
-
621
-
622
- def setup_terminal_attributes(port_id: str):
623
- """
624
- Retrieve and format the attributes of active terminals at a given port.
625
-
626
- Args:
627
- port_id (str): The ID of the port.
628
-
629
- Returns:
630
- list: A list of dictionaries containing the formatted attributes of active terminals.
631
- Each dictionary has the following keys:
632
- - title: The terminal's short code.
633
- - content: A string representation of the terminal's attributes, including its name,
634
- searchable name, and full address.
635
- """
636
- url = f"https://gateway.forto.{tms_domain}/api/transport-network/api/ports/{port_id}/terminals/list" # noqa
637
- resp = call_tms(requests.get, url)
638
- terminals = resp.json()
639
- if len(terminals) == 0:
640
- return []
641
- active_terminals = [term for term in terminals if term["isActive"]]
642
- if len(active_terminals) == 0:
643
- logger.warning(f"No active terminals found at port {port_id}.")
644
- return []
645
-
646
- terminal_attibute_strings = [
647
- dict(
648
- title=term["name"],
649
- content=" | ".join(
650
- [
651
- "shipping terminal",
652
- "code - " + term["terminalShortCode"],
653
- "name - " + modify_terminal_name(term["searchableName"]),
654
- "country - " + term["address"]["country"],
655
- ]
656
- ),
657
- )
658
- for term in active_terminals
659
- ]
660
- return terminal_attibute_strings
661
-
662
-
663
- def modify_terminal_name(text):
664
- # Find the first occurrence of a word starting with 'K' followed by a number
665
- # and replace it with 'KAAI' - meaning Quay in Dutch
666
- match = re.search(r"K(\d+)", text)
667
- if match:
668
- # Append "KAAI" followed by the number if a match is found
669
- text += f" KAAI {match.group(1)}"
670
- return text
671
-
672
-
673
172
  def remove_none_values(d):
674
173
  if isinstance(d, dict):
675
174
  # Create a new dictionary to store non-None values
@@ -731,25 +230,6 @@ def convert_invoice_type(data_field_value, params):
731
230
  return None
732
231
 
733
232
 
734
- def validate_reverse_charge_value(reverse_charge_sentence_value):
735
- """
736
- Validates the reverseChargeSentence value before assigning to line items.
737
-
738
- Args:
739
- reverse_charge_sentence_value (bool): The formatted value of reverseChargeSentence (True or False).
740
-
741
- Returns:
742
- bool: The validated reverseChargeSentence value.
743
- """
744
- if isinstance(reverse_charge_sentence_value, bool):
745
- return reverse_charge_sentence_value
746
- else:
747
- logger.warning(
748
- f"Invalid reverseChargeSentence value: {reverse_charge_sentence_value}. Defaulting to False."
749
- )
750
- return False
751
-
752
-
753
233
  # Function to create KVP dictionary using apply method
754
234
  def create_kvp_dictionary(df_raw: pd.DataFrame):
755
235
  """Create a key-value pair dictionary from the given DataFrame.
@@ -842,6 +322,14 @@ def remove_unwanted_patterns(lineitem: str):
842
322
  # Remove "HIGH CUBE"
843
323
  lineitem = lineitem.replace("HIGH CUBE", "")
844
324
 
325
+ # Remove container size e.g., 20FT, 40HC, etc.
326
+ pattern = [
327
+ f"{s}{t}"
328
+ for s in ("20|22|40|45".split("|"))
329
+ for t in ("FT|HC|DC|HD|GP|OT|RF|FR|TK|DV".split("|"))
330
+ ]
331
+ lineitem = re.sub(r"|".join(pattern), "", lineitem, flags=re.IGNORECASE).strip()
332
+
845
333
  return lineitem
846
334
 
847
335
 
@@ -872,62 +360,92 @@ def clean_item_description(lineitem: str, remove_numbers: bool = True):
872
360
  # Remove the currency codes
873
361
  lineitem = re.sub(currency_codes_pattern, "", lineitem, flags=re.IGNORECASE)
874
362
 
363
+ # remove other patterns
364
+ lineitem = remove_unwanted_patterns(lineitem)
365
+
875
366
  # Remove numbers from the line item
876
367
  if (
877
368
  remove_numbers
878
369
  ): # Do not remove numbers for the reverse charge sentence as it contains Article number
879
370
  lineitem = re.sub(r"\d+", "", lineitem)
880
371
 
881
- # remove other patterns
882
- lineitem = remove_unwanted_patterns(lineitem)
883
-
884
372
  # remove special chars
885
373
  lineitem = re.sub(r"[^A-Za-z0-9\s]", " ", lineitem).strip()
886
374
 
375
+ # Remove x from lineitem like 10 x
376
+ lineitem = re.sub(r"\b[xX]\b", " ", lineitem).strip()
377
+
887
378
  return re.sub(r"\s{2,}", " ", lineitem).strip()
888
379
 
889
380
 
890
381
  async def format_label(
891
- entity_k, entity_value, embed_manager, document_type_code, params
382
+ entity_k,
383
+ entity_value,
384
+ document_type_code,
385
+ params,
386
+ mime_type,
387
+ container_map,
388
+ terminal_map,
389
+ depot_map,
892
390
  ):
893
391
  llm_client = params["LlmClient"]
894
392
  if isinstance(entity_value, dict): # if it's a nested entity
895
393
  format_tasks = [
896
- format_label(sub_k, sub_v, embed_manager, document_type_code, params)
394
+ format_label(
395
+ sub_k,
396
+ sub_v,
397
+ document_type_code,
398
+ params,
399
+ mime_type,
400
+ container_map,
401
+ terminal_map,
402
+ depot_map,
403
+ )
897
404
  for sub_k, sub_v in entity_value.items()
898
405
  ]
899
406
  return entity_k, {k: v for k, v in await asyncio.gather(*format_tasks)}
900
407
  if isinstance(entity_value, list):
901
408
  format_tasks = await asyncio.gather(
902
409
  *[
903
- format_label(entity_k, sub_v, embed_manager, document_type_code, params)
410
+ format_label(
411
+ entity_k,
412
+ sub_v,
413
+ document_type_code,
414
+ params,
415
+ mime_type,
416
+ container_map,
417
+ terminal_map,
418
+ depot_map,
419
+ )
904
420
  for sub_v in entity_value
905
421
  ]
906
422
  )
907
423
  return entity_k, [v for _, v in format_tasks]
424
+
425
+ if mime_type == "application/pdf":
426
+ if isinstance(entity_value, tuple):
427
+ page = entity_value[1]
428
+ entity_value = entity_value[0]
429
+ else:
430
+ page = -1
431
+
908
432
  entity_key = entity_k.lower()
909
- embeddings_dict = embed_manager.embeddings_dict
910
433
  formatted_value = None
911
434
 
912
435
  if entity_key.startswith("port"):
913
436
  formatted_value = await get_port_code_ai(
914
- entity_value, llm_client, embed_manager, *embeddings_dict["ports"]
437
+ entity_value, llm_client, doc_type=document_type_code
915
438
  )
439
+
916
440
  elif (entity_key == "containertype") or (entity_key == "containersize"):
917
- formatted_value = embed_manager._find_most_similar_option(
918
- "container type " + entity_value,
919
- *embeddings_dict["container_types"],
920
- )
441
+ formatted_value = container_map.get(entity_value)
442
+
921
443
  elif check_formatting_rule(entity_k, document_type_code, "terminal"):
922
- formatted_value = embed_manager._find_most_similar_option(
923
- "shipping terminal " + str(entity_value),
924
- *embeddings_dict["terminals"],
925
- )
444
+ formatted_value = terminal_map.get(entity_value)
445
+
926
446
  elif check_formatting_rule(entity_k, document_type_code, "depot"):
927
- formatted_value = embed_manager._find_most_similar_option(
928
- "depot " + str(entity_value),
929
- *embeddings_dict["depots"],
930
- )
447
+ formatted_value = depot_map.get(entity_value)
448
+
931
449
  elif entity_key.startswith(("eta", "etd", "duedate", "issuedate", "servicedate")):
932
450
  try:
933
451
  cleaned_data_field_value = clean_date_string(entity_value)
@@ -947,21 +465,26 @@ async def format_label(
947
465
  except ValueError as e:
948
466
  logger.info(f"ParserError: {e}")
949
467
 
950
- elif entity_key in ["invoicenumber", "creditnoteinvoicenumber"]:
468
+ elif (
469
+ entity_key in ["invoicenumber", "creditnoteinvoicenumber"]
470
+ and document_type_code == "bundeskasse"
471
+ ):
951
472
  formatted_value = clean_invoice_number(entity_value)
952
473
 
953
474
  elif entity_key in ("shipmentid", "partnerreference"):
954
- # Clean the shipment ID to match Forto's standard (starts with 'S' followed by 5 to 7 digits)
475
+ # Clean the shipment ID to match Forto's standard (starts with 'S' followed by 7 or 8 digits)
955
476
  formatted_value = clean_shipment_id(entity_value)
956
477
 
957
478
  elif entity_key == "containernumber":
958
479
  # Remove all non-alphanumeric characters like ' ', '-', etc.
959
480
  formatted_value = convert_container_number(entity_value)
960
481
 
961
- elif (
962
- document_type_code in ["finalMbL", "draftMbl"] and entity_key == "measurements"
482
+ elif any(
483
+ numeric_indicator in entity_key
484
+ for numeric_indicator in ["measurements", "weight"]
963
485
  ):
964
- formatted_value = decimal_convertor(extract_number(entity_value))
486
+ formatted_value = extract_number(entity_value)
487
+
965
488
  elif any(
966
489
  packaging_type in entity_key
967
490
  for packaging_type in ["packagingtype", "packagetype", "currency"]
@@ -977,11 +500,19 @@ async def format_label(
977
500
  elif "reversechargesentence" in entity_key:
978
501
  formatted_value = clean_item_description(entity_value, remove_numbers=False)
979
502
 
503
+ elif "quantity" in entity_key:
504
+ if document_type_code in ["partnerInvoice", "customsInvoice", "bundeskasse"]:
505
+ # For partner invoice, quantity can be mentioned as whole number
506
+ # Apply decimal convertor for 46,45 --> 46.45 but not for 1.000 --> 1000
507
+ formatted_value = decimal_convertor(
508
+ extract_number(entity_value), quantity=True
509
+ )
510
+ else:
511
+ formatted_value = extract_number(entity_value)
512
+
980
513
  elif any(
981
514
  numeric_indicator in entity_key
982
515
  for numeric_indicator in [
983
- "weight",
984
- "quantity",
985
516
  "value",
986
517
  "amount",
987
518
  "price",
@@ -999,21 +530,21 @@ async def format_label(
999
530
  "documentValue": entity_value,
1000
531
  "formattedValue": formatted_value,
1001
532
  }
533
+ if mime_type == "application/pdf":
534
+ result["page"] = page
535
+
1002
536
  return entity_k, result
1003
537
 
1004
538
 
1005
- async def get_port_code_ai(
1006
- port: str, llm_client, embed_manager, port_ids, port_embeddings
1007
- ):
1008
- port_llm = await get_port_code_llm(port, llm_client)
539
+ async def get_port_code_ai(port: str, llm_client, doc_type=None):
540
+ """Get port code using AI model."""
541
+ port_llm = await get_port_code_llm(port, llm_client, doc_type=doc_type)
1009
542
 
1010
- if port_llm in port_ids:
1011
- return port_llm
1012
- port_text = f"port for shipping {port}"
1013
- return embed_manager._find_most_similar_option(port_text, port_ids, port_embeddings)
543
+ result = await get_tms_mappings(port, "ports", port_llm)
544
+ return result.get(port, None)
1014
545
 
1015
546
 
1016
- async def get_port_code_llm(port: str, llm_client):
547
+ async def get_port_code_llm(port: str, llm_client, doc_type=None):
1017
548
  if (
1018
549
  "postprocessing" in prompt_library.library.keys()
1019
550
  and "port_code" in prompt_library.library["postprocessing"].keys()
@@ -1040,7 +571,7 @@ async def get_port_code_llm(port: str, llm_client):
1040
571
  }
1041
572
 
1042
573
  response = await llm_client.get_unified_json_genai(
1043
- prompt, response_schema=response_schema, model="chatgpt"
574
+ prompt, response_schema=response_schema, model="chatgpt", doc_type=doc_type
1044
575
  )
1045
576
  try:
1046
577
  mapped_port = response["port"]
@@ -1050,7 +581,7 @@ async def get_port_code_llm(port: str, llm_client):
1050
581
  return None
1051
582
 
1052
583
 
1053
- def decimal_convertor(value):
584
+ def decimal_convertor(value, quantity=False):
1054
585
  """Convert EU values to English values."""
1055
586
  if value is None:
1056
587
  return None
@@ -1058,30 +589,118 @@ def decimal_convertor(value):
1058
589
  # Remove spaces
1059
590
  value = value.strip().replace(" ", "")
1060
591
 
1061
- # Convert comma to dot for decimal point (e.g., 4.123,45 -> 4123.45)
1062
- if re.match(r"^\d{1,3}(\.\d{3})*,\d{1,2}$", value):
1063
- value = value.replace(".", "").replace(",", ".")
592
+ # Check "-" and remove it for processing
593
+ is_negative, value = (True, value[1:]) if value.startswith("-") else (False, value)
594
+
595
+ if not quantity:
596
+ # Convert comma to dot for decimal point (e.g., 4.123,45 -> 4123.45)
597
+ if re.match(r"^\d{1,3}(\.\d{3})*,\d{1,2}$", value):
598
+ value = value.replace(".", "").replace(",", ".")
599
+
600
+ # European style integer with thousand separator: 2.500
601
+ elif re.match(r"^\d{1,3}(\.\d{3})+$", value):
602
+ value = value.replace(".", "")
603
+
604
+ # Format english values as well for consistency (e.g., 4,123.45 -> 4123.45)
605
+ elif re.match(r"^\d{1,3}(,\d{3})*\.\d{1,2}$", value):
606
+ value = value.replace(",", "")
1064
607
 
1065
- # European style integer with thousand separator: 2.500
1066
- elif re.match(r"^\d{1,3}(\.\d{3})+$", value):
1067
- value = value.replace(".", "")
608
+ # English style integer with thousand separator: 2,500
609
+ elif re.match(r"^\d{1,3}(,\d{3})+$", value):
610
+ value = value.replace(",", "")
1068
611
 
1069
- # Format english values as well for consistency (e.g., 4,123.45 -> 4123.45)
1070
- elif re.match(r"^\d{1,3}(,\d{3})*\.\d{1,2}$", value):
1071
- value = value.replace(",", "")
612
+ # Just replace comma decimals with dot (e.g., 65,45 -> 65.45)
613
+ if re.match(r"^\d+,\d{1,2}$", value):
614
+ value = value.replace(",", ".")
1072
615
 
1073
- # English style integer with thousand separator: 2,500
1074
- elif re.match(r"^\d{1,3}(,\d{3})+$", value):
1075
- value = value.replace(",", "")
616
+ # If there are more than 3 0s after decimal point, consider only 2 decimal points (e.g., 8.500000 -> 8.50)
617
+ elif re.match(r"^\d+\.\d{3,}$", value):
618
+ value = value[: value.index(".") + 3]
1076
619
 
1077
- # Just replace comma decimals with dot (e.g., 65,45 -> 65.45)
1078
- elif re.match(r"^\d+,\d{1,2}$", value):
1079
- value = value.replace(",", ".")
620
+ else: # quantity=True only last two
621
+ # Just replace comma decimals with dot (e.g., 65,45 -> 65.45)
622
+ if re.match(r"^\d+,\d{1,2}$", value):
623
+ value = value.replace(",", ".")
624
+
625
+ # If there are more than 3 0s after decimal point, consider only 2 decimal points (e.g., 8.500000 -> 8.50)
626
+ elif re.match(r"^\d+\.\d{3,}$", value):
627
+ value = value[: value.index(".") + 3]
628
+
629
+ # Re-add negative sign if applicable
630
+ value = "-" + value if is_negative else value
1080
631
 
1081
632
  return value
1082
633
 
1083
634
 
1084
- async def format_all_entities(result, embed_manager, document_type_code, params):
635
+ async def collect_mapping_requests(entity_value, document_type_code):
636
+ """Collect all unique container types, terminals, and depots from the entity value."""
637
+ # Sets to store unique values
638
+ container_types = set()
639
+ terminals = set()
640
+ depots = set()
641
+
642
+ def walk(key, value):
643
+ key_lower = key.lower()
644
+
645
+ # nested dict
646
+ if isinstance(value, dict):
647
+ for k, v in value.items():
648
+ walk(k, v)
649
+
650
+ # list of values
651
+ elif isinstance(value, list):
652
+ for item in value:
653
+ walk(key, item)
654
+
655
+ # leaf node
656
+ else:
657
+ if key_lower in ("containertype", "containersize"):
658
+ # Take only "20DV" from ('20DV', 0) if it's a tuple
659
+ container_types.add(value[0]) if isinstance(
660
+ value, tuple
661
+ ) else container_types.add(value)
662
+
663
+ elif check_formatting_rule(key, document_type_code, "terminal"):
664
+ terminals.add(value[0]) if isinstance(value, tuple) else terminals.add(
665
+ value
666
+ )
667
+
668
+ elif check_formatting_rule(key, document_type_code, "depot"):
669
+ depots.add(value[0]) if isinstance(value, tuple) else depots.add(value)
670
+
671
+ walk("root", entity_value)
672
+
673
+ return container_types, terminals, depots
674
+
675
+
676
+ async def format_all_labels(entity_data, document_type_code, params, mime_type):
677
+ """Format all labels in the entity data using cached mappings."""
678
+ # Collect all mapping values needed
679
+ container_req, terminal_req, depot_req = await collect_mapping_requests(
680
+ entity_data, document_type_code
681
+ )
682
+
683
+ # Batch fetch mappings
684
+ container_map, terminal_map, depot_map = await batch_fetch_all_mappings(
685
+ container_req, terminal_req, depot_req
686
+ )
687
+
688
+ # Format labels using cached mappings
689
+ _, result = await format_label(
690
+ "root",
691
+ entity_data,
692
+ document_type_code,
693
+ params,
694
+ mime_type,
695
+ container_map,
696
+ terminal_map,
697
+ depot_map,
698
+ )
699
+
700
+ return _, result
701
+
702
+
703
+ async def format_all_entities(result, document_type_code, params, mime_type):
1085
704
  """Format the entity values in the result dictionary."""
1086
705
  # Since we treat `customsInvoice` same as `partnerInvoice`
1087
706
  document_type_code = (
@@ -1096,15 +715,13 @@ async def format_all_entities(result, embed_manager, document_type_code, params)
1096
715
  return {}
1097
716
 
1098
717
  # Format all entities recursively
1099
- _, aggregated_data = await format_label(
1100
- None, result, embed_manager, document_type_code, params
718
+ _, aggregated_data = await format_all_labels(
719
+ result, document_type_code, params, mime_type
1101
720
  )
1102
721
 
1103
722
  # Process partner invoice on lineitem mapping and reverse charge sentence
1104
723
  if document_type_code in ["partnerInvoice", "bundeskasse"]:
1105
- process_partner_invoice(
1106
- params, aggregated_data, embed_manager, document_type_code
1107
- )
724
+ await process_partner_invoice(params, aggregated_data, document_type_code)
1108
725
 
1109
726
  logger.info("Data Extraction completed successfully")
1110
727
  return aggregated_data
@@ -1134,3 +751,48 @@ def remove_stop_words(lineitem: str):
1134
751
  .upper()
1135
752
  .strip()
1136
753
  )
754
+
755
+
756
+ def llm_prediction_to_tuples(llm_prediction, number_of_pages=-1, page_number=None):
757
+ """Convert LLM prediction dictionary to tuples of (value, page_number)."""
758
+ # If only 1 page, simply pair each value with page number 0
759
+ if number_of_pages == 1:
760
+ effective_page = 0 if page_number is None else page_number
761
+ if isinstance(llm_prediction, dict):
762
+ return {
763
+ k: llm_prediction_to_tuples(
764
+ v, number_of_pages, page_number=effective_page
765
+ )
766
+ for k, v in llm_prediction.items()
767
+ }
768
+ elif isinstance(llm_prediction, list):
769
+ return [
770
+ llm_prediction_to_tuples(v, number_of_pages, page_number=effective_page)
771
+ for v in llm_prediction
772
+ ]
773
+ else:
774
+ return (llm_prediction, effective_page) if llm_prediction else None
775
+
776
+ # logic for multi-page predictions
777
+ if isinstance(llm_prediction, dict):
778
+ if "page_number" in llm_prediction.keys() and "value" in llm_prediction.keys():
779
+ if llm_prediction["value"]:
780
+ try:
781
+ _page_number = int(llm_prediction["page_number"])
782
+ except: # noqa: E722
783
+ _page_number = -1
784
+ return (llm_prediction["value"], _page_number)
785
+ return None
786
+
787
+ for key, value in llm_prediction.items():
788
+ llm_prediction[key] = llm_prediction_to_tuples(
789
+ llm_prediction.get(key, value), number_of_pages, page_number
790
+ )
791
+
792
+ elif isinstance(llm_prediction, list):
793
+ for i, item in enumerate(llm_prediction):
794
+ llm_prediction[i] = llm_prediction_to_tuples(
795
+ item, number_of_pages, page_number
796
+ )
797
+
798
+ return llm_prediction