data-science-document-ai 1.13.0__py3-none-any.whl → 1.56.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {data_science_document_ai-1.13.0.dist-info → data_science_document_ai-1.56.1.dist-info}/METADATA +7 -2
  2. data_science_document_ai-1.56.1.dist-info/RECORD +60 -0
  3. {data_science_document_ai-1.13.0.dist-info → data_science_document_ai-1.56.1.dist-info}/WHEEL +1 -1
  4. src/constants.py +42 -12
  5. src/constants_sandbox.py +2 -22
  6. src/docai.py +18 -7
  7. src/docai_processor_config.yaml +0 -64
  8. src/excel_processing.py +34 -15
  9. src/io.py +74 -6
  10. src/llm.py +12 -34
  11. src/pdf_processing.py +228 -78
  12. src/postprocessing/common.py +495 -618
  13. src/postprocessing/postprocess_partner_invoice.py +383 -27
  14. src/prompts/library/arrivalNotice/other/placeholders.json +70 -0
  15. src/prompts/library/arrivalNotice/other/prompt.txt +40 -0
  16. src/prompts/library/bookingConfirmation/evergreen/placeholders.json +17 -17
  17. src/prompts/library/bookingConfirmation/evergreen/prompt.txt +1 -0
  18. src/prompts/library/bookingConfirmation/hapag-lloyd/placeholders.json +18 -18
  19. src/prompts/library/bookingConfirmation/hapag-lloyd/prompt.txt +1 -1
  20. src/prompts/library/bookingConfirmation/maersk/placeholders.json +17 -17
  21. src/prompts/library/bookingConfirmation/maersk/prompt.txt +1 -1
  22. src/prompts/library/bookingConfirmation/msc/placeholders.json +17 -17
  23. src/prompts/library/bookingConfirmation/msc/prompt.txt +1 -1
  24. src/prompts/library/bookingConfirmation/oocl/placeholders.json +17 -17
  25. src/prompts/library/bookingConfirmation/oocl/prompt.txt +3 -1
  26. src/prompts/library/bookingConfirmation/other/placeholders.json +17 -17
  27. src/prompts/library/bookingConfirmation/other/prompt.txt +1 -1
  28. src/prompts/library/bookingConfirmation/yangming/placeholders.json +17 -17
  29. src/prompts/library/bookingConfirmation/yangming/prompt.txt +1 -1
  30. src/prompts/library/bundeskasse/other/placeholders.json +113 -0
  31. src/prompts/library/bundeskasse/other/prompt.txt +48 -0
  32. src/prompts/library/commercialInvoice/other/placeholders.json +125 -0
  33. src/prompts/library/commercialInvoice/other/prompt.txt +2 -1
  34. src/prompts/library/customsAssessment/other/placeholders.json +67 -16
  35. src/prompts/library/customsAssessment/other/prompt.txt +24 -37
  36. src/prompts/library/customsInvoice/other/placeholders.json +205 -0
  37. src/prompts/library/customsInvoice/other/prompt.txt +105 -0
  38. src/prompts/library/deliveryOrder/other/placeholders.json +79 -28
  39. src/prompts/library/deliveryOrder/other/prompt.txt +26 -40
  40. src/prompts/library/draftMbl/other/placeholders.json +33 -33
  41. src/prompts/library/draftMbl/other/prompt.txt +34 -44
  42. src/prompts/library/finalMbL/other/placeholders.json +34 -34
  43. src/prompts/library/finalMbL/other/prompt.txt +34 -44
  44. src/prompts/library/packingList/other/placeholders.json +98 -0
  45. src/prompts/library/packingList/other/prompt.txt +1 -1
  46. src/prompts/library/partnerInvoice/other/placeholders.json +165 -45
  47. src/prompts/library/partnerInvoice/other/prompt.txt +82 -44
  48. src/prompts/library/preprocessing/carrier/placeholders.json +0 -16
  49. src/prompts/library/shippingInstruction/other/placeholders.json +115 -0
  50. src/prompts/library/shippingInstruction/other/prompt.txt +28 -15
  51. src/setup.py +73 -63
  52. src/utils.py +207 -30
  53. data_science_document_ai-1.13.0.dist-info/RECORD +0 -55
  54. src/prompts/library/draftMbl/hapag-lloyd/prompt.txt +0 -44
  55. src/prompts/library/draftMbl/maersk/prompt.txt +0 -17
  56. src/prompts/library/finalMbL/hapag-lloyd/prompt.txt +0 -44
  57. src/prompts/library/finalMbL/maersk/prompt.txt +0 -17
@@ -1,414 +1,22 @@
1
1
  import asyncio
2
- import datetime
3
2
  import json
4
3
  import os
5
4
  import re
6
5
  from datetime import timezone
7
6
 
8
- import numpy as np
9
7
  import pandas as pd
10
- import requests
11
- from vertexai.preview.language_models import TextEmbeddingModel
8
+ from nltk.corpus import stopwords
9
+ from rapidfuzz import process
12
10
 
13
11
  from src.constants import formatting_rules
14
- from src.io import get_storage_client, logger
12
+ from src.io import logger
13
+ from src.postprocessing.postprocess_partner_invoice import process_partner_invoice
15
14
  from src.prompts.prompt_library import prompt_library
16
- from src.tms import call_tms, set_tms_service_token
15
+ from src.utils import batch_fetch_all_mappings, get_tms_mappings
17
16
 
18
17
  tms_domain = os.environ["TMS_DOMAIN"]
19
18
 
20
19
 
21
- class EmbeddingsManager: # noqa: D101
22
- def __init__(self, params): # noqa: D107
23
- self.params = params
24
- self.embeddings_dict = {}
25
- self.embed_model = setup_embed_model()
26
- self.bucket = self.get_bucket_storage()
27
- self.embedding_folder = self.embed_model._model_id
28
- self.embedding_dimension = 768 # TODO: to be reduced
29
-
30
- def get_bucket_storage(self):
31
- """
32
- Retrieve the bucket storage object.
33
-
34
- Returns:
35
- The bucket storage object.
36
- """
37
- params = self.params
38
- storage_client = get_storage_client(params)
39
- bucket = storage_client.bucket(params["doc_ai_bucket_name"])
40
- return bucket
41
-
42
- def _find_most_similar_option(self, input_string, option_ids, option_embeddings):
43
- """
44
- Find the most similar option to the given input string based on embeddings.
45
-
46
- Args:
47
- model: The model used for generating embeddings.
48
- input_string (str): The input string to find the most similar option for.
49
- option_ids (list): The list of option IDs.
50
- option_embeddings (np.ndarray): The embeddings of the options.
51
-
52
- Returns:
53
- The ID of the most similar option.
54
- """
55
- try:
56
- input_embedding = self.embed_model.get_embeddings(
57
- [input_string], output_dimensionality=self.embedding_dimension
58
- )[0].values
59
- similarities = np.dot(option_embeddings, input_embedding)
60
- idx = np.argmax(similarities)
61
- return option_ids[idx]
62
- except Exception as e:
63
- logger.error(f"Embeddings error: {e}")
64
- return None
65
-
66
- def load_embeddings(self):
67
- """
68
- Load embeddings for container types, ports, and terminals.
69
-
70
- Returns:
71
- None
72
- """
73
- for data_field in [
74
- "container_types",
75
- "ports",
76
- "terminals",
77
- "depots",
78
- "item_codes",
79
- ]:
80
- self.embeddings_dict[data_field] = load_embed_by_data_field(
81
- self.bucket,
82
- f"{self.embedding_folder}/{data_field}/output",
83
- self.embedding_dimension,
84
- )
85
-
86
- async def update_embeddings(self):
87
- """
88
- Update the embeddings dictionary.
89
-
90
- Returns:
91
- dict: The updated embeddings dictionary with the following keys:
92
- - "container_types": A tuple containing the container types and their embeddings.
93
- - "ports": A tuple containing the ports and their embeddings.
94
- - "terminals": A tuple containing the terminal IDs and their embeddings.
95
- """
96
- # Update embeddings dict here.
97
- # Ensure this method is async if you're calling async operations.
98
- set_tms_service_token()
99
- (
100
- container_types,
101
- container_type_embeddings,
102
- ) = self.setup_container_type_embeddings(
103
- *self.embeddings_dict.get("container_types", ([], []))
104
- )
105
-
106
- ports, port_embeddings = self.setup_ports_embeddings(
107
- *self.embeddings_dict.get("ports", ([], []))
108
- )
109
-
110
- # Setup terminal embeddings
111
- # Since retrieving terminal attributes requires calling TMS' api to extract terminals by each port,
112
- # we only do it for new ports.
113
- prev_port_ids, _ = self.embeddings_dict.get("ports", ([], []))
114
- added_port_ids = [port for port in ports if port not in prev_port_ids]
115
- if added_port_ids:
116
- terminal_ids, terminal_embeddings = self.setup_terminal_embeddings(
117
- added_port_ids
118
- )
119
- else:
120
- terminal_ids, terminal_embeddings = self.embeddings_dict["terminals"]
121
-
122
- depot_names, depot_embeddings = self.setup_depot_embeddings(
123
- *self.embeddings_dict.get("depots", ([], []))
124
- )
125
-
126
- item_code_names, item_code_embeddings = self.setup_item_code_embeddings(
127
- *self.embeddings_dict.get("item_codes", ([], []))
128
- )
129
-
130
- self.embeddings_dict = {
131
- "container_types": (container_types, container_type_embeddings),
132
- "ports": (ports, port_embeddings),
133
- "terminals": (terminal_ids, terminal_embeddings),
134
- "depots": (depot_names, depot_embeddings),
135
- "item_codes": (item_code_names, item_code_embeddings),
136
- }
137
- return self.embeddings_dict
138
-
139
- def batch_embed(self, option_strings: list[dict], suffix: str):
140
- """
141
- Compute embeddings for a batch of option strings and uploads them to a cloud storage bucket.
142
-
143
- Args:
144
- option_strings (list): A list of option strings to compute embeddings for.
145
- suffix (str): A suffix to be used in the storage path for the embeddings:
146
- input & output will be stored under "{bucket}/{parent_folder}/{suffix}/"
147
-
148
- Returns:
149
- tuple: A tuple containing the option IDs and embeddings.
150
- """
151
- now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
152
- input_path = f"{self.embedding_folder}/{suffix}/input/{now}.jsonl"
153
- blob = self.bucket.blob(input_path)
154
-
155
- # Convert each dictionary to a JSON string and join them with newlines
156
- option_strings = [
157
- {**option, "task_type": "SEMANTIC_SIMILARITY", "output_dimensionality": 256}
158
- for option in option_strings
159
- ]
160
- jsonl_string = "\n".join(json.dumps(d) for d in option_strings)
161
-
162
- # Convert the combined string to bytes
163
- jsonl_bytes = jsonl_string.encode("utf-8")
164
-
165
- # Upload the bytes to the blob
166
- blob.upload_from_string(jsonl_bytes, content_type="text/plain")
167
-
168
- # Compute embeddings for the options
169
- embedding_path = f"{self.embedding_folder}/{suffix}/output"
170
- assert len(option_strings) <= 30000 # Limit for batch embedding
171
- batch_resp = self.embed_model.batch_predict(
172
- dataset=f"gs://{self.bucket.name}/{input_path}", # noqa
173
- destination_uri_prefix=f"gs://{self.bucket.name}/{embedding_path}", # noqa
174
- )
175
-
176
- if batch_resp.state.name != "JOB_STATE_SUCCEEDED":
177
- logger.warning(
178
- f"Batch prediction job failed with state {batch_resp.state.name}"
179
- )
180
- else:
181
- logger.info(f"Embeddings for {suffix} computed successfully.")
182
-
183
- option_ids, option_embeddings = load_embed_by_data_field(
184
- self.bucket, embedding_path, self.embedding_dimension
185
- )
186
- return option_ids, option_embeddings
187
-
188
- def setup_container_type_embeddings(
189
- self, computed_container_type_ids, computed_container_type_embeddings
190
- ):
191
- """
192
- Set up container type embeddings.
193
-
194
- Args:
195
- computed_container_type_ids (list): The list of already computed container type IDs.
196
- computed_container_type_embeddings (list): The list of already computed container type embeddings.
197
-
198
- Returns:
199
- tuple: A tuple containing the updated container type IDs and embeddings.
200
- """
201
- url = (
202
- f"https://tms.forto.{tms_domain}/api/transport-units/api/types/list" # noqa
203
- )
204
- resp = call_tms(requests.get, url)
205
- container_types = resp.json()
206
-
207
- container_attribute_strings = [
208
- dict(
209
- title=container_type["code"],
210
- content=" | ".join(
211
- ["container type"]
212
- + [
213
- f"{k}: {v}"
214
- for k, v in container_type["containerAttributes"].items()
215
- if k in ["isoSizeType", "isoTypeGroup"]
216
- ]
217
- + [container_type[k] for k in ["displayName", "notes"]]
218
- ),
219
- )
220
- for container_type in container_types
221
- if container_type["isActive"]
222
- and container_type["code"] not in computed_container_type_ids
223
- ]
224
- if not container_attribute_strings:
225
- logger.info("No new container types found.")
226
- return computed_container_type_ids, computed_container_type_embeddings
227
-
228
- logger.info("Computing embeddings for container types...")
229
- container_type_ids, container_type_embeddings = self.batch_embed(
230
- container_attribute_strings, "container_types"
231
- )
232
- return container_type_ids, container_type_embeddings
233
-
234
- def setup_ports_embeddings(self, computed_port_ids, computed_port_embeddings):
235
- """
236
- Set up port embeddings.
237
-
238
- Steps:
239
- - Retrieve active ports from the TMS API
240
- - Compute embeddings for new tradelane-enabled ports
241
- - Return ALL port IDs and embeddings.
242
-
243
- Args:
244
- computed_port_ids (list): The list of previously computed port IDs.
245
- computed_port_embeddings (list): The list of previously computed port embeddings.
246
-
247
- Returns:
248
- tuple: A tuple containing ALL port IDs and embeddings.
249
- """
250
- url = f"https://tms.forto.{tms_domain}/api/transport-network/api/ports?pageSize=1000000&status=active" # noqa
251
- resp = call_tms(requests.get, url)
252
- resp_json = resp.json()
253
- if len(resp_json["data"]) != resp_json["_paging"]["totalRecords"]:
254
- logger.error("Not all ports were returned.")
255
-
256
- new_sea_ports = [
257
- port
258
- for port in resp_json["data"]
259
- if "sea" in port["modes"] and port["id"] not in computed_port_ids
260
- ]
261
- if not new_sea_ports:
262
- logger.info("No new ports found.")
263
- return computed_port_ids, computed_port_embeddings
264
-
265
- port_attribute_strings = [
266
- dict(
267
- title=port["id"],
268
- content=" ".join(
269
- [
270
- "port for shipping",
271
- add_text_without_space(
272
- port["name"]
273
- ), # for cases like QUINHON - Quinhon
274
- port["id"],
275
- ]
276
- ),
277
- )
278
- for port in new_sea_ports
279
- ]
280
-
281
- logger.info("Computing embeddings for ports.")
282
- port_ids, port_embeddings = self.batch_embed(port_attribute_strings, "ports")
283
- return port_ids, port_embeddings
284
-
285
- def setup_depot_embeddings(self, computed_depot_names, computed_depot_embeddings):
286
- """
287
- Set up depot embeddings.
288
-
289
- Steps:
290
- - Retrieve active depot from the TMS API
291
- - Compute embeddings for new tdepot
292
- - Return ALL depot names and embeddings.
293
-
294
- Args:
295
- computed_depot_names (list): The list of previously computed depot names.
296
- computed_depot_embeddings (list): The list of previously computed depot embeddings.
297
-
298
- Returns:
299
- tuple: A tuple containing ALL depot names and embeddings.
300
- """
301
- url = f"https://tms.forto.{tms_domain}/api/transport-network/api/depots?pageSize=1000000" # noqa
302
- resp = call_tms(requests.get, url)
303
- resp_json = resp.json()
304
-
305
- new_depots = [
306
- depot
307
- for depot in resp_json["data"]
308
- if depot["name"] not in computed_depot_names
309
- ]
310
- if not new_depots:
311
- logger.info("No new depots found.")
312
- return computed_depot_names, computed_depot_embeddings
313
-
314
- depot_attribute_strings = [
315
- dict(
316
- title=depot["name"],
317
- content=" | ".join(
318
- [
319
- "depot",
320
- "name - " + depot["name"],
321
- "address - " + depot["address"]["fullAddress"],
322
- ]
323
- ),
324
- )
325
- for depot in resp_json["data"]
326
- ]
327
-
328
- logger.info("Computing embeddings for depots.")
329
- depot_names, depot_embeddings = self.batch_embed(
330
- depot_attribute_strings, "depots"
331
- )
332
- return depot_names, depot_embeddings
333
-
334
- def setup_terminal_embeddings(self, added_port_ids):
335
- """
336
- Set up terminal embeddings for `added_port_ids`, using `model`, uploaded to `bucket`.
337
-
338
- Args:
339
- added_port_ids (list): A list of added port IDs.
340
-
341
- Returns:
342
- tuple: A tuple containing the ALL terminal IDs and terminal embeddings.
343
- Not just for the added port IDs.
344
- """
345
- terminal_attibute_strings = [
346
- setup_terminal_attributes(port_id) for port_id in added_port_ids
347
- ]
348
- terminal_attibute_strings = sum(terminal_attibute_strings, [])
349
- if not terminal_attibute_strings:
350
- logger.info("No new terminals found.")
351
- return [], np.array([])
352
-
353
- terminal_ids, terminal_embeddings = self.batch_embed(
354
- terminal_attibute_strings, "terminals"
355
- )
356
- return terminal_ids, terminal_embeddings
357
-
358
- def setup_item_code_embeddings(
359
- self, computed_item_code_names, computed_item_code_embeddings
360
- ):
361
- """
362
- Set up item_code embeddings.
363
-
364
- Steps:
365
- - Retrieve active item_code from the TMS API
366
- - Compute embeddings for new titem_code
367
- - Return ALL item_code names and embeddings.
368
-
369
- Args:
370
- computed_item_code_names (list): The list of previously computed item_code names.
371
- computed_item_code_embeddings (list): The list of previously computed item_code embeddings.
372
-
373
- Returns:
374
- tuple: A tuple containing ALL item_code names and embeddings.
375
- """
376
- url = f"https://tms.forto.{tms_domain}/api/catalog/item-codes?pageSize=1000000" # noqa
377
- resp = call_tms(requests.get, url)
378
- resp_json = resp.json()
379
-
380
- new_item_codes = [
381
- item_code
382
- for item_code in resp_json["results"]
383
- if item_code["id"] not in computed_item_code_names
384
- ]
385
- if not new_item_codes:
386
- logger.info("No new item_codes found.")
387
- return computed_item_code_names, computed_item_code_embeddings
388
-
389
- item_code_attribute_strings = [
390
- dict(title=item_code["id"], content=item_code["description"])
391
- for item_code in resp_json["results"]
392
- ]
393
-
394
- logger.info("Computing embeddings for item_codes.")
395
- item_code_names, item_code_embeddings = self.batch_embed(
396
- item_code_attribute_strings, "item_codes"
397
- )
398
- return item_code_names, item_code_embeddings
399
-
400
-
401
- def setup_embed_model():
402
- """
403
- Set up and return a text embedding model.
404
-
405
- Returns:
406
- TextEmbeddingModel: The initialized text embedding model.
407
- """
408
- model = TextEmbeddingModel.from_pretrained("text-multilingual-embedding-002")
409
- return model
410
-
411
-
412
20
  def convert_container_number(container_number):
413
21
  """
414
22
  Convert a container number to ISO standard.
@@ -423,7 +31,13 @@ def convert_container_number(container_number):
423
31
  return
424
32
  # 'FFAU2932130--FX34650895-40HC' -> 'FFAU2932130'
425
33
  match = re.findall(r"[A-Z]{4}\d{7}", container_number)
426
- stripped_value = match[0] if match else None
34
+ stripped_value = match if match else None
35
+
36
+ # LLMs do extract all the container numbers as a list of strings
37
+ if stripped_value and len(stripped_value) > 1:
38
+ return stripped_value
39
+ else:
40
+ stripped_value = stripped_value[0] if stripped_value else None
427
41
 
428
42
  if not stripped_value:
429
43
  stripped_value = "".join(
@@ -440,6 +54,51 @@ def convert_container_number(container_number):
440
54
  return formatted_value
441
55
 
442
56
 
57
+ def clean_invoice_number(invoice_number):
58
+ """Post process invoice number
59
+
60
+ Args:
61
+ invoice_number (str): The invoice number to be cleaned.
62
+
63
+ Returns:
64
+ str: The cleaned invoice number if it is valid, None otherwise.
65
+ """
66
+ if not invoice_number:
67
+ return
68
+
69
+ # Remove all non-alphanumeric characters
70
+ stripped_value = re.sub(r"[^\w]", "", invoice_number)
71
+
72
+ return stripped_value
73
+
74
+
75
+ def clean_shipment_id(shipment_id):
76
+ """
77
+ Convert shipment_id to Forto standard.
78
+
79
+ Args:
80
+ shipment_id (str): The Shipment ID to be converted.
81
+
82
+ Returns:
83
+ str: The formatted shipment_id if it is valid, None otherwise.
84
+ """
85
+ if not shipment_id:
86
+ return
87
+ # '#S1234565@-1' -> 'S1234565'
88
+ # Find the pattern of a shipment ID that starts with 'S' followed by 7 to 8 digits
89
+ match = re.findall(r"S\d{6,8}", shipment_id)
90
+ stripped_value = match[0] if match else None
91
+
92
+ if not stripped_value:
93
+ return None
94
+
95
+ # Check if length is valid (should be either 7 or 8)
96
+ if len(stripped_value) not in (7, 8, 9):
97
+ return None
98
+
99
+ return stripped_value
100
+
101
+
443
102
  # Clean the date for date obj parse in tms formatting
444
103
  def clean_date_string(date_str):
445
104
  """Remove hours and timezone information from the date string."""
@@ -475,9 +134,12 @@ def extract_number(data_field_value):
475
134
  formatted_value: string
476
135
 
477
136
  """
137
+ # Remove container size pattern like 20FT, 40HC, etc from 1 x 40HC
138
+ value = remove_unwanted_patterns(data_field_value)
139
+
478
140
  formatted_value = ""
479
- for c in data_field_value:
480
- if c.isnumeric() or c in [",", "."]:
141
+ for c in value:
142
+ if c.isnumeric() or c in [",", ".", "-"]:
481
143
  formatted_value += c
482
144
 
483
145
  # First and last characters should not be [",", "."]
@@ -507,106 +169,6 @@ def extract_string(data_field_value):
507
169
  return formatted_value if formatted_value not in ["''", ""] else None
508
170
 
509
171
 
510
- def extract_google_embed_resp(prediction_string, embedding_dimension):
511
- """
512
- Extract relevant information from the Google Embed API response.
513
-
514
- Args:
515
- prediction_string (str): The prediction string returned by the Google Embed API.
516
-
517
- Returns:
518
- dict: A dictionary containing the extracted information.
519
- - _id (str): The title of the instance.
520
- - attr_text (str): The content of the instance.
521
- - embedding (list): The embeddings values from the predictions.
522
-
523
- """
524
- res = json.loads(prediction_string)
525
- return dict(
526
- _id=res["instance"]["title"],
527
- attr_text=res["instance"]["content"],
528
- embedding=res["predictions"][0]["embeddings"]["values"][:embedding_dimension],
529
- )
530
-
531
-
532
- def load_embed_by_data_field(bucket, embedding_path, embedding_dimension):
533
- """
534
- Load embeddings by data field from the specified bucket and embedding path.
535
-
536
- Args:
537
- bucket (Bucket): The bucket object representing the storage bucket.
538
- embedding_path (str): The path to the embeddings in the bucket (different by data_field).
539
-
540
- Returns:
541
- tuple: A tuple containing the option IDs and option embeddings.
542
- - option_ids (list): A list of option IDs.
543
- - option_embeddings (ndarray): An array of option embeddings.
544
- """
545
- # Retrieve the embeddings from the output files
546
- blobs = bucket.list_blobs(prefix=embedding_path)
547
- all_blob_data = []
548
- for blob in blobs:
549
- blob_data = blob.download_as_bytes().decode("utf-8").splitlines()
550
- embeddings = [
551
- extract_google_embed_resp(data, embedding_dimension) for data in blob_data
552
- ]
553
- all_blob_data.extend(embeddings)
554
- option_ids = [embed["_id"] for embed in all_blob_data]
555
- option_embeddings = np.stack([embed["embedding"] for embed in all_blob_data])
556
- return option_ids, option_embeddings
557
-
558
-
559
- def setup_terminal_attributes(port_id: str):
560
- """
561
- Retrieve and format the attributes of active terminals at a given port.
562
-
563
- Args:
564
- port_id (str): The ID of the port.
565
-
566
- Returns:
567
- list: A list of dictionaries containing the formatted attributes of active terminals.
568
- Each dictionary has the following keys:
569
- - title: The terminal's short code.
570
- - content: A string representation of the terminal's attributes, including its name,
571
- searchable name, and full address.
572
- """
573
- url = f"https://gateway.forto.{tms_domain}/api/transport-network/api/ports/{port_id}/terminals/list" # noqa
574
- resp = call_tms(requests.get, url)
575
- terminals = resp.json()
576
- if len(terminals) == 0:
577
- return []
578
- active_terminals = [term for term in terminals if term["isActive"]]
579
- if len(active_terminals) == 0:
580
- logger.warning(f"No active terminals found at port {port_id}.")
581
- return []
582
-
583
- terminal_attibute_strings = [
584
- dict(
585
- title=term["name"],
586
- content=" | ".join(
587
- [
588
- "shipping terminal",
589
- "code - " + term["terminalShortCode"],
590
- "name - " + modify_terminal_name(term["searchableName"]),
591
- "country - " + term["address"]["country"],
592
- ]
593
- ),
594
- )
595
- for term in active_terminals
596
- ]
597
- return terminal_attibute_strings
598
-
599
-
600
- def modify_terminal_name(text):
601
- # Find the first occurrence of a word starting with 'K' followed by a number
602
- # and replace it with 'KAAI' - meaning Quay in Dutch
603
- match = re.search(r"K(\d+)", text)
604
- if match:
605
- # Append "KAAI" followed by the number if a match is found
606
- text += f" KAAI {match.group(1)}"
607
- return text
608
-
609
-
610
172
  def remove_none_values(d):
611
173
  if isinstance(d, dict):
612
174
  # Create a new dictionary to store non-None values
@@ -641,33 +203,137 @@ def check_formatting_rule(entity_key, document_type_code, rule):
641
203
  return False
642
204
 
643
205
 
644
- def convert_invoice_type(data_field_value):
645
- keyword_classification = {
646
- "invoice": "invoice",
647
- "rechnung": "invoice",
648
- "factura": "invoice",
649
- "fattura": "invoice",
650
- "faktura": "invoice",
651
- "debit": "invoice",
652
- "factuur": "invoice",
653
- "credit": "creditNote",
654
- "credito": "creditNote",
655
- "crédito": "creditNote",
656
- "creditnota": "creditNote",
657
- "kreditnota": "creditNote",
658
- "kredytowa": "creditNote",
659
- "rechnungskorrektur": "creditNote",
660
- "stornobeleg": "creditNote",
661
- }
206
+ def convert_invoice_type(data_field_value, params):
207
+ """
208
+ Converts a raw invoice type string to either invoice or creditNote using fuzzy matching.
662
209
 
663
- # TODO: sort according to lenght of the keys
664
- for keyword, classification in keyword_classification.items():
665
- if keyword in data_field_value.lower():
666
- return classification
210
+ Args:
211
+ data_field_value (str): The raw invoice type string from the data.
212
+ params (dict): A dictionary of parameters, including:
213
+ - "lookup_data": A dictionary containing lookup tables.
214
+ - "fuzzy_threshold_invoice_classification": The minimum fuzzy matching score.
215
+
216
+ Returns:
217
+ str or None: The standardized invoice type if a match is found, otherwise None.
218
+ """
219
+ lookup_data = params["lookup_data"]["invoice_classification"]
220
+ keywords = list(lookup_data.keys())
221
+
222
+ best_match = process.extractOne(
223
+ data_field_value.lower(),
224
+ keywords,
225
+ )
226
+ if best_match:
227
+ best_match_key, score, _ = best_match
228
+ if score >= params["fuzzy_threshold_invoice_classification"]:
229
+ return lookup_data[best_match_key]
667
230
  return None
668
231
 
669
232
 
670
- def clean_item_description(lineitem: str):
233
+ # Function to create KVP dictionary using apply method
234
+ def create_kvp_dictionary(df_raw: pd.DataFrame):
235
+ """Create a key-value pair dictionary from the given DataFrame.
236
+
237
+ Args:
238
+ df_raw (pd.DataFrame): The input DataFrame containing 'lineitem' and 'Forto SLI' columns.
239
+
240
+ return:
241
+ A key-value pair dictionary with 'Processed Lineitem' as key and 'Forto SLI' as value.
242
+ """
243
+ df = df_raw.copy()
244
+ df["Processed Lineitem"] = df["lineitem"].apply(clean_item_description)
245
+ kvp_dict = df.set_index("Processed Lineitem")["Forto SLI"].to_dict()
246
+
247
+ return kvp_dict
248
+
249
+
250
+ def remove_dates(lineitem: str):
251
+ """
252
+ This function removes dates in the format "dd Month yyyy" from the given line item string.
253
+
254
+ Args:
255
+ lineitem (str): The input string from which dates will be removed.
256
+
257
+ Returns:
258
+ str: The string with dates removed.
259
+ """
260
+ # Remove dates in the format dd.mm.yy or dd.mm.yyyy
261
+ lineitem = re.sub(r"\b\d{1,2}\.\d{1,2}\.\d{2,4}\b", "", lineitem)
262
+
263
+ # Remove dates in the format "dd Month yyyy"
264
+ lineitem = re.sub(
265
+ r"\b\d{1,2} (?:january|february|march|april|may|june|july|august|september|october|november|december|januar|"
266
+ r"februar|märz|mai|juni|juli|oktober|dezember) \d{4}\b",
267
+ "",
268
+ lineitem,
269
+ flags=re.IGNORECASE,
270
+ )
271
+
272
+ # Define a list of month abbreviations in English and German
273
+ month_abbreviations = [
274
+ "JAN",
275
+ "FEB",
276
+ "MAR",
277
+ "APR",
278
+ "MAY",
279
+ "JUN",
280
+ "JUL",
281
+ "AUG",
282
+ "SEP",
283
+ "OCT",
284
+ "NOV",
285
+ "DEC",
286
+ "JAN",
287
+ "FEB",
288
+ "MRZ",
289
+ "APR",
290
+ "MAI",
291
+ "JUN",
292
+ "JUL",
293
+ "AUG",
294
+ "SEP",
295
+ "OKT",
296
+ "NOV",
297
+ "DEZ",
298
+ ]
299
+
300
+ # Create a regular expression pattern to match month abbreviations
301
+ pattern = r"\b(?:{})\b".format("|".join(month_abbreviations))
302
+
303
+ # Remove month abbreviations
304
+ lineitem = re.sub(pattern, "", lineitem, flags=re.IGNORECASE)
305
+
306
+ return lineitem
307
+
308
+
309
+ def remove_unwanted_patterns(lineitem: str):
310
+ """
311
+ This function removes dates, month names, and container numbers from the given line item string.
312
+
313
+ Args:
314
+ lineitem (str): The input string from which unwanted patterns will be removed.
315
+
316
+ Returns:
317
+ str: The string with dates, month names, and container numbers removed.
318
+ """
319
+ # Remove container numbers (4 letters followed by 7 digits)
320
+ lineitem = re.sub(r"\b[A-Z]{4}\d{7}\b", "", lineitem)
321
+
322
+ # Remove "HIGH CUBE"
323
+ lineitem = lineitem.replace("HIGH CUBE", "")
324
+
325
+ # Remove container size e.g., 20FT, 40HC, etc.
326
+ pattern = [
327
+ f"{s}{t}"
328
+ for s in ("20|22|40|45".split("|"))
329
+ for t in ("FT|HC|DC|HD|GP|OT|RF|FR|TK|DV".split("|"))
330
+ ]
331
+ lineitem = re.sub(r"|".join(pattern), "", lineitem, flags=re.IGNORECASE).strip()
332
+
333
+ return lineitem
334
+
335
+
336
+ def clean_item_description(lineitem: str, remove_numbers: bool = True):
671
337
  """
672
338
  This function removes dates, month names, whitespaces, currency patterns and container numbers from the given line item string. # noqa
673
339
 
@@ -677,16 +343,13 @@ def clean_item_description(lineitem: str):
677
343
  Returns:
678
344
  str: The cleaned string removed.
679
345
  """
680
- # Patterns
681
- date_explicit_patterns = r"\b\d{1,2} (?:january|february|march|april|may|june|july|august|september|october|november|december|januar|februar|märz|mai|juni|juli|oktober|dezember) \d{4}\b" # noqa
682
- date_abbrevation_patterns = r"\b(?:JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC|JAN|FEB|MRZ|APR|MAI|JUN|JUL|AUG|SEP|OKT|NOV|DEZ)\b" # noqa
683
346
  currency_codes_pattern = r"\b(USD|EUR|JPY|GBP|CAD|AUD|CHF|CNY|SEK|NZD|KRW|SGD|INR|BRL|ZAR|RUB|MXN|HKD|NOK|TRY|IDR|MYR|PHP|THB|VND|PLN|CZK|HUF|ILS|AED|SAR|QAR|KWD|EGP|NGN|ARS|CLP|COP|PEN|UYU|VEF|INR|PKR|BDT|LKR|NPR|MMK)\b" # noqa
684
- other_patterns = r"\b(HIGH CUBE|HIGHCUBE|FROM|TO|day s|Free|HQ)\b"
347
+
348
+ # Remove stopwords
349
+ lineitem = remove_stop_words(lineitem)
685
350
 
686
351
  # remove dates
687
- lineitem = re.sub(r"\b\d{1,2}\.\d{1,2}\.\d{2,4}\b", "", lineitem)
688
- lineitem = re.sub(date_explicit_patterns, "", lineitem, flags=re.IGNORECASE)
689
- lineitem = re.sub(date_abbrevation_patterns, "", lineitem, flags=re.IGNORECASE)
352
+ lineitem = remove_dates(lineitem)
690
353
 
691
354
  # remove whitespaces
692
355
  lineitem = re.sub(r"\s{2,}", " ", lineitem)
@@ -694,30 +357,50 @@ def clean_item_description(lineitem: str):
694
357
  # remove newlines
695
358
  lineitem = re.sub(r"\\n|\n", " ", lineitem)
696
359
 
697
- # remove special chars
698
- lineitem = re.sub(r"[^A-Za-z0-9\s]", " ", lineitem).strip()
699
-
700
- # remove other patterns
701
- lineitem = re.sub(other_patterns, "", lineitem, flags=re.IGNORECASE)
702
-
703
360
  # Remove the currency codes
704
361
  lineitem = re.sub(currency_codes_pattern, "", lineitem, flags=re.IGNORECASE)
705
362
 
706
- # Remove container numbers (4 letters followed by 7 digits)
707
- lineitem = re.sub(r"\b[A-Z]{4}\d{7}\b", "", lineitem)
363
+ # remove other patterns
364
+ lineitem = remove_unwanted_patterns(lineitem)
708
365
 
709
366
  # Remove numbers from the line item
710
- lineitem = re.sub(r"\d+", "", lineitem)
367
+ if (
368
+ remove_numbers
369
+ ): # Do not remove numbers for the reverse charge sentence as it contains Article number
370
+ lineitem = re.sub(r"\d+", "", lineitem)
371
+
372
+ # remove special chars
373
+ lineitem = re.sub(r"[^A-Za-z0-9\s]", " ", lineitem).strip()
374
+
375
+ # Remove x from lineitem like 10 x
376
+ lineitem = re.sub(r"\b[xX]\b", " ", lineitem).strip()
711
377
 
712
378
  return re.sub(r"\s{2,}", " ", lineitem).strip()
713
379
 
714
380
 
715
381
  async def format_label(
716
- entity_k, entity_value, embed_manager, document_type_code, llm_client
382
+ entity_k,
383
+ entity_value,
384
+ document_type_code,
385
+ params,
386
+ mime_type,
387
+ container_map,
388
+ terminal_map,
389
+ depot_map,
717
390
  ):
391
+ llm_client = params["LlmClient"]
718
392
  if isinstance(entity_value, dict): # if it's a nested entity
719
393
  format_tasks = [
720
- format_label(sub_k, sub_v, embed_manager, document_type_code, llm_client)
394
+ format_label(
395
+ sub_k,
396
+ sub_v,
397
+ document_type_code,
398
+ params,
399
+ mime_type,
400
+ container_map,
401
+ terminal_map,
402
+ depot_map,
403
+ )
721
404
  for sub_k, sub_v in entity_value.items()
722
405
  ]
723
406
  return entity_k, {k: v for k, v in await asyncio.gather(*format_tasks)}
@@ -725,35 +408,45 @@ async def format_label(
725
408
  format_tasks = await asyncio.gather(
726
409
  *[
727
410
  format_label(
728
- entity_k, sub_v, embed_manager, document_type_code, llm_client
411
+ entity_k,
412
+ sub_v,
413
+ document_type_code,
414
+ params,
415
+ mime_type,
416
+ container_map,
417
+ terminal_map,
418
+ depot_map,
729
419
  )
730
420
  for sub_v in entity_value
731
421
  ]
732
422
  )
733
423
  return entity_k, [v for _, v in format_tasks]
424
+
425
+ if mime_type == "application/pdf":
426
+ if isinstance(entity_value, tuple):
427
+ page = entity_value[1]
428
+ entity_value = entity_value[0]
429
+ else:
430
+ page = -1
431
+
734
432
  entity_key = entity_k.lower()
735
- embeddings_dict = embed_manager.embeddings_dict
736
433
  formatted_value = None
434
+
737
435
  if entity_key.startswith("port"):
738
436
  formatted_value = await get_port_code_ai(
739
- entity_value, llm_client, embed_manager, *embeddings_dict["ports"]
437
+ entity_value, llm_client, doc_type=document_type_code
740
438
  )
439
+
741
440
  elif (entity_key == "containertype") or (entity_key == "containersize"):
742
- formatted_value = embed_manager._find_most_similar_option(
743
- "container type " + entity_value,
744
- *embeddings_dict["container_types"],
745
- )
441
+ formatted_value = container_map.get(entity_value)
442
+
746
443
  elif check_formatting_rule(entity_k, document_type_code, "terminal"):
747
- formatted_value = embed_manager._find_most_similar_option(
748
- "shipping terminal " + str(entity_value),
749
- *embeddings_dict["terminals"],
750
- )
444
+ formatted_value = terminal_map.get(entity_value)
445
+
751
446
  elif check_formatting_rule(entity_k, document_type_code, "depot"):
752
- formatted_value = embed_manager._find_most_similar_option(
753
- "depot " + str(entity_value),
754
- *embeddings_dict["depots"],
755
- )
756
- elif entity_key.startswith(("eta", "etd", "duedate", "issuedate")):
447
+ formatted_value = depot_map.get(entity_value)
448
+
449
+ elif entity_key.startswith(("eta", "etd", "duedate", "issuedate", "servicedate")):
757
450
  try:
758
451
  cleaned_data_field_value = clean_date_string(entity_value)
759
452
  dt_obj = extract_date(cleaned_data_field_value)
@@ -771,28 +464,27 @@ async def format_label(
771
464
  formatted_value = dt_obj.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
772
465
  except ValueError as e:
773
466
  logger.info(f"ParserError: {e}")
467
+
468
+ elif (
469
+ entity_key in ["invoicenumber", "creditnoteinvoicenumber"]
470
+ and document_type_code == "bundeskasse"
471
+ ):
472
+ formatted_value = clean_invoice_number(entity_value)
473
+
474
+ elif entity_key in ("shipmentid", "partnerreference"):
475
+ # Clean the shipment ID to match Forto's standard (starts with 'S' followed by 7 or 8 digits)
476
+ formatted_value = clean_shipment_id(entity_value)
477
+
774
478
  elif entity_key == "containernumber":
775
479
  # Remove all non-alphanumeric characters like ' ', '-', etc.
776
480
  formatted_value = convert_container_number(entity_value)
481
+
777
482
  elif any(
778
483
  numeric_indicator in entity_key
779
- for numeric_indicator in [
780
- "weight",
781
- "quantity",
782
- "value",
783
- "amount",
784
- "totalamount",
785
- "totalamounteuro",
786
- "vatamount",
787
- "vatApplicableAmount",
788
- "grandTotal",
789
- ]
790
- ):
791
- formatted_value = extract_number(entity_value)
792
- elif (
793
- document_type_code in ["finalMbL", "draftMbl"] and entity_key == "measurements"
484
+ for numeric_indicator in ["measurements", "weight"]
794
485
  ):
795
486
  formatted_value = extract_number(entity_value)
487
+
796
488
  elif any(
797
489
  packaging_type in entity_key
798
490
  for packaging_type in ["packagingtype", "packagetype", "currency"]
@@ -802,27 +494,57 @@ async def format_label(
802
494
  elif "lineitemdescription" in entity_key:
803
495
  formatted_value = clean_item_description(entity_value)
804
496
  elif "documenttype" in entity_key:
805
- formatted_value = convert_invoice_type(entity_value)
497
+ formatted_value = convert_invoice_type(entity_value, params)
498
+
499
+ # Handle reverseChargeSentence
500
+ elif "reversechargesentence" in entity_key:
501
+ formatted_value = clean_item_description(entity_value, remove_numbers=False)
502
+
503
+ elif "quantity" in entity_key:
504
+ if document_type_code in ["partnerInvoice", "customsInvoice", "bundeskasse"]:
505
+ # For partner invoice, quantity can be mentioned as whole number
506
+ # Apply decimal convertor for 46,45 --> 46.45 but not for 1.000 --> 1000
507
+ formatted_value = decimal_convertor(
508
+ extract_number(entity_value), quantity=True
509
+ )
510
+ else:
511
+ formatted_value = extract_number(entity_value)
512
+
513
+ elif any(
514
+ numeric_indicator in entity_key
515
+ for numeric_indicator in [
516
+ "value",
517
+ "amount",
518
+ "price",
519
+ "totalamount",
520
+ "totalamounteuro",
521
+ "vatamount",
522
+ "vatapplicableamount",
523
+ "grandtotal",
524
+ ]
525
+ ):
526
+ # Convert EU values to English values (e.g., 4.123,45 -> 4123.45)
527
+ formatted_value = decimal_convertor(extract_number(entity_value))
806
528
 
807
529
  result = {
808
530
  "documentValue": entity_value,
809
531
  "formattedValue": formatted_value,
810
532
  }
533
+ if mime_type == "application/pdf":
534
+ result["page"] = page
535
+
811
536
  return entity_k, result
812
537
 
813
538
 
814
- async def get_port_code_ai(
815
- port: str, llm_client, embed_manager, port_ids, port_embeddings
816
- ):
817
- port_llm = await get_port_code_llm(port, llm_client)
539
+ async def get_port_code_ai(port: str, llm_client, doc_type=None):
540
+ """Get port code using AI model."""
541
+ port_llm = await get_port_code_llm(port, llm_client, doc_type=doc_type)
818
542
 
819
- if port_llm in port_ids:
820
- return port_llm
821
- port_text = f"port for shipping {port}"
822
- return embed_manager._find_most_similar_option(port_text, port_ids, port_embeddings)
543
+ result = await get_tms_mappings(port, "ports", port_llm)
544
+ return result.get(port, None)
823
545
 
824
546
 
825
- async def get_port_code_llm(port: str, llm_client):
547
+ async def get_port_code_llm(port: str, llm_client, doc_type=None):
826
548
  if (
827
549
  "postprocessing" in prompt_library.library.keys()
828
550
  and "port_code" in prompt_library.library["postprocessing"].keys()
@@ -849,7 +571,7 @@ async def get_port_code_llm(port: str, llm_client):
849
571
  }
850
572
 
851
573
  response = await llm_client.get_unified_json_genai(
852
- prompt, response_schema=response_schema, model="chatgpt"
574
+ prompt, response_schema=response_schema, model="chatgpt", doc_type=doc_type
853
575
  )
854
576
  try:
855
577
  mapped_port = response["port"]
@@ -859,25 +581,149 @@ async def get_port_code_llm(port: str, llm_client):
859
581
  return None
860
582
 
861
583
 
862
- async def format_all_entities(result, embed_manager, document_type_code, llm_client):
863
- # remove None values from dict
584
+ def decimal_convertor(value, quantity=False):
585
+ """Convert EU values to English values."""
586
+ if value is None:
587
+ return None
588
+
589
+ # Remove spaces
590
+ value = value.strip().replace(" ", "")
591
+
592
+ # Check "-" and remove it for processing
593
+ is_negative, value = (True, value[1:]) if value.startswith("-") else (False, value)
594
+
595
+ if not quantity:
596
+ # Convert comma to dot for decimal point (e.g., 4.123,45 -> 4123.45)
597
+ if re.match(r"^\d{1,3}(\.\d{3})*,\d{1,2}$", value):
598
+ value = value.replace(".", "").replace(",", ".")
599
+
600
+ # European style integer with thousand separator: 2.500
601
+ elif re.match(r"^\d{1,3}(\.\d{3})+$", value):
602
+ value = value.replace(".", "")
603
+
604
+ # Format english values as well for consistency (e.g., 4,123.45 -> 4123.45)
605
+ elif re.match(r"^\d{1,3}(,\d{3})*\.\d{1,2}$", value):
606
+ value = value.replace(",", "")
607
+
608
+ # English style integer with thousand separator: 2,500
609
+ elif re.match(r"^\d{1,3}(,\d{3})+$", value):
610
+ value = value.replace(",", "")
611
+
612
+ # Just replace comma decimals with dot (e.g., 65,45 -> 65.45)
613
+ if re.match(r"^\d+,\d{1,2}$", value):
614
+ value = value.replace(",", ".")
615
+
616
+ # If there are more than 3 0s after decimal point, consider only 2 decimal points (e.g., 8.500000 -> 8.50)
617
+ elif re.match(r"^\d+\.\d{3,}$", value):
618
+ value = value[: value.index(".") + 3]
619
+
620
+ else: # quantity=True → only last two
621
+ # Just replace comma decimals with dot (e.g., 65,45 -> 65.45)
622
+ if re.match(r"^\d+,\d{1,2}$", value):
623
+ value = value.replace(",", ".")
624
+
625
+ # If there are more than 3 0s after decimal point, consider only 2 decimal points (e.g., 8.500000 -> 8.50)
626
+ elif re.match(r"^\d+\.\d{3,}$", value):
627
+ value = value[: value.index(".") + 3]
628
+
629
+ # Re-add negative sign if applicable
630
+ value = "-" + value if is_negative else value
631
+
632
+ return value
633
+
634
+
635
+ async def collect_mapping_requests(entity_value, document_type_code):
636
+ """Collect all unique container types, terminals, and depots from the entity value."""
637
+ # Sets to store unique values
638
+ container_types = set()
639
+ terminals = set()
640
+ depots = set()
641
+
642
+ def walk(key, value):
643
+ key_lower = key.lower()
644
+
645
+ # nested dict
646
+ if isinstance(value, dict):
647
+ for k, v in value.items():
648
+ walk(k, v)
649
+
650
+ # list of values
651
+ elif isinstance(value, list):
652
+ for item in value:
653
+ walk(key, item)
654
+
655
+ # leaf node
656
+ else:
657
+ if key_lower in ("containertype", "containersize"):
658
+ # Take only "20DV" from ('20DV', 0) if it's a tuple
659
+ container_types.add(value[0]) if isinstance(
660
+ value, tuple
661
+ ) else container_types.add(value)
662
+
663
+ elif check_formatting_rule(key, document_type_code, "terminal"):
664
+ terminals.add(value[0]) if isinstance(value, tuple) else terminals.add(
665
+ value
666
+ )
667
+
668
+ elif check_formatting_rule(key, document_type_code, "depot"):
669
+ depots.add(value[0]) if isinstance(value, tuple) else depots.add(value)
670
+
671
+ walk("root", entity_value)
672
+
673
+ return container_types, terminals, depots
674
+
675
+
676
+ async def format_all_labels(entity_data, document_type_code, params, mime_type):
677
+ """Format all labels in the entity data using cached mappings."""
678
+ # Collect all mapping values needed
679
+ container_req, terminal_req, depot_req = await collect_mapping_requests(
680
+ entity_data, document_type_code
681
+ )
682
+
683
+ # Batch fetch mappings
684
+ container_map, terminal_map, depot_map = await batch_fetch_all_mappings(
685
+ container_req, terminal_req, depot_req
686
+ )
687
+
688
+ # Format labels using cached mappings
689
+ _, result = await format_label(
690
+ "root",
691
+ entity_data,
692
+ document_type_code,
693
+ params,
694
+ mime_type,
695
+ container_map,
696
+ terminal_map,
697
+ depot_map,
698
+ )
699
+
700
+ return _, result
701
+
702
+
703
+ async def format_all_entities(result, document_type_code, params, mime_type):
704
+ """Format the entity values in the result dictionary."""
705
+ # Since we treat `customsInvoice` same as `partnerInvoice`
706
+ document_type_code = (
707
+ "partnerInvoice"
708
+ if document_type_code == "customsInvoice"
709
+ else document_type_code
710
+ )
711
+ # Remove None values from the dictionary
864
712
  result = remove_none_values(result)
865
713
  if result is None:
866
714
  logger.info("No data was extracted.")
867
715
  return {}
868
- _, aggregated_data = await format_label(
869
- None, result, embed_manager, document_type_code, llm_client
716
+
717
+ # Format all entities recursively
718
+ _, aggregated_data = await format_all_labels(
719
+ result, document_type_code, params, mime_type
870
720
  )
871
- # TODO: this is hardcoded cuz its an exception
872
- # It will stay this way until a better implementation will be found
873
- if document_type_code == "partnerInvoice":
874
- for line_item in aggregated_data.get("lineItem", []):
875
- line_item["itemCode"] = associate_forto_item_code(
876
- embed_manager, line_item["lineItemDescription"]["formattedValue"]
877
- )
878
721
 
879
- logger.info("Data Extraction completed successfully")
722
+ # Process partner invoice on lineitem mapping and reverse charge sentence
723
+ if document_type_code in ["partnerInvoice", "bundeskasse"]:
724
+ await process_partner_invoice(params, aggregated_data, document_type_code)
880
725
 
726
+ logger.info("Data Extraction completed successfully")
881
727
  return aggregated_data
882
728
 
883
729
 
@@ -890,32 +736,63 @@ def add_text_without_space(text):
890
736
  return text
891
737
 
892
738
 
893
- def associate_forto_item_code(embed_manager, line_item_description):
894
- """
895
- Associates a given line item description with the most similar Forto item code.
896
-
897
- This function utilizes an embedding manager to find the most similar Forto item code
898
- from a pre-computed set of embeddings, based on the provided line item description.
739
+ def remove_stop_words(lineitem: str):
740
+ """Remove stop words in English and German from the given line item string.
899
741
 
900
742
  Args:
901
- embed_manager: An instance of an embedding manager that contains the item code embeddings.
902
- line_item_description: The description of the line item to associate with an item code.
743
+ lineitem (str): The input string from which stop words will be removed.
903
744
 
904
745
  Returns:
905
- A dictionary containing the original line item description and the most similar
906
- formatted item code. The dictionary has the following structure:
907
- {
908
- "documentValue": line_item_description,
909
- "formattedValue": formatted_value,
910
- }
911
- where formatted_value is the most similar item code found.
746
+ str: The string with stop words removed.
912
747
  """
913
- embeddings_dict = embed_manager.embeddings_dict
914
- formatted_value = embed_manager._find_most_similar_option(
915
- line_item_description, *embeddings_dict["item_codes"]
748
+ stop_words = set(stopwords.words("english") + stopwords.words("german")) - {"off"}
749
+ return (
750
+ " ".join(word for word in lineitem.split() if word.lower() not in stop_words)
751
+ .upper()
752
+ .strip()
916
753
  )
917
- result = {
918
- "documentValue": line_item_description,
919
- "formattedValue": formatted_value,
920
- }
921
- return result
754
+
755
+
756
+ def llm_prediction_to_tuples(llm_prediction, number_of_pages=-1, page_number=None):
757
+ """Convert LLM prediction dictionary to tuples of (value, page_number)."""
758
+ # If only 1 page, simply pair each value with page number 0
759
+ if number_of_pages == 1:
760
+ effective_page = 0 if page_number is None else page_number
761
+ if isinstance(llm_prediction, dict):
762
+ return {
763
+ k: llm_prediction_to_tuples(
764
+ v, number_of_pages, page_number=effective_page
765
+ )
766
+ for k, v in llm_prediction.items()
767
+ }
768
+ elif isinstance(llm_prediction, list):
769
+ return [
770
+ llm_prediction_to_tuples(v, number_of_pages, page_number=effective_page)
771
+ for v in llm_prediction
772
+ ]
773
+ else:
774
+ return (llm_prediction, effective_page) if llm_prediction else None
775
+
776
+ # logic for multi-page predictions
777
+ if isinstance(llm_prediction, dict):
778
+ if "page_number" in llm_prediction.keys() and "value" in llm_prediction.keys():
779
+ if llm_prediction["value"]:
780
+ try:
781
+ _page_number = int(llm_prediction["page_number"])
782
+ except: # noqa: E722
783
+ _page_number = -1
784
+ return (llm_prediction["value"], _page_number)
785
+ return None
786
+
787
+ for key, value in llm_prediction.items():
788
+ llm_prediction[key] = llm_prediction_to_tuples(
789
+ llm_prediction.get(key, value), number_of_pages, page_number
790
+ )
791
+
792
+ elif isinstance(llm_prediction, list):
793
+ for i, item in enumerate(llm_prediction):
794
+ llm_prediction[i] = llm_prediction_to_tuples(
795
+ item, number_of_pages, page_number
796
+ )
797
+
798
+ return llm_prediction