data-science-document-ai 1.38.0__py3-none-any.whl → 1.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: data-science-document-ai
3
- Version: 1.38.0
3
+ Version: 1.39.0
4
4
  Summary: "Document AI repo for data science"
5
5
  Author: Naomi Nguyen
6
6
  Author-email: naomi.nguyen@forto.com
@@ -6,11 +6,11 @@ src/excel_processing.py,sha256=ZUlZ5zgWObmQfAWHoSrEEITKwr-xXxuOiPC3qDnGjtQ,2459
6
6
  src/io.py,sha256=IXz4wWqiHa9mnHNgtrC6X9M2lItYp9eu6rHCThUIh5c,3585
7
7
  src/llm.py,sha256=aEK3rL8XvY7CakvkOJQmcHpEKwZRd8PPrLrzHiO-GFk,7827
8
8
  src/log_setup.py,sha256=RhHnpXqcl-ii4EJzRt47CF2R-Q3YPF68tepg_Kg7tkw,2895
9
- src/pdf_processing.py,sha256=g-WVrI6J2lbVR3eOoDJnuy4buWh7bTmO-3aezoTN3i4,15527
10
- src/postprocessing/common.py,sha256=UxwmnXH7saggxDMs9Ssx_Bp3-O9NeUrcKFWRI_QYuZ0,39583
9
+ src/pdf_processing.py,sha256=yB0FpIdSRqxeEbZAIK_bPFypWrSSMb8uwCRxTTFfmxc,15493
10
+ src/postprocessing/common.py,sha256=OR9O73gUP4tevIZMnorbiUgzviEJlVr46ArTWMXrYVA,19316
11
11
  src/postprocessing/postprocess_booking_confirmation.py,sha256=nK32eDiBNbauyQz0oCa9eraysku8aqzrcoRFoWVumDU,4827
12
12
  src/postprocessing/postprocess_commercial_invoice.py,sha256=3I8ijluTZcOs_sMnFZxfkAPle0UFQ239EMuvZfDZVPg,1028
13
- src/postprocessing/postprocess_partner_invoice.py,sha256=oCT-l31DTosUf0cz0d5IWOF6erw6rD3rQfR58koSUeM,11760
13
+ src/postprocessing/postprocess_partner_invoice.py,sha256=bWm3Miaq_mtX62xSs14vNQCWPHOj2895Bt6TuOVZWZU,11742
14
14
  src/prompts/library/bookingConfirmation/evergreen/placeholders.json,sha256=Re2wBgZoaJ5yImUUAwZOZxFcKXHxi83TCZwTuqd2v2k,1405
15
15
  src/prompts/library/bookingConfirmation/evergreen/prompt.txt,sha256=qlBMFDHy-gwr2PVeuHrfMEg_8Ibdym243DnaCgINa7g,2614
16
16
  src/prompts/library/bookingConfirmation/hapag-lloyd/placeholders.json,sha256=Re2wBgZoaJ5yImUUAwZOZxFcKXHxi83TCZwTuqd2v2k,1405
@@ -51,9 +51,9 @@ src/prompts/library/preprocessing/carrier/placeholders.json,sha256=1UmrQNqBEsjLI
51
51
  src/prompts/library/preprocessing/carrier/prompt.txt,sha256=NLvRZQCZ6aWC1yTr7Q93jK5z7Vi_b4HBaiFYYnIsO-w,134
52
52
  src/prompts/library/shippingInstruction/other/prompt.txt,sha256=fyC24ig4FyRNnLuQM69s4ZVajsK-LHIl2dvaaEXr-6Q,1327
53
53
  src/prompts/prompt_library.py,sha256=VJWHeXN-s501C2GiidIIvQQuZdU6T1R27hE2dKBiI40,2555
54
- src/setup.py,sha256=TJu68mXS6Dx90Il8A_pHDnrIOiLD3q9f7FWgW0c1HOM,8352
54
+ src/setup.py,sha256=kPSZosrICfaGZeDaajr40Ha7Ok4XK4fo_uq35Omiwr0,7128
55
55
  src/tms.py,sha256=UXbIo1QE--hIX6NZi5Qyp2R_CP338syrY9pCTPrfgnE,1741
56
- src/utils.py,sha256=68x3hakQ8aDfq7967XoTRe_vsneWnLbWp_jz8q_FrBA,12189
57
- data_science_document_ai-1.38.0.dist-info/METADATA,sha256=fDQUs1Zi2ZteWsX6vm-0pNEZA04KJAels_OjC4Job6o,2153
58
- data_science_document_ai-1.38.0.dist-info/WHEEL,sha256=M5asmiAlL6HEcOq52Yi5mmk9KmTVjY2RDPtO4p9DMrc,88
59
- data_science_document_ai-1.38.0.dist-info/RECORD,,
56
+ src/utils.py,sha256=fTzXSeXejSmeGH0thOEdtDTAVhmeALVg44ORFarBKOk,13826
57
+ data_science_document_ai-1.39.0.dist-info/METADATA,sha256=odM1ywura_FepOjD6iVQXc4QXVsmrtRQILScEsyZTM0,2153
58
+ data_science_document_ai-1.39.0.dist-info/WHEEL,sha256=M5asmiAlL6HEcOq52Yi5mmk9KmTVjY2RDPtO4p9DMrc,88
59
+ data_science_document_ai-1.39.0.dist-info/RECORD,,
src/pdf_processing.py CHANGED
@@ -355,7 +355,6 @@ async def data_extraction_manual_flow(
355
355
  meta,
356
356
  processor_client,
357
357
  schema_client,
358
- embed_manager,
359
358
  ):
360
359
  """
361
360
  Process a PDF file and extract data from it.
@@ -418,7 +417,7 @@ async def data_extraction_manual_flow(
418
417
  )
419
418
  # Create the result dictionary with the extracted data
420
419
  extracted_data = await format_all_entities(
421
- extracted_data, embed_manager, meta.documentTypeCode, params
420
+ extracted_data, meta.documentTypeCode, params
422
421
  )
423
422
  result = {
424
423
  "id": meta.id,
@@ -1,426 +1,22 @@
1
1
  import asyncio
2
- import datetime
3
2
  import json
4
3
  import os
5
4
  import re
6
5
  from datetime import timezone
7
6
 
8
- import numpy as np
9
7
  import pandas as pd
10
- import requests
11
8
  from nltk.corpus import stopwords
12
9
  from rapidfuzz import process
13
- from vertexai.preview.language_models import TextEmbeddingModel
14
10
 
15
11
  from src.constants import formatting_rules
16
- from src.io import get_storage_client, logger
12
+ from src.io import logger
17
13
  from src.postprocessing.postprocess_partner_invoice import process_partner_invoice
18
14
  from src.prompts.prompt_library import prompt_library
19
- from src.tms import call_tms, set_tms_service_token
15
+ from src.utils import get_tms_mappings
20
16
 
21
17
  tms_domain = os.environ["TMS_DOMAIN"]
22
18
 
23
19
 
24
- class EmbeddingsManager: # noqa: D101
25
- def __init__(self, params): # noqa: D107
26
- self.params = params
27
- self.embeddings_dict = {}
28
- self.embed_model = setup_embed_model()
29
- self.bucket = self.get_bucket_storage()
30
- self.embedding_folder = self.embed_model._model_id
31
- self.embedding_dimension = 768 # TODO: to be reduced
32
-
33
- def get_bucket_storage(self):
34
- """
35
- Retrieve the bucket storage object.
36
-
37
- Returns:
38
- The bucket storage object.
39
- """
40
- params = self.params
41
- storage_client = get_storage_client(params)
42
- bucket = storage_client.bucket(params["doc_ai_bucket_name"])
43
- return bucket
44
-
45
- def _find_most_similar_option(self, input_string, option_ids, option_embeddings):
46
- """
47
- Find the most similar option to the given input string based on embeddings.
48
-
49
- Args:
50
- model: The model used for generating embeddings.
51
- input_string (str): The input string to find the most similar option for.
52
- option_ids (list): The list of option IDs.
53
- option_embeddings (np.ndarray): The embeddings of the options.
54
-
55
- Returns:
56
- The ID of the most similar option.
57
- """
58
- try:
59
- input_embedding = self.embed_model.get_embeddings(
60
- [input_string], output_dimensionality=self.embedding_dimension
61
- )[0].values
62
- similarities = np.dot(option_embeddings, input_embedding)
63
- idx = np.argmax(similarities)
64
- return option_ids[idx]
65
- except Exception as e:
66
- logger.error(f"Embeddings error: {e}")
67
- return None
68
-
69
- def load_embeddings(self):
70
- """
71
- Load embeddings for container types, ports, and terminals.
72
-
73
- Returns:
74
- None
75
- """
76
- for data_field in [
77
- "container_types",
78
- "ports",
79
- "terminals",
80
- "depots",
81
- "item_codes_label",
82
- ]:
83
- self.embeddings_dict[data_field] = load_embed_by_data_field(
84
- self.bucket,
85
- f"{self.embedding_folder}/{data_field}/output",
86
- self.embedding_dimension,
87
- )
88
-
89
- async def update_embeddings(self):
90
- """
91
- Update the embeddings dictionary.
92
-
93
- Returns:
94
- dict: The updated embeddings dictionary with the following keys:
95
- - "container_types": A tuple containing the container types and their embeddings.
96
- - "ports": A tuple containing the ports and their embeddings.
97
- - "terminals": A tuple containing the terminal IDs and their embeddings.
98
- """
99
- # Update embeddings dict here.
100
- # Ensure this method is async if you're calling async operations.
101
- set_tms_service_token()
102
- (
103
- container_types,
104
- container_type_embeddings,
105
- ) = self.setup_container_type_embeddings(
106
- *self.embeddings_dict.get("container_types", ([], []))
107
- )
108
-
109
- ports, port_embeddings = self.setup_ports_embeddings(
110
- *self.embeddings_dict.get("ports", ([], []))
111
- )
112
-
113
- # Setup terminal embeddings
114
- # Since retrieving terminal attributes requires calling TMS' api to extract terminals by each port,
115
- # we only do it for new ports.
116
- prev_port_ids, _ = self.embeddings_dict.get("ports", ([], []))
117
- added_port_ids = [port for port in ports if port not in prev_port_ids]
118
- if added_port_ids:
119
- terminal_ids, terminal_embeddings = self.setup_terminal_embeddings(
120
- added_port_ids
121
- )
122
- else:
123
- terminal_ids, terminal_embeddings = self.embeddings_dict["terminals"]
124
-
125
- depot_names, depot_embeddings = self.setup_depot_embeddings(
126
- *self.embeddings_dict.get("depots", ([], []))
127
- )
128
-
129
- item_code_names, item_code_embeddings = self.setup_item_code_embeddings(
130
- *self.embeddings_dict.get("item_codes_label", ([], []))
131
- )
132
-
133
- self.embeddings_dict = {
134
- "container_types": (container_types, container_type_embeddings),
135
- "ports": (ports, port_embeddings),
136
- "terminals": (terminal_ids, terminal_embeddings),
137
- "depots": (depot_names, depot_embeddings),
138
- "item_codes_label": (item_code_names, item_code_embeddings),
139
- }
140
- return self.embeddings_dict
141
-
142
- def batch_embed(self, option_strings: list[dict], suffix: str):
143
- """
144
- Compute embeddings for a batch of option strings and uploads them to a cloud storage bucket.
145
-
146
- Args:
147
- option_strings (list): A list of option strings to compute embeddings for.
148
- suffix (str): A suffix to be used in the storage path for the embeddings:
149
- input & output will be stored under "{bucket}/{parent_folder}/{suffix}/"
150
-
151
- Returns:
152
- tuple: A tuple containing the option IDs and embeddings.
153
- """
154
- now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
155
- input_path = f"{self.embedding_folder}/{suffix}/input/{now}.jsonl"
156
- blob = self.bucket.blob(input_path)
157
-
158
- # Convert each dictionary to a JSON string and join them with newlines
159
- option_strings = [
160
- {**option, "task_type": "SEMANTIC_SIMILARITY", "output_dimensionality": 256}
161
- for option in option_strings
162
- ]
163
- jsonl_string = "\n".join(json.dumps(d) for d in option_strings)
164
-
165
- # Convert the combined string to bytes
166
- jsonl_bytes = jsonl_string.encode("utf-8")
167
-
168
- # Upload the bytes to the blob
169
- blob.upload_from_string(jsonl_bytes, content_type="text/plain")
170
-
171
- # Compute embeddings for the options
172
- embedding_path = f"{self.embedding_folder}/{suffix}/output"
173
- assert len(option_strings) <= 30000 # Limit for batch embedding
174
- batch_resp = self.embed_model.batch_predict(
175
- dataset=f"gs://{self.bucket.name}/{input_path}", # noqa
176
- destination_uri_prefix=f"gs://{self.bucket.name}/{embedding_path}", # noqa
177
- )
178
-
179
- if batch_resp.state.name != "JOB_STATE_SUCCEEDED":
180
- logger.warning(
181
- f"Batch prediction job failed with state {batch_resp.state.name}"
182
- )
183
- else:
184
- logger.info(f"Embeddings for {suffix} computed successfully.")
185
-
186
- option_ids, option_embeddings = load_embed_by_data_field(
187
- self.bucket, embedding_path, self.embedding_dimension
188
- )
189
- return option_ids, option_embeddings
190
-
191
- def setup_container_type_embeddings(
192
- self, computed_container_type_ids, computed_container_type_embeddings
193
- ):
194
- """
195
- Set up container type embeddings.
196
-
197
- Args:
198
- computed_container_type_ids (list): The list of already computed container type IDs.
199
- computed_container_type_embeddings (list): The list of already computed container type embeddings.
200
-
201
- Returns:
202
- tuple: A tuple containing the updated container type IDs and embeddings.
203
- """
204
- url = (
205
- f"https://tms.forto.{tms_domain}/api/transport-units/api/types/list" # noqa
206
- )
207
- resp = call_tms(requests.get, url)
208
- container_types = resp.json()
209
-
210
- container_attribute_strings = [
211
- dict(
212
- title=container_type["code"],
213
- content=" | ".join(
214
- [container_type["code"]]
215
- + [
216
- f"{v}"
217
- for k, v in container_type["containerAttributes"].items()
218
- if k in ["isoSizeType", "isoTypeGroup", "containerCategory"]
219
- ]
220
- + [container_type.get(k, "") for k in ["displayName", "notes"]]
221
- ),
222
- )
223
- for container_type in container_types
224
- if container_type["type"] == "Container"
225
- and container_type["code"] not in computed_container_type_ids
226
- and container_type.get("containerAttributes") is not None
227
- ]
228
- if not container_attribute_strings:
229
- logger.info("No new container types found.")
230
- return computed_container_type_ids, computed_container_type_embeddings
231
-
232
- logger.info("Computing embeddings for container types...")
233
- container_type_ids, container_type_embeddings = self.batch_embed(
234
- container_attribute_strings, "container_types"
235
- )
236
- return container_type_ids, container_type_embeddings
237
-
238
- def setup_ports_embeddings(self, computed_port_ids, computed_port_embeddings):
239
- """
240
- Set up port embeddings.
241
-
242
- Steps:
243
- - Retrieve active ports from the TMS API
244
- - Compute embeddings for new tradelane-enabled ports
245
- - Return ALL port IDs and embeddings.
246
-
247
- Args:
248
- computed_port_ids (list): The list of previously computed port IDs.
249
- computed_port_embeddings (list): The list of previously computed port embeddings.
250
-
251
- Returns:
252
- tuple: A tuple containing ALL port IDs and embeddings.
253
- """
254
- url = f"https://tms.forto.{tms_domain}/api/transport-network/api/ports?pageSize=1000000&status=active" # noqa
255
- resp = call_tms(requests.get, url)
256
- resp_json = resp.json()
257
- if len(resp_json["data"]) != resp_json["_paging"]["totalRecords"]:
258
- logger.error("Not all ports were returned.")
259
-
260
- new_sea_ports = [
261
- port
262
- for port in resp_json["data"]
263
- if "sea" in port["modes"] and port["id"] not in computed_port_ids
264
- ]
265
- if not new_sea_ports:
266
- logger.info("No new ports found.")
267
- return computed_port_ids, computed_port_embeddings
268
-
269
- port_attribute_strings = [
270
- dict(
271
- title=port["id"],
272
- content=" ".join(
273
- [
274
- "port for shipping",
275
- add_text_without_space(
276
- port["name"]
277
- ), # for cases like QUINHON - Quinhon
278
- port["id"],
279
- ]
280
- ),
281
- )
282
- for port in new_sea_ports
283
- ]
284
-
285
- logger.info("Computing embeddings for ports.")
286
- port_ids, port_embeddings = self.batch_embed(port_attribute_strings, "ports")
287
- return port_ids, port_embeddings
288
-
289
- def setup_depot_embeddings(self, computed_depot_names, computed_depot_embeddings):
290
- """
291
- Set up depot embeddings.
292
-
293
- Steps:
294
- - Retrieve active depot from the TMS API
295
- - Compute embeddings for new tdepot
296
- - Return ALL depot names and embeddings.
297
-
298
- Args:
299
- computed_depot_names (list): The list of previously computed depot names.
300
- computed_depot_embeddings (list): The list of previously computed depot embeddings.
301
-
302
- Returns:
303
- tuple: A tuple containing ALL depot names and embeddings.
304
- """
305
- url = f"https://tms.forto.{tms_domain}/api/transport-network/api/depots?pageSize=1000000" # noqa
306
- resp = call_tms(requests.get, url)
307
- resp_json = resp.json()
308
-
309
- new_depots = [
310
- depot
311
- for depot in resp_json["data"]
312
- if depot["name"] not in computed_depot_names
313
- ]
314
- if not new_depots:
315
- logger.info("No new depots found.")
316
- return computed_depot_names, computed_depot_embeddings
317
-
318
- depot_attribute_strings = [
319
- dict(
320
- title=depot["name"],
321
- content=" | ".join(
322
- [
323
- "depot",
324
- "name - " + depot["name"],
325
- "address - " + depot["address"]["fullAddress"],
326
- ]
327
- ),
328
- )
329
- for depot in resp_json["data"]
330
- ]
331
-
332
- logger.info("Computing embeddings for depots.")
333
- depot_names, depot_embeddings = self.batch_embed(
334
- depot_attribute_strings, "depots"
335
- )
336
- return depot_names, depot_embeddings
337
-
338
- def setup_terminal_embeddings(self, added_port_ids):
339
- """
340
- Set up terminal embeddings for `added_port_ids`, using `model`, uploaded to `bucket`.
341
-
342
- Args:
343
- added_port_ids (list): A list of added port IDs.
344
-
345
- Returns:
346
- tuple: A tuple containing the ALL terminal IDs and terminal embeddings.
347
- Not just for the added port IDs.
348
- """
349
- terminal_attibute_strings = [
350
- setup_terminal_attributes(port_id) for port_id in added_port_ids
351
- ]
352
- terminal_attibute_strings = sum(terminal_attibute_strings, [])
353
- if not terminal_attibute_strings:
354
- logger.info("No new terminals found.")
355
- return [], np.array([])
356
-
357
- terminal_ids, terminal_embeddings = self.batch_embed(
358
- terminal_attibute_strings, "terminals"
359
- )
360
- return terminal_ids, terminal_embeddings
361
-
362
- def setup_item_code_embeddings(
363
- self, computed_item_code_names, computed_item_code_embeddings
364
- ):
365
- """
366
- Set up item_code embeddings.
367
-
368
- Steps:
369
- - Retrieve active item_code from the TMS API
370
- - Compute embeddings for new titem_code
371
- - Return ALL item_code names and embeddings.
372
-
373
- Args:
374
- computed_item_code_names (list): The list of previously computed item_code names.
375
- computed_item_code_embeddings (list): The list of previously computed item_code embeddings.
376
-
377
- Returns:
378
- tuple: A tuple containing ALL item_code names and embeddings.
379
- """
380
- url = f"https://tms.forto.{tms_domain}/api/catalog/item-codes?transportTypes=fcl&pageSize=1000000" # noqa
381
- resp = call_tms(requests.get, url)
382
- resp_json = resp.json()
383
-
384
- new_item_codes = [
385
- item_code
386
- for item_code in resp_json["results"]
387
- if item_code["id"] not in computed_item_code_names
388
- ]
389
- if not new_item_codes:
390
- logger.info("No new item_codes found.")
391
- return computed_item_code_names, computed_item_code_embeddings
392
-
393
- item_code_attribute_strings = [
394
- dict(
395
- title=item_code["id"],
396
- content=" | ".join(
397
- [
398
- item_code["id"],
399
- item_code["label"],
400
- ]
401
- ),
402
- )
403
- for item_code in resp_json["results"]
404
- ]
405
-
406
- logger.info("Computing embeddings for item_codes.")
407
- item_code_names, item_code_embeddings = self.batch_embed(
408
- item_code_attribute_strings, "item_codes_label"
409
- )
410
- return item_code_names, item_code_embeddings
411
-
412
-
413
- def setup_embed_model():
414
- """
415
- Set up and return a text embedding model.
416
-
417
- Returns:
418
- TextEmbeddingModel: The initialized text embedding model.
419
- """
420
- model = TextEmbeddingModel.from_pretrained("text-multilingual-embedding-002")
421
- return model
422
-
423
-
424
20
  def convert_container_number(container_number):
425
21
  """
426
22
  Convert a container number to ISO standard.
@@ -570,106 +166,6 @@ def extract_string(data_field_value):
570
166
  return formatted_value if formatted_value not in ["''", ""] else None
571
167
 
572
168
 
573
- def extract_google_embed_resp(prediction_string, embedding_dimension):
574
- """
575
- Extract relevant information from the Google Embed API response.
576
-
577
- Args:
578
- prediction_string (str): The prediction string returned by the Google Embed API.
579
-
580
- Returns:
581
- dict: A dictionary containing the extracted information.
582
- - _id (str): The title of the instance.
583
- - attr_text (str): The content of the instance.
584
- - embedding (list): The embeddings values from the predictions.
585
-
586
- """
587
- res = json.loads(prediction_string)
588
- return dict(
589
- _id=res["instance"]["title"],
590
- attr_text=res["instance"]["content"],
591
- embedding=res["predictions"][0]["embeddings"]["values"][:embedding_dimension],
592
- )
593
-
594
-
595
- def load_embed_by_data_field(bucket, embedding_path, embedding_dimension):
596
- """
597
- Load embeddings by data field from the specified bucket and embedding path.
598
-
599
- Args:
600
- bucket (Bucket): The bucket object representing the storage bucket.
601
- embedding_path (str): The path to the embeddings in the bucket (different by data_field).
602
-
603
- Returns:
604
- tuple: A tuple containing the option IDs and option embeddings.
605
- - option_ids (list): A list of option IDs.
606
- - option_embeddings (ndarray): An array of option embeddings.
607
- """
608
- # Retrieve the embeddings from the output files
609
- blobs = bucket.list_blobs(prefix=embedding_path)
610
- all_blob_data = []
611
- for blob in blobs:
612
- blob_data = blob.download_as_bytes().decode("utf-8").splitlines()
613
- embeddings = [
614
- extract_google_embed_resp(data, embedding_dimension) for data in blob_data
615
- ]
616
- all_blob_data.extend(embeddings)
617
- option_ids = [embed["_id"] for embed in all_blob_data]
618
- option_embeddings = np.stack([embed["embedding"] for embed in all_blob_data])
619
- return option_ids, option_embeddings
620
-
621
-
622
- def setup_terminal_attributes(port_id: str):
623
- """
624
- Retrieve and format the attributes of active terminals at a given port.
625
-
626
- Args:
627
- port_id (str): The ID of the port.
628
-
629
- Returns:
630
- list: A list of dictionaries containing the formatted attributes of active terminals.
631
- Each dictionary has the following keys:
632
- - title: The terminal's short code.
633
- - content: A string representation of the terminal's attributes, including its name,
634
- searchable name, and full address.
635
- """
636
- url = f"https://gateway.forto.{tms_domain}/api/transport-network/api/ports/{port_id}/terminals/list" # noqa
637
- resp = call_tms(requests.get, url)
638
- terminals = resp.json()
639
- if len(terminals) == 0:
640
- return []
641
- active_terminals = [term for term in terminals if term["isActive"]]
642
- if len(active_terminals) == 0:
643
- logger.warning(f"No active terminals found at port {port_id}.")
644
- return []
645
-
646
- terminal_attibute_strings = [
647
- dict(
648
- title=term["name"],
649
- content=" | ".join(
650
- [
651
- "shipping terminal",
652
- "code - " + term["terminalShortCode"],
653
- "name - " + modify_terminal_name(term["searchableName"]),
654
- "country - " + term["address"]["country"],
655
- ]
656
- ),
657
- )
658
- for term in active_terminals
659
- ]
660
- return terminal_attibute_strings
661
-
662
-
663
- def modify_terminal_name(text):
664
- # Find the first occurrence of a word starting with 'K' followed by a number
665
- # and replace it with 'KAAI' - meaning Quay in Dutch
666
- match = re.search(r"K(\d+)", text)
667
- if match:
668
- # Append "KAAI" followed by the number if a match is found
669
- text += f" KAAI {match.group(1)}"
670
- return text
671
-
672
-
673
169
  def remove_none_values(d):
674
170
  if isinstance(d, dict):
675
171
  # Create a new dictionary to store non-None values
@@ -731,25 +227,6 @@ def convert_invoice_type(data_field_value, params):
731
227
  return None
732
228
 
733
229
 
734
- def validate_reverse_charge_value(reverse_charge_sentence_value):
735
- """
736
- Validates the reverseChargeSentence value before assigning to line items.
737
-
738
- Args:
739
- reverse_charge_sentence_value (bool): The formatted value of reverseChargeSentence (True or False).
740
-
741
- Returns:
742
- bool: The validated reverseChargeSentence value.
743
- """
744
- if isinstance(reverse_charge_sentence_value, bool):
745
- return reverse_charge_sentence_value
746
- else:
747
- logger.warning(
748
- f"Invalid reverseChargeSentence value: {reverse_charge_sentence_value}. Defaulting to False."
749
- )
750
- return False
751
-
752
-
753
230
  # Function to create KVP dictionary using apply method
754
231
  def create_kvp_dictionary(df_raw: pd.DataFrame):
755
232
  """Create a key-value pair dictionary from the given DataFrame.
@@ -887,47 +364,37 @@ def clean_item_description(lineitem: str, remove_numbers: bool = True):
887
364
  return re.sub(r"\s{2,}", " ", lineitem).strip()
888
365
 
889
366
 
890
- async def format_label(
891
- entity_k, entity_value, embed_manager, document_type_code, params
892
- ):
367
+ async def format_label(entity_k, entity_value, document_type_code, params):
893
368
  llm_client = params["LlmClient"]
894
369
  if isinstance(entity_value, dict): # if it's a nested entity
895
370
  format_tasks = [
896
- format_label(sub_k, sub_v, embed_manager, document_type_code, params)
371
+ format_label(sub_k, sub_v, document_type_code, params)
897
372
  for sub_k, sub_v in entity_value.items()
898
373
  ]
899
374
  return entity_k, {k: v for k, v in await asyncio.gather(*format_tasks)}
900
375
  if isinstance(entity_value, list):
901
376
  format_tasks = await asyncio.gather(
902
377
  *[
903
- format_label(entity_k, sub_v, embed_manager, document_type_code, params)
378
+ format_label(entity_k, sub_v, document_type_code, params)
904
379
  for sub_v in entity_value
905
380
  ]
906
381
  )
907
382
  return entity_k, [v for _, v in format_tasks]
908
383
  entity_key = entity_k.lower()
909
- embeddings_dict = embed_manager.embeddings_dict
910
384
  formatted_value = None
911
385
 
912
386
  if entity_key.startswith("port"):
913
- formatted_value = await get_port_code_ai(
914
- entity_value, llm_client, embed_manager, *embeddings_dict["ports"]
915
- )
387
+ formatted_value = await get_port_code_ai(entity_value, llm_client)
388
+
916
389
  elif (entity_key == "containertype") or (entity_key == "containersize"):
917
- formatted_value = embed_manager._find_most_similar_option(
918
- "container type " + entity_value,
919
- *embeddings_dict["container_types"],
920
- )
390
+ formatted_value = get_tms_mappings(entity_value, "container_types")
391
+
921
392
  elif check_formatting_rule(entity_k, document_type_code, "terminal"):
922
- formatted_value = embed_manager._find_most_similar_option(
923
- "shipping terminal " + str(entity_value),
924
- *embeddings_dict["terminals"],
925
- )
393
+ formatted_value = get_tms_mappings(entity_value, "terminals")
394
+
926
395
  elif check_formatting_rule(entity_k, document_type_code, "depot"):
927
- formatted_value = embed_manager._find_most_similar_option(
928
- "depot " + str(entity_value),
929
- *embeddings_dict["depots"],
930
- )
396
+ formatted_value = get_tms_mappings(entity_value, "depots")
397
+
931
398
  elif entity_key.startswith(("eta", "etd", "duedate", "issuedate", "servicedate")):
932
399
  try:
933
400
  cleaned_data_field_value = clean_date_string(entity_value)
@@ -1002,15 +469,11 @@ async def format_label(
1002
469
  return entity_k, result
1003
470
 
1004
471
 
1005
- async def get_port_code_ai(
1006
- port: str, llm_client, embed_manager, port_ids, port_embeddings
1007
- ):
472
+ async def get_port_code_ai(port: str, llm_client):
473
+ """Get port code using AI model."""
1008
474
  port_llm = await get_port_code_llm(port, llm_client)
1009
475
 
1010
- if port_llm in port_ids:
1011
- return port_llm
1012
- port_text = f"port for shipping {port}"
1013
- return embed_manager._find_most_similar_option(port_text, port_ids, port_embeddings)
476
+ return get_tms_mappings(port, "ports", port_llm)
1014
477
 
1015
478
 
1016
479
  async def get_port_code_llm(port: str, llm_client):
@@ -1081,7 +544,7 @@ def decimal_convertor(value):
1081
544
  return value
1082
545
 
1083
546
 
1084
- async def format_all_entities(result, embed_manager, document_type_code, params):
547
+ async def format_all_entities(result, document_type_code, params):
1085
548
  """Format the entity values in the result dictionary."""
1086
549
  # Since we treat `customsInvoice` same as `partnerInvoice`
1087
550
  document_type_code = (
@@ -1096,15 +559,11 @@ async def format_all_entities(result, embed_manager, document_type_code, params)
1096
559
  return {}
1097
560
 
1098
561
  # Format all entities recursively
1099
- _, aggregated_data = await format_label(
1100
- None, result, embed_manager, document_type_code, params
1101
- )
562
+ _, aggregated_data = await format_label(None, result, document_type_code, params)
1102
563
 
1103
564
  # Process partner invoice on lineitem mapping and reverse charge sentence
1104
565
  if document_type_code in ["partnerInvoice", "bundeskasse"]:
1105
- process_partner_invoice(
1106
- params, aggregated_data, embed_manager, document_type_code
1107
- )
566
+ process_partner_invoice(params, aggregated_data, document_type_code)
1108
567
 
1109
568
  logger.info("Data Extraction completed successfully")
1110
569
  return aggregated_data
@@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
4
4
  from fuzzywuzzy import fuzz
5
5
 
6
6
  from src.io import logger
7
+ from src.utils import get_tms_mappings
7
8
 
8
9
 
9
10
  def postprocessing_partner_invoice(partner_invoice):
@@ -135,7 +136,7 @@ def update_recipient_and_vendor(aggregated_data, is_recipient_forto):
135
136
  ] = "Dasbachstraße 15, 54292 Trier, Germany"
136
137
 
137
138
 
138
- def process_partner_invoice(params, aggregated_data, embed_manager, document_type_code):
139
+ def process_partner_invoice(params, aggregated_data, document_type_code):
139
140
  """Process the partner invoice data."""
140
141
  # Post process containerNumber.
141
142
  # TODO: Remove this block of code after migrating to LLM completely and update the placeholder in the prompt library
@@ -192,7 +193,6 @@ def process_partner_invoice(params, aggregated_data, embed_manager, document_typ
192
193
  for line_item in line_items:
193
194
  if line_item.get("lineItemDescription", None) is not None:
194
195
  line_item["itemCode"] = associate_forto_item_code(
195
- embed_manager,
196
196
  line_item["lineItemDescription"]["formattedValue"],
197
197
  params,
198
198
  )
@@ -275,7 +275,7 @@ def find_matching_lineitem(new_lineitem: str, kvp_dict: dict, threshold=90):
275
275
  return kvp_dict.get(best_match, None)
276
276
 
277
277
 
278
- def associate_forto_item_code(embed_manager, input_string, params):
278
+ def associate_forto_item_code(input_string, params):
279
279
  """
280
280
  Finds a match for the input string using fuzzy matching first, then embedding fallback.
281
281
 
@@ -286,7 +286,6 @@ def associate_forto_item_code(embed_manager, input_string, params):
286
286
 
287
287
  Args:
288
288
  input_string: The string to find a match for.
289
- embed_manager: The embedding manager instance to use for fallback.
290
289
  params: Parameters containing the lookup data and fuzzy threshold.
291
290
 
292
291
  Returns:
@@ -301,10 +300,11 @@ def associate_forto_item_code(embed_manager, input_string, params):
301
300
 
302
301
  if forto_item_code is None:
303
302
  # 2. Fallback to embedding function if no good fuzzy match
304
- embeddings_dict = embed_manager.embeddings_dict
305
- forto_item_code = embed_manager._find_most_similar_option(
306
- input_string, *embeddings_dict["item_codes_label"]
307
- )
303
+ forto_item_code = get_tms_mappings(input_string, "line_items")
304
+ # embeddings_dict = embed_manager.embeddings_dict
305
+ # forto_item_code = embed_manager._find_most_similar_option(
306
+ # input_string, *embeddings_dict["item_codes_label"]
307
+ # )
308
308
 
309
309
  result = {"documentValue": input_string, "formattedValue": forto_item_code}
310
310
  return result
src/setup.py CHANGED
@@ -1,5 +1,4 @@
1
1
  """Contains project setup parameters and initialization functions."""
2
- import argparse
3
2
  import json
4
3
 
5
4
  # import streamlit as st
@@ -69,50 +68,6 @@ def get_docai_schema_client(params, async_=True):
69
68
  return client
70
69
 
71
70
 
72
- def parse_input():
73
- """Manage input parameters."""
74
- parser = argparse.ArgumentParser(description="", add_help=False)
75
- parser.add_argument(
76
- "--scope",
77
- type=str,
78
- dest="scope",
79
- required=False,
80
- help="Whether the function should 'upload' or 'download' documents",
81
- )
82
- parser.add_argument(
83
- "--document_name",
84
- type=str,
85
- dest="document_name",
86
- required=False,
87
- help="Category of the document (e.g., 'commercialInvoice', 'packingList')",
88
- )
89
- parser.add_argument(
90
- "--for_combinations",
91
- type=bool,
92
- default=False,
93
- dest="for_combinations",
94
- required=False,
95
- help="A flag to download documents into a special subfolder",
96
- )
97
- parser.add_argument(
98
- "--n_samples",
99
- type=int,
100
- default=50,
101
- dest="n_samples",
102
- required=False,
103
- help="A number of samples to download",
104
- )
105
-
106
- # Remove declared missing arguments (e.g. model_type)
107
- args = vars(parser.parse_args())
108
- args_no_null = {
109
- k: v.split(",") if isinstance(v, str) else v
110
- for k, v in args.items()
111
- if v is not None
112
- }
113
- return args_no_null
114
-
115
-
116
71
  def setup_params(args=None):
117
72
  """
118
73
  Set up the application parameters.
src/utils.py CHANGED
@@ -10,6 +10,7 @@ from typing import Literal
10
10
 
11
11
  import openpyxl
12
12
  import pandas as pd
13
+ import requests
13
14
  from google.cloud import documentai_v1beta3 as docu_ai_beta
14
15
  from PyPDF2 import PdfReader, PdfWriter
15
16
 
@@ -364,3 +365,50 @@ def extract_top_pages(pdf_bytes, num_pages=4):
364
365
  writer.write(output)
365
366
 
366
367
  return output.getvalue()
368
+
369
+
370
+ def get_tms_mappings(
371
+ input_list: list[str], embedding_type: str, llm_ports: list[str] = None
372
+ ):
373
+ """Get TMS mappings for the given values.
374
+
375
+ Args:
376
+ input_list (list[str]): List of strings to get embeddings for.
377
+ embedding_type (str): Type of embedding to use
378
+ (e.g., "container_types", "ports", "depots", "lineitems", "terminals").
379
+ llm_ports (list[str], optional): List of LLM ports to use. Defaults to None.
380
+
381
+ Returns:
382
+ dict: A dictionary with the mapping results.
383
+ """
384
+ # To test the API locally, port-forward the embedding service in the sandbox to 8080:80
385
+ # If you want to launch uvicorn from the tms-embedding repo, then use --port 8080 in the config file
386
+ base_url = (
387
+ "http://0.0.0.0:8080/"
388
+ if os.getenv("CLUSTER") is None
389
+ else "http://tms-embedding:80/api/v1/mappings"
390
+ )
391
+
392
+ # Ensure input_list is a list
393
+ if not isinstance(input_list, list):
394
+ input_list = [input_list]
395
+
396
+ # Always send a dict with named keys
397
+ payload = {embedding_type: input_list}
398
+ if llm_ports:
399
+ payload["llm_ports"] = llm_ports if isinstance(llm_ports, list) else [llm_ports]
400
+
401
+ # Make the POST request to the TMS mappings API
402
+ url = f"{base_url}/{embedding_type}"
403
+ response = requests.post(url=url, json=payload)
404
+
405
+ if response.status_code != 200:
406
+ logger.error(
407
+ f"Error from TMS mappings API: {response.status_code} - {response.text}"
408
+ )
409
+
410
+ formatted_values = (
411
+ response.json().get("response", {}).get("data", {}).get(input_list[0], None)
412
+ )
413
+
414
+ return formatted_values