nucliadb 6.4.2.post4376__py3-none-any.whl → 6.4.2.post4378__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,14 +18,16 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import AsyncGenerator, Optional
21
+ import asyncio
22
+ from typing import AsyncGenerator, AsyncIterable, Optional
22
23
 
23
- from nidx_protos.nodereader_pb2 import StreamRequest
24
+ from nidx_protos.nodereader_pb2 import DocumentItem, StreamRequest
24
25
 
25
26
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
26
27
  from nucliadb.common.nidx import get_nidx_searcher_client
27
28
  from nucliadb.train import logger
28
29
  from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
30
+ from nucliadb.train.settings import settings
29
31
  from nucliadb_protos.dataset_pb2 import (
30
32
  FieldSplitData,
31
33
  FieldStreamingBatch,
@@ -50,82 +52,38 @@ async def generate_field_streaming_payloads(
50
52
  trainset: TrainSet,
51
53
  shard_replica_id: str,
52
54
  ) -> AsyncGenerator[FieldSplitData, None]:
53
- # Query how many resources has each label
54
55
  request = StreamRequest()
55
56
  request.shard_id.id = shard_replica_id
56
57
 
57
58
  for label in trainset.filter.labels:
58
59
  request.filter.labels.append(f"/l/{label}")
59
-
60
60
  for path in trainset.filter.paths:
61
61
  request.filter.labels.append(f"/p/{path}")
62
-
63
62
  for metadata in trainset.filter.metadata:
64
63
  request.filter.labels.append(f"/m/{metadata}")
65
-
66
64
  for entity in trainset.filter.entities:
67
65
  request.filter.labels.append(f"/e/{entity}")
68
-
69
66
  for field in trainset.filter.fields:
70
67
  request.filter.labels.append(f"/f/{field}")
71
-
72
68
  for status in trainset.filter.status:
73
69
  request.filter.labels.append(f"/n/s/{status}")
74
70
 
75
71
  resources = set()
76
72
  fields = set()
77
73
 
78
- async for document_item in get_nidx_searcher_client().Documents(request):
79
- text_labels = []
80
- for label in document_item.labels:
81
- text_labels.append(label)
82
-
83
- field_id = f"{document_item.uuid}{document_item.field}"
84
- resources.add(document_item.uuid)
85
-
86
- field_parts = document_item.field.split("/")
87
- if len(field_parts) == 3:
88
- _, field_type, field = field_parts
89
- split = "0"
90
- elif len(field_parts) == 4:
91
- _, field_type, field, split = field_parts
92
- else:
93
- raise Exception(f"Invalid field definition {document_item.field}")
94
-
95
- tl = FieldSplitData()
96
- rid, field_type, field = field_id.split("/")
97
- tl.rid = document_item.uuid
98
- tl.field = field
99
- tl.field_type = field_type
100
- tl.split = split
101
-
102
- field_unique_key = f"{rid}/{field_type}/{field}/{split}"
74
+ async for fsd in iter_field_split_data(
75
+ request, kbid, trainset, max_parallel=settings.field_streaming_parallelisation
76
+ ):
77
+ resources.add(fsd.rid)
78
+ field_unique_key = f"{fsd.rid}/{fsd.field_type}/{fsd.field}/{fsd.split}"
103
79
  if field_unique_key in fields:
104
80
  # This field has already been yielded. This can happen as we are streaming directly from nidx
105
81
  # and field deletions may not be reflected immediately in the index.
106
82
  logger.warning(f"Duplicated field found {field_unique_key}. Skipping.", extra={"kbid": kbid})
107
83
  continue
108
-
109
84
  fields.add(field_unique_key)
110
85
 
111
- if trainset.exclude_text:
112
- tl.text.text = ""
113
- else:
114
- extracted = await get_field_text(kbid, rid, field, field_type)
115
- if extracted is not None:
116
- tl.text.CopyFrom(extracted)
117
-
118
- metadata_obj = await get_field_metadata(kbid, rid, field, field_type)
119
- if metadata_obj is not None:
120
- tl.metadata.CopyFrom(metadata_obj)
121
-
122
- basic = await get_field_basic(kbid, rid, field, field_type)
123
- if basic is not None:
124
- tl.basic.CopyFrom(basic)
125
-
126
- tl.labels.extend(text_labels)
127
-
128
- yield tl
86
+ yield fsd
129
87
 
130
88
  if len(fields) % 1000 == 0:
131
89
  logger.info(
@@ -149,6 +107,72 @@ async def generate_field_streaming_payloads(
149
107
  )
150
108
 
151
109
 
110
+ async def iter_field_split_data(
111
+ request: StreamRequest, kbid: str, trainset: TrainSet, max_parallel: int = 5
112
+ ) -> AsyncIterable[FieldSplitData]:
113
+ tasks: list[asyncio.Task] = []
114
+ async for document_item in get_nidx_searcher_client().Documents(request):
115
+ if len(tasks) >= max_parallel:
116
+ results = await asyncio.gather(*tasks)
117
+ for fsd in results:
118
+ yield fsd
119
+ tasks.clear()
120
+ tasks.append(asyncio.create_task(fetch_field_split_data(document_item, kbid, trainset)))
121
+ if len(tasks):
122
+ results = await asyncio.gather(*tasks)
123
+ for fsd in results:
124
+ yield fsd
125
+ tasks.clear()
126
+
127
+
128
+ async def fetch_field_split_data(
129
+ document_item: DocumentItem, kbid: str, trainset: TrainSet
130
+ ) -> FieldSplitData:
131
+ field_id = f"{document_item.uuid}{document_item.field}"
132
+ field_parts = document_item.field.split("/")
133
+ if len(field_parts) == 3:
134
+ _, field_type, field = field_parts
135
+ split = "0"
136
+ elif len(field_parts) == 4:
137
+ _, field_type, field, split = field_parts
138
+ else:
139
+ raise Exception(f"Invalid field definition {document_item.field}")
140
+ _, field_type, field = field_id.split("/")
141
+ fsd = FieldSplitData()
142
+ fsd.rid = document_item.uuid
143
+ fsd.field = field
144
+ fsd.field_type = field_type
145
+ fsd.split = split
146
+ tasks = []
147
+ if trainset.exclude_text:
148
+ fsd.text.text = ""
149
+ else:
150
+ tasks.append(asyncio.create_task(_fetch_field_extracted_text(kbid, fsd)))
151
+ tasks.append(asyncio.create_task(_fetch_field_metadata(kbid, fsd)))
152
+ tasks.append(asyncio.create_task(_fetch_basic(kbid, fsd)))
153
+ await asyncio.gather(*tasks)
154
+ fsd.labels.extend(document_item.labels)
155
+ return fsd
156
+
157
+
158
+ async def _fetch_field_extracted_text(kbid: str, fsd: FieldSplitData):
159
+ extracted = await get_field_text(kbid, fsd.rid, fsd.field, fsd.field_type)
160
+ if extracted is not None:
161
+ fsd.text.CopyFrom(extracted)
162
+
163
+
164
+ async def _fetch_field_metadata(kbid: str, fsd: FieldSplitData):
165
+ metadata_obj = await get_field_metadata(kbid, fsd.rid, fsd.field, fsd.field_type)
166
+ if metadata_obj is not None:
167
+ fsd.metadata.CopyFrom(metadata_obj)
168
+
169
+
170
+ async def _fetch_basic(kbid: str, fsd: FieldSplitData):
171
+ basic = await get_field_basic(kbid, fsd.rid, fsd.field, fsd.field_type)
172
+ if basic is not None:
173
+ fsd.basic.CopyFrom(basic)
174
+
175
+
152
176
  async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Optional[ExtractedText]:
153
177
  orm_resource = await get_resource_from_cache_or_db(kbid, rid)
154
178
 
@@ -34,6 +34,7 @@ class Settings(DriverSettings):
34
34
  internal_search_api: str = "http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/search"
35
35
 
36
36
  resource_cache_size: int = 2
37
+ field_streaming_parallelisation: int = 5
37
38
 
38
39
 
39
40
  settings = Settings()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nucliadb
3
- Version: 6.4.2.post4376
3
+ Version: 6.4.2.post4378
4
4
  Summary: NucliaDB
5
5
  Author-email: Nuclia <nucliadb@nuclia.com>
6
6
  License-Expression: AGPL-3.0-or-later
@@ -19,11 +19,11 @@ Classifier: Programming Language :: Python :: 3.12
19
19
  Classifier: Programming Language :: Python :: 3 :: Only
20
20
  Requires-Python: <4,>=3.9
21
21
  Description-Content-Type: text/markdown
22
- Requires-Dist: nucliadb-telemetry[all]>=6.4.2.post4376
23
- Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.2.post4376
24
- Requires-Dist: nucliadb-protos>=6.4.2.post4376
25
- Requires-Dist: nucliadb-models>=6.4.2.post4376
26
- Requires-Dist: nidx-protos>=6.4.2.post4376
22
+ Requires-Dist: nucliadb-telemetry[all]>=6.4.2.post4378
23
+ Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.2.post4378
24
+ Requires-Dist: nucliadb-protos>=6.4.2.post4378
25
+ Requires-Dist: nucliadb-models>=6.4.2.post4378
26
+ Requires-Dist: nidx-protos>=6.4.2.post4378
27
27
  Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
28
28
  Requires-Dist: nuclia-models>=0.24.2
29
29
  Requires-Dist: uvicorn[standard]
@@ -309,7 +309,7 @@ nucliadb/train/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
309
309
  nucliadb/train/resource.py,sha256=3qQ_9Zdt5JAbtD-wpmt7OeDGRNKS-fQdKAuIQfznZm0,16219
310
310
  nucliadb/train/run.py,sha256=evz6CKVfJOzkbHMoaYz2mTMlKjJnNOb1O8zBBWMpeBw,1400
311
311
  nucliadb/train/servicer.py,sha256=scbmq8FriKsJGkOcoZB2Fg_IyIExn9Ux4W30mGDlkJQ,5728
312
- nucliadb/train/settings.py,sha256=8_-XCO_nrE98cMJpe0fYkXeF2bkWKff1VX-2qdwcsjs,1417
312
+ nucliadb/train/settings.py,sha256=Vz-bQxwxYg6Qhc8Vnap95AwlYyCE1LF7NCPlLBfToXI,1462
313
313
  nucliadb/train/types.py,sha256=xyVYy8kHipAWoDb7Pn7dCYQ_efHPzDW_3AXg5M-aV28,1519
314
314
  nucliadb/train/upload.py,sha256=fTjH1KEL-0ogf3LV0T6ODO0QdPGwdZShSUtFUCAcUlA,3256
315
315
  nucliadb/train/uploader.py,sha256=xdLGz1ToDue9Q_M8A-_KYkO-V6fWKYOZQ6IGM4FuwWA,6424
@@ -322,7 +322,7 @@ nucliadb/train/api/v1/shards.py,sha256=GJRnQe8P-7_VTIN1oxVmxlrDA08qVN7opEZdbF4Wx
322
322
  nucliadb/train/api/v1/trainset.py,sha256=kpnpDgiMWr1FKHZJgwH7hue5kzilA8-i9X0YHlNeHuU,2113
323
323
  nucliadb/train/generators/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
324
324
  nucliadb/train/generators/field_classifier.py,sha256=xUA10o9CtBtilbP3uc-8Wn_zQ0oK3BrqYGqZgxh4ZLk,3428
325
- nucliadb/train/generators/field_streaming.py,sha256=nje317SutX8QmHq-xwUphzUiozmzpCRfPXxhF_jFzdg,6441
325
+ nucliadb/train/generators/field_streaming.py,sha256=fq4XxHf5fPLccPjO722tA9Xcw6smmceVqSme0fY2_NA,7631
326
326
  nucliadb/train/generators/image_classifier.py,sha256=BDXgyd5TGZRnzDnVRvp-qsRCuoTbTYwui3JiDIjuiDc,1736
327
327
  nucliadb/train/generators/paragraph_classifier.py,sha256=4sH3IQc7yJrlDs1C76SxFzL9N5mXWRZzJzoiF7y4dSQ,2703
328
328
  nucliadb/train/generators/paragraph_streaming.py,sha256=1xsc_IqP-1M0TzYTqu5qCvWBNp_J3Kyvnx8HVbToXmQ,3532
@@ -368,8 +368,8 @@ nucliadb/writer/tus/local.py,sha256=7jYa_w9b-N90jWgN2sQKkNcomqn6JMVBOVeDOVYJHto,
368
368
  nucliadb/writer/tus/s3.py,sha256=vF0NkFTXiXhXq3bCVXXVV-ED38ECVoUeeYViP8uMqcU,8357
369
369
  nucliadb/writer/tus/storage.py,sha256=ToqwjoYnjI4oIcwzkhha_MPxi-k4Jk3Lt55zRwaC1SM,2903
370
370
  nucliadb/writer/tus/utils.py,sha256=MSdVbRsRSZVdkaum69_0wku7X3p5wlZf4nr6E0GMKbw,2556
371
- nucliadb-6.4.2.post4376.dist-info/METADATA,sha256=FWD9EeIPDvZWgRX4R33CaHbOrlk2YYzGCtrnMGaqR0s,4152
372
- nucliadb-6.4.2.post4376.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
373
- nucliadb-6.4.2.post4376.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
374
- nucliadb-6.4.2.post4376.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
375
- nucliadb-6.4.2.post4376.dist-info/RECORD,,
371
+ nucliadb-6.4.2.post4378.dist-info/METADATA,sha256=W_SC_iY4NnfaM04p0SUiwwdGvfgm4eDKNVNaWXU2mp8,4152
372
+ nucliadb-6.4.2.post4378.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
373
+ nucliadb-6.4.2.post4378.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
374
+ nucliadb-6.4.2.post4378.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
375
+ nucliadb-6.4.2.post4378.dist-info/RECORD,,