nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (418) hide show
  1. migrations/0003_allfields_key.py +1 -35
  2. migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
  3. migrations/0010_fix_corrupt_indexes.py +10 -10
  4. migrations/0011_materialize_labelset_ids.py +1 -16
  5. migrations/0012_rollover_shards.py +5 -10
  6. migrations/0014_rollover_shards.py +4 -5
  7. migrations/0015_targeted_rollover.py +5 -10
  8. migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
  9. migrations/0017_multiple_writable_shards.py +2 -4
  10. migrations/0018_purge_orphan_kbslugs.py +5 -7
  11. migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
  12. migrations/0020_drain_nodes_from_cluster.py +3 -3
  13. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
  14. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  15. migrations/0023_backfill_pg_catalog.py +80 -0
  16. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  17. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  18. migrations/0027_rollover_texts3.py +73 -0
  19. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  20. migrations/pg/0002_catalog.py +42 -0
  21. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  22. nucliadb/common/cluster/base.py +30 -16
  23. nucliadb/common/cluster/discovery/base.py +6 -14
  24. nucliadb/common/cluster/discovery/k8s.py +9 -19
  25. nucliadb/common/cluster/discovery/manual.py +1 -3
  26. nucliadb/common/cluster/discovery/utils.py +1 -3
  27. nucliadb/common/cluster/grpc_node_dummy.py +3 -11
  28. nucliadb/common/cluster/index_node.py +10 -19
  29. nucliadb/common/cluster/manager.py +174 -59
  30. nucliadb/common/cluster/rebalance.py +27 -29
  31. nucliadb/common/cluster/rollover.py +353 -194
  32. nucliadb/common/cluster/settings.py +6 -0
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
  34. nucliadb/common/cluster/standalone/index_node.py +4 -11
  35. nucliadb/common/cluster/standalone/service.py +2 -6
  36. nucliadb/common/cluster/standalone/utils.py +2 -6
  37. nucliadb/common/cluster/utils.py +29 -22
  38. nucliadb/common/constants.py +20 -0
  39. nucliadb/common/context/__init__.py +3 -0
  40. nucliadb/common/context/fastapi.py +8 -5
  41. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  42. nucliadb/common/datamanagers/__init__.py +7 -1
  43. nucliadb/common/datamanagers/atomic.py +22 -4
  44. nucliadb/common/datamanagers/cluster.py +5 -5
  45. nucliadb/common/datamanagers/entities.py +6 -16
  46. nucliadb/common/datamanagers/fields.py +84 -0
  47. nucliadb/common/datamanagers/kb.py +83 -37
  48. nucliadb/common/datamanagers/labels.py +26 -56
  49. nucliadb/common/datamanagers/processing.py +2 -6
  50. nucliadb/common/datamanagers/resources.py +41 -103
  51. nucliadb/common/datamanagers/rollover.py +76 -15
  52. nucliadb/common/datamanagers/synonyms.py +1 -1
  53. nucliadb/common/datamanagers/utils.py +15 -6
  54. nucliadb/common/datamanagers/vectorsets.py +110 -0
  55. nucliadb/common/external_index_providers/base.py +257 -0
  56. nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
  57. nucliadb/common/external_index_providers/manager.py +101 -0
  58. nucliadb/common/external_index_providers/pinecone.py +933 -0
  59. nucliadb/common/external_index_providers/settings.py +52 -0
  60. nucliadb/common/http_clients/auth.py +3 -6
  61. nucliadb/common/http_clients/processing.py +6 -11
  62. nucliadb/common/http_clients/utils.py +1 -3
  63. nucliadb/common/ids.py +240 -0
  64. nucliadb/common/locking.py +29 -7
  65. nucliadb/common/maindb/driver.py +11 -35
  66. nucliadb/common/maindb/exceptions.py +3 -0
  67. nucliadb/common/maindb/local.py +22 -9
  68. nucliadb/common/maindb/pg.py +206 -111
  69. nucliadb/common/maindb/utils.py +11 -42
  70. nucliadb/common/models_utils/from_proto.py +479 -0
  71. nucliadb/common/models_utils/to_proto.py +60 -0
  72. nucliadb/common/nidx.py +260 -0
  73. nucliadb/export_import/datamanager.py +25 -19
  74. nucliadb/export_import/exporter.py +5 -11
  75. nucliadb/export_import/importer.py +5 -7
  76. nucliadb/export_import/models.py +3 -3
  77. nucliadb/export_import/tasks.py +4 -4
  78. nucliadb/export_import/utils.py +25 -37
  79. nucliadb/health.py +1 -3
  80. nucliadb/ingest/app.py +15 -11
  81. nucliadb/ingest/consumer/auditing.py +21 -19
  82. nucliadb/ingest/consumer/consumer.py +82 -47
  83. nucliadb/ingest/consumer/materializer.py +5 -12
  84. nucliadb/ingest/consumer/pull.py +12 -27
  85. nucliadb/ingest/consumer/service.py +19 -17
  86. nucliadb/ingest/consumer/shard_creator.py +2 -4
  87. nucliadb/ingest/consumer/utils.py +1 -3
  88. nucliadb/ingest/fields/base.py +137 -105
  89. nucliadb/ingest/fields/conversation.py +18 -5
  90. nucliadb/ingest/fields/exceptions.py +1 -4
  91. nucliadb/ingest/fields/file.py +7 -16
  92. nucliadb/ingest/fields/link.py +5 -10
  93. nucliadb/ingest/fields/text.py +9 -4
  94. nucliadb/ingest/orm/brain.py +200 -213
  95. nucliadb/ingest/orm/broker_message.py +181 -0
  96. nucliadb/ingest/orm/entities.py +36 -51
  97. nucliadb/ingest/orm/exceptions.py +12 -0
  98. nucliadb/ingest/orm/knowledgebox.py +322 -197
  99. nucliadb/ingest/orm/processor/__init__.py +2 -700
  100. nucliadb/ingest/orm/processor/auditing.py +4 -23
  101. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  102. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  103. nucliadb/ingest/orm/processor/processor.py +752 -0
  104. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  105. nucliadb/ingest/orm/resource.py +249 -402
  106. nucliadb/ingest/orm/utils.py +4 -4
  107. nucliadb/ingest/partitions.py +3 -9
  108. nucliadb/ingest/processing.py +64 -73
  109. nucliadb/ingest/py.typed +0 -0
  110. nucliadb/ingest/serialize.py +37 -167
  111. nucliadb/ingest/service/__init__.py +1 -3
  112. nucliadb/ingest/service/writer.py +185 -412
  113. nucliadb/ingest/settings.py +10 -20
  114. nucliadb/ingest/utils.py +3 -6
  115. nucliadb/learning_proxy.py +242 -55
  116. nucliadb/metrics_exporter.py +30 -19
  117. nucliadb/middleware/__init__.py +1 -3
  118. nucliadb/migrator/command.py +1 -3
  119. nucliadb/migrator/datamanager.py +13 -13
  120. nucliadb/migrator/migrator.py +47 -30
  121. nucliadb/migrator/utils.py +18 -10
  122. nucliadb/purge/__init__.py +139 -33
  123. nucliadb/purge/orphan_shards.py +7 -13
  124. nucliadb/reader/__init__.py +1 -3
  125. nucliadb/reader/api/models.py +1 -12
  126. nucliadb/reader/api/v1/__init__.py +0 -1
  127. nucliadb/reader/api/v1/download.py +21 -88
  128. nucliadb/reader/api/v1/export_import.py +1 -1
  129. nucliadb/reader/api/v1/knowledgebox.py +10 -10
  130. nucliadb/reader/api/v1/learning_config.py +2 -6
  131. nucliadb/reader/api/v1/resource.py +62 -88
  132. nucliadb/reader/api/v1/services.py +64 -83
  133. nucliadb/reader/app.py +12 -29
  134. nucliadb/reader/lifecycle.py +18 -4
  135. nucliadb/reader/py.typed +0 -0
  136. nucliadb/reader/reader/notifications.py +10 -28
  137. nucliadb/search/__init__.py +1 -3
  138. nucliadb/search/api/v1/__init__.py +1 -2
  139. nucliadb/search/api/v1/ask.py +17 -10
  140. nucliadb/search/api/v1/catalog.py +184 -0
  141. nucliadb/search/api/v1/feedback.py +16 -24
  142. nucliadb/search/api/v1/find.py +36 -36
  143. nucliadb/search/api/v1/knowledgebox.py +89 -60
  144. nucliadb/search/api/v1/resource/ask.py +2 -8
  145. nucliadb/search/api/v1/resource/search.py +49 -70
  146. nucliadb/search/api/v1/search.py +44 -210
  147. nucliadb/search/api/v1/suggest.py +39 -54
  148. nucliadb/search/app.py +12 -32
  149. nucliadb/search/lifecycle.py +10 -3
  150. nucliadb/search/predict.py +136 -187
  151. nucliadb/search/py.typed +0 -0
  152. nucliadb/search/requesters/utils.py +25 -58
  153. nucliadb/search/search/cache.py +149 -20
  154. nucliadb/search/search/chat/ask.py +571 -123
  155. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
  156. nucliadb/search/search/chat/images.py +41 -17
  157. nucliadb/search/search/chat/prompt.py +817 -266
  158. nucliadb/search/search/chat/query.py +213 -309
  159. nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
  160. nucliadb/search/search/fetch.py +43 -36
  161. nucliadb/search/search/filters.py +9 -15
  162. nucliadb/search/search/find.py +214 -53
  163. nucliadb/search/search/find_merge.py +408 -391
  164. nucliadb/search/search/hydrator.py +191 -0
  165. nucliadb/search/search/merge.py +187 -223
  166. nucliadb/search/search/metrics.py +73 -2
  167. nucliadb/search/search/paragraphs.py +64 -106
  168. nucliadb/search/search/pgcatalog.py +233 -0
  169. nucliadb/search/search/predict_proxy.py +1 -1
  170. nucliadb/search/search/query.py +305 -150
  171. nucliadb/search/search/query_parser/exceptions.py +22 -0
  172. nucliadb/search/search/query_parser/models.py +101 -0
  173. nucliadb/search/search/query_parser/parser.py +183 -0
  174. nucliadb/search/search/rank_fusion.py +204 -0
  175. nucliadb/search/search/rerankers.py +270 -0
  176. nucliadb/search/search/shards.py +3 -32
  177. nucliadb/search/search/summarize.py +7 -18
  178. nucliadb/search/search/utils.py +27 -4
  179. nucliadb/search/settings.py +15 -1
  180. nucliadb/standalone/api_router.py +4 -10
  181. nucliadb/standalone/app.py +8 -14
  182. nucliadb/standalone/auth.py +7 -21
  183. nucliadb/standalone/config.py +7 -10
  184. nucliadb/standalone/lifecycle.py +26 -25
  185. nucliadb/standalone/migrations.py +1 -3
  186. nucliadb/standalone/purge.py +1 -1
  187. nucliadb/standalone/py.typed +0 -0
  188. nucliadb/standalone/run.py +3 -6
  189. nucliadb/standalone/settings.py +9 -16
  190. nucliadb/standalone/versions.py +15 -5
  191. nucliadb/tasks/consumer.py +8 -12
  192. nucliadb/tasks/producer.py +7 -6
  193. nucliadb/tests/config.py +53 -0
  194. nucliadb/train/__init__.py +1 -3
  195. nucliadb/train/api/utils.py +1 -2
  196. nucliadb/train/api/v1/shards.py +1 -1
  197. nucliadb/train/api/v1/trainset.py +2 -4
  198. nucliadb/train/app.py +10 -31
  199. nucliadb/train/generator.py +10 -19
  200. nucliadb/train/generators/field_classifier.py +7 -19
  201. nucliadb/train/generators/field_streaming.py +156 -0
  202. nucliadb/train/generators/image_classifier.py +12 -18
  203. nucliadb/train/generators/paragraph_classifier.py +5 -9
  204. nucliadb/train/generators/paragraph_streaming.py +6 -9
  205. nucliadb/train/generators/question_answer_streaming.py +19 -20
  206. nucliadb/train/generators/sentence_classifier.py +9 -15
  207. nucliadb/train/generators/token_classifier.py +48 -39
  208. nucliadb/train/generators/utils.py +14 -18
  209. nucliadb/train/lifecycle.py +7 -3
  210. nucliadb/train/nodes.py +23 -32
  211. nucliadb/train/py.typed +0 -0
  212. nucliadb/train/servicer.py +13 -21
  213. nucliadb/train/settings.py +2 -6
  214. nucliadb/train/types.py +13 -10
  215. nucliadb/train/upload.py +3 -6
  216. nucliadb/train/uploader.py +19 -23
  217. nucliadb/train/utils.py +1 -1
  218. nucliadb/writer/__init__.py +1 -3
  219. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  220. nucliadb/writer/api/v1/export_import.py +67 -14
  221. nucliadb/writer/api/v1/field.py +16 -269
  222. nucliadb/writer/api/v1/knowledgebox.py +218 -68
  223. nucliadb/writer/api/v1/resource.py +68 -88
  224. nucliadb/writer/api/v1/services.py +51 -70
  225. nucliadb/writer/api/v1/slug.py +61 -0
  226. nucliadb/writer/api/v1/transaction.py +67 -0
  227. nucliadb/writer/api/v1/upload.py +114 -113
  228. nucliadb/writer/app.py +6 -43
  229. nucliadb/writer/back_pressure.py +16 -38
  230. nucliadb/writer/exceptions.py +0 -4
  231. nucliadb/writer/lifecycle.py +21 -15
  232. nucliadb/writer/py.typed +0 -0
  233. nucliadb/writer/resource/audit.py +2 -1
  234. nucliadb/writer/resource/basic.py +48 -46
  235. nucliadb/writer/resource/field.py +25 -127
  236. nucliadb/writer/resource/origin.py +1 -2
  237. nucliadb/writer/settings.py +6 -2
  238. nucliadb/writer/tus/__init__.py +17 -15
  239. nucliadb/writer/tus/azure.py +111 -0
  240. nucliadb/writer/tus/dm.py +17 -5
  241. nucliadb/writer/tus/exceptions.py +1 -3
  242. nucliadb/writer/tus/gcs.py +49 -84
  243. nucliadb/writer/tus/local.py +21 -37
  244. nucliadb/writer/tus/s3.py +28 -68
  245. nucliadb/writer/tus/storage.py +5 -56
  246. nucliadb/writer/vectorsets.py +125 -0
  247. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  248. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  249. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  250. nucliadb/common/maindb/redis.py +0 -194
  251. nucliadb/common/maindb/tikv.py +0 -433
  252. nucliadb/ingest/fields/layout.py +0 -58
  253. nucliadb/ingest/tests/conftest.py +0 -30
  254. nucliadb/ingest/tests/fixtures.py +0 -764
  255. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  256. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
  257. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
  258. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  259. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  260. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  261. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
  262. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  263. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  264. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  265. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
  266. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  267. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  268. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
  269. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  270. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  271. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  272. nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
  273. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  274. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
  275. nucliadb/ingest/tests/unit/test_cache.py +0 -31
  276. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  277. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  278. nucliadb/middleware/transaction.py +0 -117
  279. nucliadb/reader/api/v1/learning_collector.py +0 -63
  280. nucliadb/reader/tests/__init__.py +0 -19
  281. nucliadb/reader/tests/conftest.py +0 -31
  282. nucliadb/reader/tests/fixtures.py +0 -136
  283. nucliadb/reader/tests/test_list_resources.py +0 -75
  284. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  285. nucliadb/reader/tests/test_reader_resource.py +0 -353
  286. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  287. nucliadb/search/api/v1/chat.py +0 -263
  288. nucliadb/search/api/v1/resource/chat.py +0 -174
  289. nucliadb/search/tests/__init__.py +0 -19
  290. nucliadb/search/tests/conftest.py +0 -33
  291. nucliadb/search/tests/fixtures.py +0 -199
  292. nucliadb/search/tests/node.py +0 -466
  293. nucliadb/search/tests/unit/__init__.py +0 -18
  294. nucliadb/search/tests/unit/api/__init__.py +0 -19
  295. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  296. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  297. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
  298. nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
  299. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  300. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  301. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
  302. nucliadb/search/tests/unit/search/__init__.py +0 -18
  303. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  304. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
  305. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  306. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  307. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  308. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
  309. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  310. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  311. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  312. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  313. nucliadb/search/tests/unit/search/test_query.py +0 -153
  314. nucliadb/search/tests/unit/test_app.py +0 -79
  315. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  316. nucliadb/search/tests/unit/test_merge.py +0 -34
  317. nucliadb/search/tests/unit/test_predict.py +0 -525
  318. nucliadb/standalone/tests/__init__.py +0 -19
  319. nucliadb/standalone/tests/conftest.py +0 -33
  320. nucliadb/standalone/tests/fixtures.py +0 -38
  321. nucliadb/standalone/tests/unit/__init__.py +0 -18
  322. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  323. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  324. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  325. nucliadb/standalone/tests/unit/test_migrations.py +0 -63
  326. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  327. nucliadb/tests/benchmarks/__init__.py +0 -19
  328. nucliadb/tests/benchmarks/test_search.py +0 -99
  329. nucliadb/tests/conftest.py +0 -32
  330. nucliadb/tests/fixtures.py +0 -735
  331. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
  332. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
  333. nucliadb/tests/migrations/test_migration_0017.py +0 -76
  334. nucliadb/tests/migrations/test_migration_0018.py +0 -95
  335. nucliadb/tests/tikv.py +0 -240
  336. nucliadb/tests/unit/__init__.py +0 -19
  337. nucliadb/tests/unit/common/__init__.py +0 -19
  338. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  339. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  340. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
  341. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  342. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
  343. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
  344. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
  345. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
  346. nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
  347. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
  348. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  349. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  350. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  351. nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
  352. nucliadb/tests/unit/common/test_context.py +0 -36
  353. nucliadb/tests/unit/export_import/__init__.py +0 -19
  354. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  355. nucliadb/tests/unit/export_import/test_utils.py +0 -301
  356. nucliadb/tests/unit/migrator/__init__.py +0 -19
  357. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  358. nucliadb/tests/unit/tasks/__init__.py +0 -19
  359. nucliadb/tests/unit/tasks/conftest.py +0 -42
  360. nucliadb/tests/unit/tasks/test_consumer.py +0 -92
  361. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  362. nucliadb/tests/unit/tasks/test_tasks.py +0 -58
  363. nucliadb/tests/unit/test_field_ids.py +0 -49
  364. nucliadb/tests/unit/test_health.py +0 -86
  365. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  366. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  367. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  368. nucliadb/tests/unit/test_purge.py +0 -136
  369. nucliadb/tests/utils/__init__.py +0 -74
  370. nucliadb/tests/utils/aiohttp_session.py +0 -44
  371. nucliadb/tests/utils/broker_messages/__init__.py +0 -171
  372. nucliadb/tests/utils/broker_messages/fields.py +0 -197
  373. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  374. nucliadb/tests/utils/entities.py +0 -78
  375. nucliadb/train/api/v1/check.py +0 -60
  376. nucliadb/train/tests/__init__.py +0 -19
  377. nucliadb/train/tests/conftest.py +0 -29
  378. nucliadb/train/tests/fixtures.py +0 -342
  379. nucliadb/train/tests/test_field_classification.py +0 -122
  380. nucliadb/train/tests/test_get_entities.py +0 -80
  381. nucliadb/train/tests/test_get_info.py +0 -51
  382. nucliadb/train/tests/test_get_ontology.py +0 -34
  383. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  384. nucliadb/train/tests/test_image_classification.py +0 -221
  385. nucliadb/train/tests/test_list_fields.py +0 -39
  386. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  387. nucliadb/train/tests/test_list_resources.py +0 -39
  388. nucliadb/train/tests/test_list_sentences.py +0 -71
  389. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  390. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  391. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  392. nucliadb/train/tests/test_sentence_classification.py +0 -143
  393. nucliadb/train/tests/test_token_classification.py +0 -136
  394. nucliadb/train/tests/utils.py +0 -101
  395. nucliadb/writer/layouts/__init__.py +0 -51
  396. nucliadb/writer/layouts/v1.py +0 -59
  397. nucliadb/writer/tests/__init__.py +0 -19
  398. nucliadb/writer/tests/conftest.py +0 -31
  399. nucliadb/writer/tests/fixtures.py +0 -191
  400. nucliadb/writer/tests/test_fields.py +0 -475
  401. nucliadb/writer/tests/test_files.py +0 -740
  402. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  403. nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
  404. nucliadb/writer/tests/test_resources.py +0 -476
  405. nucliadb/writer/tests/test_service.py +0 -137
  406. nucliadb/writer/tests/test_tus.py +0 -203
  407. nucliadb/writer/tests/utils.py +0 -35
  408. nucliadb/writer/tus/pg.py +0 -125
  409. nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
  410. nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
  411. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  412. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  413. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  414. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  415. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  416. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  417. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  418. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -21,11 +21,11 @@
21
21
  from typing import AsyncIterator, Optional
22
22
 
23
23
  from fastapi import HTTPException
24
- from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
25
24
 
26
25
  from nucliadb.train.generators.field_classifier import (
27
26
  field_classification_batch_generator,
28
27
  )
28
+ from nucliadb.train.generators.field_streaming import field_streaming_batch_generator
29
29
  from nucliadb.train.generators.image_classifier import (
30
30
  image_classification_batch_generator,
31
31
  )
@@ -46,6 +46,7 @@ from nucliadb.train.generators.token_classifier import (
46
46
  )
47
47
  from nucliadb.train.types import TrainBatch
48
48
  from nucliadb.train.utils import get_shard_manager
49
+ from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
49
50
 
50
51
 
51
52
  async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
@@ -59,34 +60,24 @@ async def generate_train_data(kbid: str, shard: str, trainset: TrainSet):
59
60
  batch_generator: Optional[AsyncIterator[TrainBatch]] = None
60
61
 
61
62
  if trainset.type == TaskType.FIELD_CLASSIFICATION:
62
- batch_generator = field_classification_batch_generator(
63
- kbid, trainset, node, shard_replica_id
64
- )
63
+ batch_generator = field_classification_batch_generator(kbid, trainset, node, shard_replica_id)
65
64
  elif trainset.type == TaskType.IMAGE_CLASSIFICATION:
66
- batch_generator = image_classification_batch_generator(
67
- kbid, trainset, node, shard_replica_id
68
- )
65
+ batch_generator = image_classification_batch_generator(kbid, trainset, node, shard_replica_id)
69
66
  elif trainset.type == TaskType.PARAGRAPH_CLASSIFICATION:
70
67
  batch_generator = paragraph_classification_batch_generator(
71
68
  kbid, trainset, node, shard_replica_id
72
69
  )
73
70
  elif trainset.type == TaskType.TOKEN_CLASSIFICATION:
74
- batch_generator = token_classification_batch_generator(
75
- kbid, trainset, node, shard_replica_id
76
- )
71
+ batch_generator = token_classification_batch_generator(kbid, trainset, node, shard_replica_id)
77
72
  elif trainset.type == TaskType.SENTENCE_CLASSIFICATION:
78
- batch_generator = sentence_classification_batch_generator(
79
- kbid, trainset, node, shard_replica_id
80
- )
73
+ batch_generator = sentence_classification_batch_generator(kbid, trainset, node, shard_replica_id)
81
74
  elif trainset.type == TaskType.PARAGRAPH_STREAMING:
82
- batch_generator = paragraph_streaming_batch_generator(
83
- kbid, trainset, node, shard_replica_id
84
- )
75
+ batch_generator = paragraph_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
85
76
 
86
77
  elif trainset.type == TaskType.QUESTION_ANSWER_STREAMING:
87
- batch_generator = question_answer_batch_generator(
88
- kbid, trainset, node, shard_replica_id
89
- )
78
+ batch_generator = question_answer_batch_generator(kbid, trainset, node, shard_replica_id)
79
+ elif trainset.type == TaskType.FIELD_STREAMING:
80
+ batch_generator = field_streaming_batch_generator(kbid, trainset, node, shard_replica_id)
90
81
 
91
82
  if batch_generator is None:
92
83
  raise HTTPException(
@@ -20,7 +20,10 @@
20
20
 
21
21
  from typing import AsyncGenerator
22
22
 
23
- from fastapi import HTTPException
23
+ from nucliadb.common.cluster.base import AbstractIndexNode
24
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
25
+ from nucliadb.train import logger
26
+ from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
24
27
  from nucliadb_protos.dataset_pb2 import (
25
28
  FieldClassificationBatch,
26
29
  Label,
@@ -29,11 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
29
32
  )
30
33
  from nucliadb_protos.nodereader_pb2 import StreamRequest
31
34
 
32
- from nucliadb.common.cluster.base import AbstractIndexNode
33
- from nucliadb.ingest.orm.resource import KB_REVERSE
34
- from nucliadb.train import logger
35
- from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
36
-
37
35
 
38
36
  def field_classification_batch_generator(
39
37
  kbid: str,
@@ -41,15 +39,7 @@ def field_classification_batch_generator(
41
39
  node: AbstractIndexNode,
42
40
  shard_replica_id: str,
43
41
  ) -> AsyncGenerator[FieldClassificationBatch, None]:
44
- if len(trainset.filter.labels) != 1:
45
- raise HTTPException(
46
- status_code=422,
47
- detail="Paragraph Classification should be of 1 labelset",
48
- )
49
-
50
- generator = generate_field_classification_payloads(
51
- kbid, trainset, node, shard_replica_id
52
- )
42
+ generator = generate_field_classification_payloads(kbid, trainset, node, shard_replica_id)
53
43
  batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
54
44
  return batch_generator
55
45
 
@@ -95,13 +85,11 @@ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> st
95
85
  logger.error(f"{rid} does not exist on DB")
96
86
  return ""
97
87
 
98
- field_type_int = KB_REVERSE[field_type]
88
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
99
89
  field_obj = await orm_resource.get_field(field, field_type_int, load=False)
100
90
  extracted_text = await field_obj.get_extracted_text()
101
91
  if extracted_text is None:
102
- logger.warning(
103
- f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
104
- )
92
+ logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
105
93
  return ""
106
94
 
107
95
  text = ""
@@ -0,0 +1,156 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+
21
+ from typing import AsyncGenerator, Optional
22
+
23
+ from nucliadb.common.cluster.base import AbstractIndexNode
24
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
25
+ from nucliadb.train import logger
26
+ from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
27
+ from nucliadb_protos.dataset_pb2 import (
28
+ FieldSplitData,
29
+ FieldStreamingBatch,
30
+ TrainSet,
31
+ )
32
+ from nucliadb_protos.nodereader_pb2 import StreamRequest
33
+ from nucliadb_protos.resources_pb2 import Basic, FieldComputedMetadata
34
+ from nucliadb_protos.utils_pb2 import ExtractedText
35
+
36
+
37
+ def field_streaming_batch_generator(
38
+ kbid: str,
39
+ trainset: TrainSet,
40
+ node: AbstractIndexNode,
41
+ shard_replica_id: str,
42
+ ) -> AsyncGenerator[FieldStreamingBatch, None]:
43
+ generator = generate_field_streaming_payloads(kbid, trainset, node, shard_replica_id)
44
+ batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
45
+ return batch_generator
46
+
47
+
48
+ async def generate_field_streaming_payloads(
49
+ kbid: str,
50
+ trainset: TrainSet,
51
+ node: AbstractIndexNode,
52
+ shard_replica_id: str,
53
+ ) -> AsyncGenerator[FieldSplitData, None]:
54
+ # Query how many resources has each label
55
+ request = StreamRequest()
56
+ request.shard_id.id = shard_replica_id
57
+
58
+ for label in trainset.filter.labels:
59
+ request.filter.labels.append(f"/l/{label}")
60
+
61
+ for path in trainset.filter.paths:
62
+ request.filter.labels.append(f"/p/{path}")
63
+
64
+ for metadata in trainset.filter.metadata:
65
+ request.filter.labels.append(f"/m/{metadata}")
66
+
67
+ for entity in trainset.filter.entities:
68
+ request.filter.labels.append(f"/e/{entity}")
69
+
70
+ for field in trainset.filter.fields:
71
+ request.filter.labels.append(f"/f/{field}")
72
+
73
+ for status in trainset.filter.status:
74
+ request.filter.labels.append(f"/n/s/{status}")
75
+ total = 0
76
+
77
+ async for document_item in node.stream_get_fields(request):
78
+ text_labels = []
79
+ for label in document_item.labels:
80
+ text_labels.append(label)
81
+
82
+ field_id = f"{document_item.uuid}{document_item.field}"
83
+ total += 1
84
+
85
+ field_parts = document_item.field.split("/")
86
+ if len(field_parts) == 3:
87
+ _, field_type, field = field_parts
88
+ split = "0"
89
+ elif len(field_parts) == 4:
90
+ _, field_type, field, split = field_parts
91
+ else:
92
+ raise Exception(f"Invalid field definition {document_item.field}")
93
+
94
+ tl = FieldSplitData()
95
+ rid, field_type, field = field_id.split("/")
96
+ tl.rid = document_item.uuid
97
+ tl.field = field
98
+ tl.field_type = field_type
99
+ tl.split = split
100
+ extracted = await get_field_text(kbid, rid, field, field_type)
101
+ if extracted is not None:
102
+ tl.text.CopyFrom(extracted)
103
+
104
+ metadata_obj = await get_field_metadata(kbid, rid, field, field_type)
105
+ if metadata_obj is not None:
106
+ tl.metadata.CopyFrom(metadata_obj)
107
+
108
+ basic = await get_field_basic(kbid, rid, field, field_type)
109
+ if basic is not None:
110
+ tl.basic.CopyFrom(basic)
111
+
112
+ tl.labels.extend(text_labels)
113
+
114
+ yield tl
115
+
116
+
117
+ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Optional[ExtractedText]:
118
+ orm_resource = await get_resource_from_cache_or_db(kbid, rid)
119
+
120
+ if orm_resource is None:
121
+ logger.error(f"{rid} does not exist on DB")
122
+ return None
123
+
124
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
125
+ field_obj = await orm_resource.get_field(field, field_type_int, load=False)
126
+ extracted_text = await field_obj.get_extracted_text()
127
+
128
+ return extracted_text
129
+
130
+
131
+ async def get_field_metadata(
132
+ kbid: str, rid: str, field: str, field_type: str
133
+ ) -> Optional[FieldComputedMetadata]:
134
+ orm_resource = await get_resource_from_cache_or_db(kbid, rid)
135
+
136
+ if orm_resource is None:
137
+ logger.error(f"{rid} does not exist on DB")
138
+ return None
139
+
140
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
141
+ field_obj = await orm_resource.get_field(field, field_type_int, load=False)
142
+ field_metadata = await field_obj.get_field_metadata()
143
+
144
+ return field_metadata
145
+
146
+
147
+ async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Optional[Basic]:
148
+ orm_resource = await get_resource_from_cache_or_db(kbid, rid)
149
+
150
+ if orm_resource is None:
151
+ logger.error(f"{rid} does not exist on DB")
152
+ return None
153
+
154
+ basic = await orm_resource.get_basic()
155
+
156
+ return basic
@@ -21,6 +21,12 @@
21
21
  import json
22
22
  from typing import Any, AsyncGenerator
23
23
 
24
+ from nucliadb.common.cluster.base import AbstractIndexNode
25
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
26
+ from nucliadb.ingest.fields.base import Field
27
+ from nucliadb.ingest.orm.resource import Resource
28
+ from nucliadb.train import logger
29
+ from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
24
30
  from nucliadb_protos.dataset_pb2 import (
25
31
  ImageClassification,
26
32
  ImageClassificationBatch,
@@ -29,12 +35,6 @@ from nucliadb_protos.dataset_pb2 import (
29
35
  from nucliadb_protos.nodereader_pb2 import StreamRequest
30
36
  from nucliadb_protos.resources_pb2 import FieldType, PageStructure, VisualSelection
31
37
 
32
- from nucliadb.common.cluster.base import AbstractIndexNode
33
- from nucliadb.ingest.fields.base import Field
34
- from nucliadb.ingest.orm.resource import KB_REVERSE, Resource
35
- from nucliadb.train import logger
36
- from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
37
-
38
38
  VISUALLY_ANNOTABLE_FIELDS = {FieldType.FILE, FieldType.LINK}
39
39
 
40
40
  # PAWLS JSON format
@@ -47,9 +47,7 @@ def image_classification_batch_generator(
47
47
  node: AbstractIndexNode,
48
48
  shard_replica_id: str,
49
49
  ) -> AsyncGenerator[ImageClassificationBatch, None]:
50
- generator = generate_image_classification_payloads(
51
- kbid, trainset, node, shard_replica_id
52
- )
50
+ generator = generate_image_classification_payloads(kbid, trainset, node, shard_replica_id)
53
51
  batch_generator = batchify(generator, trainset.batch_size, ImageClassificationBatch)
54
52
  return batch_generator
55
53
 
@@ -71,7 +69,7 @@ async def generate_image_classification_payloads(
71
69
  return
72
70
 
73
71
  _, field_type_key, field_key = item.field.split("/")
74
- field_type = KB_REVERSE[field_type_key]
72
+ field_type = FIELD_TYPE_STR_TO_PB[field_type_key]
75
73
 
76
74
  if field_type not in VISUALLY_ANNOTABLE_FIELDS:
77
75
  continue
@@ -131,9 +129,7 @@ async def generate_image_classification_payloads(
131
129
  yield ic
132
130
 
133
131
 
134
- async def get_page_selections(
135
- resource: Resource, field: Field
136
- ) -> dict[int, list[VisualSelection]]:
132
+ async def get_page_selections(resource: Resource, field: Field) -> dict[int, list[VisualSelection]]:
137
133
  page_selections: dict[int, list[VisualSelection]] = {}
138
134
  basic = await resource.get_basic()
139
135
  if basic is None or basic.fieldmetadata is None:
@@ -144,7 +140,7 @@ async def get_page_selections(
144
140
  for fieldmetadata in basic.fieldmetadata:
145
141
  if (
146
142
  fieldmetadata.field.field == field.id
147
- and fieldmetadata.field.field_type == KB_REVERSE[field.type]
143
+ and fieldmetadata.field.field_type == FIELD_TYPE_STR_TO_PB[field.type]
148
144
  ):
149
145
  for selection in fieldmetadata.page_selections:
150
146
  page_selections[selection.page] = selection.visual # type: ignore
@@ -155,7 +151,7 @@ async def get_page_selections(
155
151
 
156
152
  async def get_page_structure(field: Field) -> list[tuple[str, PageStructure]]:
157
153
  page_structures: list[tuple[str, PageStructure]] = []
158
- field_type = KB_REVERSE[field.type]
154
+ field_type = FIELD_TYPE_STR_TO_PB[field.type]
159
155
  if field_type == FieldType.FILE:
160
156
  fed = await field.get_file_extracted_data() # type: ignore
161
157
  if fed is None:
@@ -163,9 +159,7 @@ async def get_page_structure(field: Field) -> list[tuple[str, PageStructure]]:
163
159
 
164
160
  fp = fed.file_pages_previews
165
161
  if len(fp.pages) != len(fp.structures):
166
- field_path = (
167
- f"/kb/{field.kbid}/resource/{field.resource.uuid}/file/{field.id}"
168
- )
162
+ field_path = f"/kb/{field.kbid}/resource/{field.resource.uuid}/file/{field.id}"
169
163
  logger.warning(
170
164
  f"File extracted data has a different number of pages and structures! ({field_path})"
171
165
  )
@@ -21,6 +21,9 @@
21
21
  from typing import AsyncGenerator
22
22
 
23
23
  from fastapi import HTTPException
24
+
25
+ from nucliadb.common.cluster.base import AbstractIndexNode
26
+ from nucliadb.train.generators.utils import batchify, get_paragraph
24
27
  from nucliadb_protos.dataset_pb2 import (
25
28
  Label,
26
29
  ParagraphClassificationBatch,
@@ -29,9 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
29
32
  )
30
33
  from nucliadb_protos.nodereader_pb2 import StreamRequest
31
34
 
32
- from nucliadb.common.cluster.base import AbstractIndexNode
33
- from nucliadb.train.generators.utils import batchify, get_paragraph
34
-
35
35
 
36
36
  def paragraph_classification_batch_generator(
37
37
  kbid: str,
@@ -45,12 +45,8 @@ def paragraph_classification_batch_generator(
45
45
  detail="Paragraph Classification should be of 1 labelset",
46
46
  )
47
47
 
48
- generator = generate_paragraph_classification_payloads(
49
- kbid, trainset, node, shard_replica_id
50
- )
51
- batch_generator = batchify(
52
- generator, trainset.batch_size, ParagraphClassificationBatch
53
- )
48
+ generator = generate_paragraph_classification_payloads(kbid, trainset, node, shard_replica_id)
49
+ batch_generator = batchify(generator, trainset.batch_size, ParagraphClassificationBatch)
54
50
  return batch_generator
55
51
 
56
52
 
@@ -20,6 +20,10 @@
20
20
 
21
21
  from typing import AsyncGenerator
22
22
 
23
+ from nucliadb.common.cluster.base import AbstractIndexNode
24
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
25
+ from nucliadb.train import logger
26
+ from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
23
27
  from nucliadb_protos.dataset_pb2 import (
24
28
  ParagraphStreamingBatch,
25
29
  ParagraphStreamItem,
@@ -27,11 +31,6 @@ from nucliadb_protos.dataset_pb2 import (
27
31
  )
28
32
  from nucliadb_protos.nodereader_pb2 import StreamRequest
29
33
 
30
- from nucliadb.common.cluster.base import AbstractIndexNode
31
- from nucliadb.ingest.orm.resource import KB_REVERSE
32
- from nucliadb.train import logger
33
- from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
34
-
35
34
 
36
35
  def paragraph_streaming_batch_generator(
37
36
  kbid: str,
@@ -39,9 +38,7 @@ def paragraph_streaming_batch_generator(
39
38
  node: AbstractIndexNode,
40
39
  shard_replica_id: str,
41
40
  ) -> AsyncGenerator[ParagraphStreamingBatch, None]:
42
- generator = generate_paragraph_streaming_payloads(
43
- kbid, trainset, node, shard_replica_id
44
- )
41
+ generator = generate_paragraph_streaming_payloads(kbid, trainset, node, shard_replica_id)
45
42
  batch_generator = batchify(generator, trainset.batch_size, ParagraphStreamingBatch)
46
43
  return batch_generator
47
44
 
@@ -68,7 +65,7 @@ async def generate_paragraph_streaming_payloads(
68
65
  logger.error(f"{rid} does not exist on DB")
69
66
  continue
70
67
 
71
- field_type_int = KB_REVERSE[field_type]
68
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
72
69
  field_obj = await orm_resource.get_field(field, field_type_int, load=False)
73
70
 
74
71
  extracted_text = await field_obj.get_extracted_text()
@@ -20,6 +20,14 @@
20
20
 
21
21
  from typing import AsyncGenerator
22
22
 
23
+ from nucliadb.common.cluster.base import AbstractIndexNode
24
+ from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB
25
+ from nucliadb.train import logger
26
+ from nucliadb.train.generators.utils import (
27
+ batchify,
28
+ get_paragraph,
29
+ get_resource_from_cache_or_db,
30
+ )
23
31
  from nucliadb_protos.dataset_pb2 import (
24
32
  QuestionAnswerStreamingBatch,
25
33
  QuestionAnswerStreamItem,
@@ -32,15 +40,6 @@ from nucliadb_protos.resources_pb2 import (
32
40
  QuestionAnswerAnnotation,
33
41
  )
34
42
 
35
- from nucliadb.common.cluster.base import AbstractIndexNode
36
- from nucliadb.ingest.orm.resource import FIELD_TYPE_TO_ID, KB_REVERSE
37
- from nucliadb.train import logger
38
- from nucliadb.train.generators.utils import (
39
- batchify,
40
- get_paragraph,
41
- get_resource_from_cache_or_db,
42
- )
43
-
44
43
 
45
44
  def question_answer_batch_generator(
46
45
  kbid: str,
@@ -48,12 +47,8 @@ def question_answer_batch_generator(
48
47
  node: AbstractIndexNode,
49
48
  shard_replica_id: str,
50
49
  ) -> AsyncGenerator[QuestionAnswerStreamingBatch, None]:
51
- generator = generate_question_answer_streaming_payloads(
52
- kbid, trainset, node, shard_replica_id
53
- )
54
- batch_generator = batchify(
55
- generator, trainset.batch_size, QuestionAnswerStreamingBatch
56
- )
50
+ generator = generate_question_answer_streaming_payloads(kbid, trainset, node, shard_replica_id)
51
+ batch_generator = batchify(generator, trainset.batch_size, QuestionAnswerStreamingBatch)
57
52
  return batch_generator
58
53
 
59
54
 
@@ -90,14 +85,18 @@ async def generate_question_answer_streaming_payloads(
90
85
  item.cancelled_by_user = qa_annotation_pb.cancelled_by_user
91
86
  yield item
92
87
 
93
- field_type_int = KB_REVERSE[field_type]
88
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
94
89
  field_obj = await orm_resource.get_field(field, field_type_int, load=False)
95
90
 
96
91
  question_answers_pb = await field_obj.get_question_answers()
97
92
  if question_answers_pb is not None:
98
- for question_answer_pb in question_answers_pb.question_answer:
93
+ for question_answer_pb in question_answers_pb.question_answers.question_answer:
99
94
  async for item in iter_stream_items(kbid, question_answer_pb):
100
95
  yield item
96
+ for question_answer_pb in question_answers_pb.split_question_answers.values():
97
+ for split_question_answer_pb in question_answer_pb.question_answer:
98
+ async for item in iter_stream_items(kbid, split_question_answer_pb):
99
+ yield item
101
100
 
102
101
 
103
102
  async def iter_stream_items(
@@ -109,7 +108,7 @@ async def iter_stream_items(
109
108
  for paragraph_id in question_pb.ids_paragraphs:
110
109
  try:
111
110
  text = await get_paragraph(kbid, paragraph_id)
112
- except Exception as exc: # pragma: nocover
111
+ except Exception as exc: # pragma: no cover
113
112
  logger.warning(
114
113
  "Question paragraph couldn't be fetched while streaming Q&A",
115
114
  extra={"kbid": kbid, "paragraph_id": paragraph_id},
@@ -128,7 +127,7 @@ async def iter_stream_items(
128
127
  for paragraph_id in answer_pb.ids_paragraphs:
129
128
  try:
130
129
  text = await get_paragraph(kbid, paragraph_id)
131
- except Exception as exc: # pragma: nocover
130
+ except Exception as exc: # pragma: no cover
132
131
  logger.warning(
133
132
  "Answer paragraph couldn't be fetched while streaming Q&A",
134
133
  extra={"kbid": kbid, "paragraph_id": paragraph_id},
@@ -141,4 +140,4 @@ async def iter_stream_items(
141
140
 
142
141
 
143
142
  def is_same_field(field: FieldID, field_id: str, field_type: str) -> bool:
144
- return field.field == field_id and FIELD_TYPE_TO_ID[field.field_type] == field_type
143
+ return field.field == field_id and FIELD_TYPE_PB_TO_STR[field.field_type] == field_type
@@ -21,6 +21,11 @@
21
21
  from typing import AsyncGenerator
22
22
 
23
23
  from fastapi import HTTPException
24
+
25
+ from nucliadb.common.cluster.base import AbstractIndexNode
26
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
27
+ from nucliadb.train import logger
28
+ from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
24
29
  from nucliadb_protos.dataset_pb2 import (
25
30
  Label,
26
31
  MultipleTextSameLabels,
@@ -29,11 +34,6 @@ from nucliadb_protos.dataset_pb2 import (
29
34
  )
30
35
  from nucliadb_protos.nodereader_pb2 import StreamRequest
31
36
 
32
- from nucliadb.common.cluster.base import AbstractIndexNode
33
- from nucliadb.ingest.orm.resource import KB_REVERSE
34
- from nucliadb.train import logger
35
- from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
36
-
37
37
 
38
38
  def sentence_classification_batch_generator(
39
39
  kbid: str,
@@ -47,12 +47,8 @@ def sentence_classification_batch_generator(
47
47
  detail="Sentence Classification should be at least of 1 labelset",
48
48
  )
49
49
 
50
- generator = generate_sentence_classification_payloads(
51
- kbid, trainset, node, shard_replica_id
52
- )
53
- batch_generator = batchify(
54
- generator, trainset.batch_size, SentenceClassificationBatch
55
- )
50
+ generator = generate_sentence_classification_payloads(kbid, trainset, node, shard_replica_id)
51
+ batch_generator = batchify(generator, trainset.batch_size, SentenceClassificationBatch)
56
52
  return batch_generator
57
53
 
58
54
 
@@ -107,14 +103,12 @@ async def get_sentences(kbid: str, result: str) -> list[str]:
107
103
  logger.error(f"{rid} does not exist on DB")
108
104
  return []
109
105
 
110
- field_type_int = KB_REVERSE[field_type]
106
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
111
107
  field_obj = await orm_resource.get_field(field, field_type_int, load=False)
112
108
  extracted_text = await field_obj.get_extracted_text()
113
109
  field_metadata = await field_obj.get_field_metadata()
114
110
  if extracted_text is None:
115
- logger.warning(
116
- f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
117
- )
111
+ logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
118
112
  return []
119
113
 
120
114
  splitted_texts = []