nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (418) hide show
  1. migrations/0003_allfields_key.py +1 -35
  2. migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
  3. migrations/0010_fix_corrupt_indexes.py +10 -10
  4. migrations/0011_materialize_labelset_ids.py +1 -16
  5. migrations/0012_rollover_shards.py +5 -10
  6. migrations/0014_rollover_shards.py +4 -5
  7. migrations/0015_targeted_rollover.py +5 -10
  8. migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
  9. migrations/0017_multiple_writable_shards.py +2 -4
  10. migrations/0018_purge_orphan_kbslugs.py +5 -7
  11. migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
  12. migrations/0020_drain_nodes_from_cluster.py +3 -3
  13. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
  14. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  15. migrations/0023_backfill_pg_catalog.py +80 -0
  16. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  17. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  18. migrations/0027_rollover_texts3.py +73 -0
  19. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  20. migrations/pg/0002_catalog.py +42 -0
  21. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  22. nucliadb/common/cluster/base.py +30 -16
  23. nucliadb/common/cluster/discovery/base.py +6 -14
  24. nucliadb/common/cluster/discovery/k8s.py +9 -19
  25. nucliadb/common/cluster/discovery/manual.py +1 -3
  26. nucliadb/common/cluster/discovery/utils.py +1 -3
  27. nucliadb/common/cluster/grpc_node_dummy.py +3 -11
  28. nucliadb/common/cluster/index_node.py +10 -19
  29. nucliadb/common/cluster/manager.py +174 -59
  30. nucliadb/common/cluster/rebalance.py +27 -29
  31. nucliadb/common/cluster/rollover.py +353 -194
  32. nucliadb/common/cluster/settings.py +6 -0
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
  34. nucliadb/common/cluster/standalone/index_node.py +4 -11
  35. nucliadb/common/cluster/standalone/service.py +2 -6
  36. nucliadb/common/cluster/standalone/utils.py +2 -6
  37. nucliadb/common/cluster/utils.py +29 -22
  38. nucliadb/common/constants.py +20 -0
  39. nucliadb/common/context/__init__.py +3 -0
  40. nucliadb/common/context/fastapi.py +8 -5
  41. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  42. nucliadb/common/datamanagers/__init__.py +7 -1
  43. nucliadb/common/datamanagers/atomic.py +22 -4
  44. nucliadb/common/datamanagers/cluster.py +5 -5
  45. nucliadb/common/datamanagers/entities.py +6 -16
  46. nucliadb/common/datamanagers/fields.py +84 -0
  47. nucliadb/common/datamanagers/kb.py +83 -37
  48. nucliadb/common/datamanagers/labels.py +26 -56
  49. nucliadb/common/datamanagers/processing.py +2 -6
  50. nucliadb/common/datamanagers/resources.py +41 -103
  51. nucliadb/common/datamanagers/rollover.py +76 -15
  52. nucliadb/common/datamanagers/synonyms.py +1 -1
  53. nucliadb/common/datamanagers/utils.py +15 -6
  54. nucliadb/common/datamanagers/vectorsets.py +110 -0
  55. nucliadb/common/external_index_providers/base.py +257 -0
  56. nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
  57. nucliadb/common/external_index_providers/manager.py +101 -0
  58. nucliadb/common/external_index_providers/pinecone.py +933 -0
  59. nucliadb/common/external_index_providers/settings.py +52 -0
  60. nucliadb/common/http_clients/auth.py +3 -6
  61. nucliadb/common/http_clients/processing.py +6 -11
  62. nucliadb/common/http_clients/utils.py +1 -3
  63. nucliadb/common/ids.py +240 -0
  64. nucliadb/common/locking.py +29 -7
  65. nucliadb/common/maindb/driver.py +11 -35
  66. nucliadb/common/maindb/exceptions.py +3 -0
  67. nucliadb/common/maindb/local.py +22 -9
  68. nucliadb/common/maindb/pg.py +206 -111
  69. nucliadb/common/maindb/utils.py +11 -42
  70. nucliadb/common/models_utils/from_proto.py +479 -0
  71. nucliadb/common/models_utils/to_proto.py +60 -0
  72. nucliadb/common/nidx.py +260 -0
  73. nucliadb/export_import/datamanager.py +25 -19
  74. nucliadb/export_import/exporter.py +5 -11
  75. nucliadb/export_import/importer.py +5 -7
  76. nucliadb/export_import/models.py +3 -3
  77. nucliadb/export_import/tasks.py +4 -4
  78. nucliadb/export_import/utils.py +25 -37
  79. nucliadb/health.py +1 -3
  80. nucliadb/ingest/app.py +15 -11
  81. nucliadb/ingest/consumer/auditing.py +21 -19
  82. nucliadb/ingest/consumer/consumer.py +82 -47
  83. nucliadb/ingest/consumer/materializer.py +5 -12
  84. nucliadb/ingest/consumer/pull.py +12 -27
  85. nucliadb/ingest/consumer/service.py +19 -17
  86. nucliadb/ingest/consumer/shard_creator.py +2 -4
  87. nucliadb/ingest/consumer/utils.py +1 -3
  88. nucliadb/ingest/fields/base.py +137 -105
  89. nucliadb/ingest/fields/conversation.py +18 -5
  90. nucliadb/ingest/fields/exceptions.py +1 -4
  91. nucliadb/ingest/fields/file.py +7 -16
  92. nucliadb/ingest/fields/link.py +5 -10
  93. nucliadb/ingest/fields/text.py +9 -4
  94. nucliadb/ingest/orm/brain.py +200 -213
  95. nucliadb/ingest/orm/broker_message.py +181 -0
  96. nucliadb/ingest/orm/entities.py +36 -51
  97. nucliadb/ingest/orm/exceptions.py +12 -0
  98. nucliadb/ingest/orm/knowledgebox.py +322 -197
  99. nucliadb/ingest/orm/processor/__init__.py +2 -700
  100. nucliadb/ingest/orm/processor/auditing.py +4 -23
  101. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  102. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  103. nucliadb/ingest/orm/processor/processor.py +752 -0
  104. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  105. nucliadb/ingest/orm/resource.py +249 -402
  106. nucliadb/ingest/orm/utils.py +4 -4
  107. nucliadb/ingest/partitions.py +3 -9
  108. nucliadb/ingest/processing.py +64 -73
  109. nucliadb/ingest/py.typed +0 -0
  110. nucliadb/ingest/serialize.py +37 -167
  111. nucliadb/ingest/service/__init__.py +1 -3
  112. nucliadb/ingest/service/writer.py +185 -412
  113. nucliadb/ingest/settings.py +10 -20
  114. nucliadb/ingest/utils.py +3 -6
  115. nucliadb/learning_proxy.py +242 -55
  116. nucliadb/metrics_exporter.py +30 -19
  117. nucliadb/middleware/__init__.py +1 -3
  118. nucliadb/migrator/command.py +1 -3
  119. nucliadb/migrator/datamanager.py +13 -13
  120. nucliadb/migrator/migrator.py +47 -30
  121. nucliadb/migrator/utils.py +18 -10
  122. nucliadb/purge/__init__.py +139 -33
  123. nucliadb/purge/orphan_shards.py +7 -13
  124. nucliadb/reader/__init__.py +1 -3
  125. nucliadb/reader/api/models.py +1 -12
  126. nucliadb/reader/api/v1/__init__.py +0 -1
  127. nucliadb/reader/api/v1/download.py +21 -88
  128. nucliadb/reader/api/v1/export_import.py +1 -1
  129. nucliadb/reader/api/v1/knowledgebox.py +10 -10
  130. nucliadb/reader/api/v1/learning_config.py +2 -6
  131. nucliadb/reader/api/v1/resource.py +62 -88
  132. nucliadb/reader/api/v1/services.py +64 -83
  133. nucliadb/reader/app.py +12 -29
  134. nucliadb/reader/lifecycle.py +18 -4
  135. nucliadb/reader/py.typed +0 -0
  136. nucliadb/reader/reader/notifications.py +10 -28
  137. nucliadb/search/__init__.py +1 -3
  138. nucliadb/search/api/v1/__init__.py +1 -2
  139. nucliadb/search/api/v1/ask.py +17 -10
  140. nucliadb/search/api/v1/catalog.py +184 -0
  141. nucliadb/search/api/v1/feedback.py +16 -24
  142. nucliadb/search/api/v1/find.py +36 -36
  143. nucliadb/search/api/v1/knowledgebox.py +89 -60
  144. nucliadb/search/api/v1/resource/ask.py +2 -8
  145. nucliadb/search/api/v1/resource/search.py +49 -70
  146. nucliadb/search/api/v1/search.py +44 -210
  147. nucliadb/search/api/v1/suggest.py +39 -54
  148. nucliadb/search/app.py +12 -32
  149. nucliadb/search/lifecycle.py +10 -3
  150. nucliadb/search/predict.py +136 -187
  151. nucliadb/search/py.typed +0 -0
  152. nucliadb/search/requesters/utils.py +25 -58
  153. nucliadb/search/search/cache.py +149 -20
  154. nucliadb/search/search/chat/ask.py +571 -123
  155. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
  156. nucliadb/search/search/chat/images.py +41 -17
  157. nucliadb/search/search/chat/prompt.py +817 -266
  158. nucliadb/search/search/chat/query.py +213 -309
  159. nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
  160. nucliadb/search/search/fetch.py +43 -36
  161. nucliadb/search/search/filters.py +9 -15
  162. nucliadb/search/search/find.py +214 -53
  163. nucliadb/search/search/find_merge.py +408 -391
  164. nucliadb/search/search/hydrator.py +191 -0
  165. nucliadb/search/search/merge.py +187 -223
  166. nucliadb/search/search/metrics.py +73 -2
  167. nucliadb/search/search/paragraphs.py +64 -106
  168. nucliadb/search/search/pgcatalog.py +233 -0
  169. nucliadb/search/search/predict_proxy.py +1 -1
  170. nucliadb/search/search/query.py +305 -150
  171. nucliadb/search/search/query_parser/exceptions.py +22 -0
  172. nucliadb/search/search/query_parser/models.py +101 -0
  173. nucliadb/search/search/query_parser/parser.py +183 -0
  174. nucliadb/search/search/rank_fusion.py +204 -0
  175. nucliadb/search/search/rerankers.py +270 -0
  176. nucliadb/search/search/shards.py +3 -32
  177. nucliadb/search/search/summarize.py +7 -18
  178. nucliadb/search/search/utils.py +27 -4
  179. nucliadb/search/settings.py +15 -1
  180. nucliadb/standalone/api_router.py +4 -10
  181. nucliadb/standalone/app.py +8 -14
  182. nucliadb/standalone/auth.py +7 -21
  183. nucliadb/standalone/config.py +7 -10
  184. nucliadb/standalone/lifecycle.py +26 -25
  185. nucliadb/standalone/migrations.py +1 -3
  186. nucliadb/standalone/purge.py +1 -1
  187. nucliadb/standalone/py.typed +0 -0
  188. nucliadb/standalone/run.py +3 -6
  189. nucliadb/standalone/settings.py +9 -16
  190. nucliadb/standalone/versions.py +15 -5
  191. nucliadb/tasks/consumer.py +8 -12
  192. nucliadb/tasks/producer.py +7 -6
  193. nucliadb/tests/config.py +53 -0
  194. nucliadb/train/__init__.py +1 -3
  195. nucliadb/train/api/utils.py +1 -2
  196. nucliadb/train/api/v1/shards.py +1 -1
  197. nucliadb/train/api/v1/trainset.py +2 -4
  198. nucliadb/train/app.py +10 -31
  199. nucliadb/train/generator.py +10 -19
  200. nucliadb/train/generators/field_classifier.py +7 -19
  201. nucliadb/train/generators/field_streaming.py +156 -0
  202. nucliadb/train/generators/image_classifier.py +12 -18
  203. nucliadb/train/generators/paragraph_classifier.py +5 -9
  204. nucliadb/train/generators/paragraph_streaming.py +6 -9
  205. nucliadb/train/generators/question_answer_streaming.py +19 -20
  206. nucliadb/train/generators/sentence_classifier.py +9 -15
  207. nucliadb/train/generators/token_classifier.py +48 -39
  208. nucliadb/train/generators/utils.py +14 -18
  209. nucliadb/train/lifecycle.py +7 -3
  210. nucliadb/train/nodes.py +23 -32
  211. nucliadb/train/py.typed +0 -0
  212. nucliadb/train/servicer.py +13 -21
  213. nucliadb/train/settings.py +2 -6
  214. nucliadb/train/types.py +13 -10
  215. nucliadb/train/upload.py +3 -6
  216. nucliadb/train/uploader.py +19 -23
  217. nucliadb/train/utils.py +1 -1
  218. nucliadb/writer/__init__.py +1 -3
  219. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  220. nucliadb/writer/api/v1/export_import.py +67 -14
  221. nucliadb/writer/api/v1/field.py +16 -269
  222. nucliadb/writer/api/v1/knowledgebox.py +218 -68
  223. nucliadb/writer/api/v1/resource.py +68 -88
  224. nucliadb/writer/api/v1/services.py +51 -70
  225. nucliadb/writer/api/v1/slug.py +61 -0
  226. nucliadb/writer/api/v1/transaction.py +67 -0
  227. nucliadb/writer/api/v1/upload.py +114 -113
  228. nucliadb/writer/app.py +6 -43
  229. nucliadb/writer/back_pressure.py +16 -38
  230. nucliadb/writer/exceptions.py +0 -4
  231. nucliadb/writer/lifecycle.py +21 -15
  232. nucliadb/writer/py.typed +0 -0
  233. nucliadb/writer/resource/audit.py +2 -1
  234. nucliadb/writer/resource/basic.py +48 -46
  235. nucliadb/writer/resource/field.py +25 -127
  236. nucliadb/writer/resource/origin.py +1 -2
  237. nucliadb/writer/settings.py +6 -2
  238. nucliadb/writer/tus/__init__.py +17 -15
  239. nucliadb/writer/tus/azure.py +111 -0
  240. nucliadb/writer/tus/dm.py +17 -5
  241. nucliadb/writer/tus/exceptions.py +1 -3
  242. nucliadb/writer/tus/gcs.py +49 -84
  243. nucliadb/writer/tus/local.py +21 -37
  244. nucliadb/writer/tus/s3.py +28 -68
  245. nucliadb/writer/tus/storage.py +5 -56
  246. nucliadb/writer/vectorsets.py +125 -0
  247. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  248. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  249. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  250. nucliadb/common/maindb/redis.py +0 -194
  251. nucliadb/common/maindb/tikv.py +0 -433
  252. nucliadb/ingest/fields/layout.py +0 -58
  253. nucliadb/ingest/tests/conftest.py +0 -30
  254. nucliadb/ingest/tests/fixtures.py +0 -764
  255. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  256. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
  257. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
  258. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  259. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  260. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  261. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
  262. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  263. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  264. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  265. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
  266. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  267. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  268. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
  269. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  270. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  271. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  272. nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
  273. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  274. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
  275. nucliadb/ingest/tests/unit/test_cache.py +0 -31
  276. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  277. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  278. nucliadb/middleware/transaction.py +0 -117
  279. nucliadb/reader/api/v1/learning_collector.py +0 -63
  280. nucliadb/reader/tests/__init__.py +0 -19
  281. nucliadb/reader/tests/conftest.py +0 -31
  282. nucliadb/reader/tests/fixtures.py +0 -136
  283. nucliadb/reader/tests/test_list_resources.py +0 -75
  284. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  285. nucliadb/reader/tests/test_reader_resource.py +0 -353
  286. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  287. nucliadb/search/api/v1/chat.py +0 -263
  288. nucliadb/search/api/v1/resource/chat.py +0 -174
  289. nucliadb/search/tests/__init__.py +0 -19
  290. nucliadb/search/tests/conftest.py +0 -33
  291. nucliadb/search/tests/fixtures.py +0 -199
  292. nucliadb/search/tests/node.py +0 -466
  293. nucliadb/search/tests/unit/__init__.py +0 -18
  294. nucliadb/search/tests/unit/api/__init__.py +0 -19
  295. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  296. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  297. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
  298. nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
  299. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  300. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  301. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
  302. nucliadb/search/tests/unit/search/__init__.py +0 -18
  303. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  304. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
  305. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  306. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  307. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  308. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
  309. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  310. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  311. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  312. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  313. nucliadb/search/tests/unit/search/test_query.py +0 -153
  314. nucliadb/search/tests/unit/test_app.py +0 -79
  315. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  316. nucliadb/search/tests/unit/test_merge.py +0 -34
  317. nucliadb/search/tests/unit/test_predict.py +0 -525
  318. nucliadb/standalone/tests/__init__.py +0 -19
  319. nucliadb/standalone/tests/conftest.py +0 -33
  320. nucliadb/standalone/tests/fixtures.py +0 -38
  321. nucliadb/standalone/tests/unit/__init__.py +0 -18
  322. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  323. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  324. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  325. nucliadb/standalone/tests/unit/test_migrations.py +0 -63
  326. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  327. nucliadb/tests/benchmarks/__init__.py +0 -19
  328. nucliadb/tests/benchmarks/test_search.py +0 -99
  329. nucliadb/tests/conftest.py +0 -32
  330. nucliadb/tests/fixtures.py +0 -735
  331. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
  332. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
  333. nucliadb/tests/migrations/test_migration_0017.py +0 -76
  334. nucliadb/tests/migrations/test_migration_0018.py +0 -95
  335. nucliadb/tests/tikv.py +0 -240
  336. nucliadb/tests/unit/__init__.py +0 -19
  337. nucliadb/tests/unit/common/__init__.py +0 -19
  338. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  339. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  340. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
  341. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  342. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
  343. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
  344. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
  345. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
  346. nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
  347. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
  348. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  349. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  350. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  351. nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
  352. nucliadb/tests/unit/common/test_context.py +0 -36
  353. nucliadb/tests/unit/export_import/__init__.py +0 -19
  354. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  355. nucliadb/tests/unit/export_import/test_utils.py +0 -301
  356. nucliadb/tests/unit/migrator/__init__.py +0 -19
  357. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  358. nucliadb/tests/unit/tasks/__init__.py +0 -19
  359. nucliadb/tests/unit/tasks/conftest.py +0 -42
  360. nucliadb/tests/unit/tasks/test_consumer.py +0 -92
  361. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  362. nucliadb/tests/unit/tasks/test_tasks.py +0 -58
  363. nucliadb/tests/unit/test_field_ids.py +0 -49
  364. nucliadb/tests/unit/test_health.py +0 -86
  365. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  366. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  367. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  368. nucliadb/tests/unit/test_purge.py +0 -136
  369. nucliadb/tests/utils/__init__.py +0 -74
  370. nucliadb/tests/utils/aiohttp_session.py +0 -44
  371. nucliadb/tests/utils/broker_messages/__init__.py +0 -171
  372. nucliadb/tests/utils/broker_messages/fields.py +0 -197
  373. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  374. nucliadb/tests/utils/entities.py +0 -78
  375. nucliadb/train/api/v1/check.py +0 -60
  376. nucliadb/train/tests/__init__.py +0 -19
  377. nucliadb/train/tests/conftest.py +0 -29
  378. nucliadb/train/tests/fixtures.py +0 -342
  379. nucliadb/train/tests/test_field_classification.py +0 -122
  380. nucliadb/train/tests/test_get_entities.py +0 -80
  381. nucliadb/train/tests/test_get_info.py +0 -51
  382. nucliadb/train/tests/test_get_ontology.py +0 -34
  383. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  384. nucliadb/train/tests/test_image_classification.py +0 -221
  385. nucliadb/train/tests/test_list_fields.py +0 -39
  386. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  387. nucliadb/train/tests/test_list_resources.py +0 -39
  388. nucliadb/train/tests/test_list_sentences.py +0 -71
  389. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  390. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  391. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  392. nucliadb/train/tests/test_sentence_classification.py +0 -143
  393. nucliadb/train/tests/test_token_classification.py +0 -136
  394. nucliadb/train/tests/utils.py +0 -101
  395. nucliadb/writer/layouts/__init__.py +0 -51
  396. nucliadb/writer/layouts/v1.py +0 -59
  397. nucliadb/writer/tests/__init__.py +0 -19
  398. nucliadb/writer/tests/conftest.py +0 -31
  399. nucliadb/writer/tests/fixtures.py +0 -191
  400. nucliadb/writer/tests/test_fields.py +0 -475
  401. nucliadb/writer/tests/test_files.py +0 -740
  402. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  403. nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
  404. nucliadb/writer/tests/test_resources.py +0 -476
  405. nucliadb/writer/tests/test_service.py +0 -137
  406. nucliadb/writer/tests/test_tus.py +0 -203
  407. nucliadb/writer/tests/utils.py +0 -35
  408. nucliadb/writer/tus/pg.py +0 -125
  409. nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
  410. nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
  411. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  412. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  413. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  414. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  415. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  416. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  417. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  418. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -21,6 +21,10 @@
21
21
  from collections import OrderedDict
22
22
  from typing import AsyncGenerator, cast
23
23
 
24
+ from nucliadb.common.cluster.base import AbstractIndexNode
25
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
26
+ from nucliadb.train import logger
27
+ from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
24
28
  from nucliadb_protos.dataset_pb2 import (
25
29
  TokenClassificationBatch,
26
30
  TokensClassification,
@@ -28,11 +32,6 @@ from nucliadb_protos.dataset_pb2 import (
28
32
  )
29
33
  from nucliadb_protos.nodereader_pb2 import StreamFilter, StreamRequest
30
34
 
31
- from nucliadb.common.cluster.base import AbstractIndexNode
32
- from nucliadb.ingest.orm.resource import KB_REVERSE
33
- from nucliadb.train import logger
34
- from nucliadb.train.generators.utils import batchify, get_resource_from_cache_or_db
35
-
36
35
  NERS_DICT = dict[str, dict[str, list[tuple[int, int]]]]
37
36
  POSITION_DICT = OrderedDict[tuple[int, int], tuple[str, str]]
38
37
  MAIN = "__main__"
@@ -44,9 +43,7 @@ def token_classification_batch_generator(
44
43
  node: AbstractIndexNode,
45
44
  shard_replica_id: str,
46
45
  ) -> AsyncGenerator[TokenClassificationBatch, None]:
47
- generator = generate_token_classification_payloads(
48
- kbid, trainset, node, shard_replica_id
49
- )
46
+ generator = generate_token_classification_payloads(kbid, trainset, node, shard_replica_id)
50
47
  batch_generator = batchify(generator, trainset.batch_size, TokenClassificationBatch)
51
48
  return batch_generator
52
49
 
@@ -97,21 +94,19 @@ async def get_field_text(
97
94
  logger.error(f"{rid} does not exist on DB")
98
95
  return {}, {}, {}
99
96
 
100
- field_type_int = KB_REVERSE[field_type]
97
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
101
98
  field_obj = await orm_resource.get_field(field, field_type_int, load=False)
102
99
  extracted_text = await field_obj.get_extracted_text()
103
100
  if extracted_text is None:
104
- logger.warning(
105
- f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
106
- )
101
+ logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
107
102
  return {}, {}, {}
108
103
 
109
104
  split_text: dict[str, str] = extracted_text.split_text
110
105
  split_text[MAIN] = extracted_text.text
111
106
 
112
- split_ners: dict[str, NERS_DICT] = (
113
- {}
114
- ) # Dict of entity group , with entity and list of positions in field
107
+ split_ners: dict[
108
+ str, NERS_DICT
109
+ ] = {} # Dict of entity group , with entity and list of positions in field
115
110
  split_ners[MAIN] = {}
116
111
 
117
112
  basic_data = await orm_resource.get_basic()
@@ -138,16 +133,42 @@ async def get_field_text(
138
133
  split = MAIN
139
134
  else:
140
135
  split = token.split
141
- split_ners[split].setdefault(token.klass, {}).setdefault(
142
- token.token, []
143
- )
144
- split_ners[split][token.klass][token.token].append(
145
- (token.start, token.end)
146
- )
136
+ split_ners[split].setdefault(token.klass, {}).setdefault(token.token, [])
137
+ split_ners[split][token.klass][token.token].append((token.start, token.end))
147
138
 
148
139
  field_metadata = await field_obj.get_field_metadata()
149
140
  # Check computed definition of entities
150
141
  if field_metadata is not None:
142
+ # Data Augmentation + Processor entities
143
+ for data_augmentation_task_id, entities in field_metadata.metadata.entities.items():
144
+ for entity in entities.entities:
145
+ entity_text = entity.text
146
+ entity_label = entity.label
147
+ entity_positions = entity.positions
148
+ if entity_label in valid_entity_groups:
149
+ split_ners[MAIN].setdefault(entity_label, {}).setdefault(entity_text, [])
150
+ for position in entity_positions:
151
+ split_ners[MAIN][entity_label][entity_text].append(
152
+ (position.start, position.end)
153
+ )
154
+
155
+ for split, split_metadata in field_metadata.split_metadata.items():
156
+ for data_augmentation_task_id, entities in split_metadata.entities.items():
157
+ for entity in entities.entities:
158
+ entity_text = entity.text
159
+ entity_label = entity.label
160
+ entity_positions = entity.positions
161
+ if entity_label in valid_entity_groups:
162
+ split_ners.setdefault(split, {}).setdefault(entity_label, {}).setdefault(
163
+ entity_text, []
164
+ )
165
+ for position in entity_positions:
166
+ split_ners[split][entity_label][entity_text].append(
167
+ (position.start, position.end)
168
+ )
169
+
170
+ # Legacy processor entities
171
+ # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
151
172
  for entity_key, positions in field_metadata.metadata.positions.items():
152
173
  entities = entity_key.split("/")
153
174
  entity_group = entities[0]
@@ -156,9 +177,7 @@ async def get_field_text(
156
177
  if entity_group in valid_entity_groups:
157
178
  split_ners[MAIN].setdefault(entity_group, {}).setdefault(entity, [])
158
179
  for position in positions.position:
159
- split_ners[MAIN][entity_group][entity].append(
160
- (position.start, position.end)
161
- )
180
+ split_ners[MAIN][entity_group][entity].append((position.start, position.end))
162
181
 
163
182
  for split, split_metadata in field_metadata.split_metadata.items():
164
183
  for entity_key, positions in split_metadata.positions.items():
@@ -166,24 +185,16 @@ async def get_field_text(
166
185
  entity_group = entities[0]
167
186
  entity = "/".join(entities[1:])
168
187
  if entity_group in valid_entity_groups:
169
- split_ners.setdefault(split, {}).setdefault(
170
- entity_group, {}
171
- ).setdefault(entity, [])
188
+ split_ners.setdefault(split, {}).setdefault(entity_group, {}).setdefault(entity, [])
172
189
  for position in positions.position:
173
- split_ners[split][entity_group][entity].append(
174
- (position.start, position.end)
175
- )
190
+ split_ners[split][entity_group][entity].append((position.start, position.end))
176
191
 
177
192
  for split, invalid_tokens in invalid_tokens_split.items():
178
193
  for token.klass, token.token, token.start, token.end in invalid_tokens:
179
194
  if token.klass in split_ners.get(split, {}):
180
195
  if token.token in split_ners.get(split, {}).get(token.klass, {}):
181
- if (token.start, token.end) in split_ners[split][token.klass][
182
- token.token
183
- ]:
184
- split_ners[split][token.klass][token.token].remove(
185
- (token.start, token.end)
186
- )
196
+ if (token.start, token.end) in split_ners[split][token.klass][token.token]:
197
+ split_ners[split][token.klass][token.token].remove((token.start, token.end))
187
198
  if len(split_ners[split][token.klass][token.token]) == 0:
188
199
  del split_ners[split][token.klass][token.token]
189
200
  if len(split_ners[split][token.klass]) == 0:
@@ -197,9 +208,7 @@ async def get_field_text(
197
208
  for position in positions:
198
209
  split_positions[position] = (entity_group, entity)
199
210
 
200
- ordered_positions[split] = OrderedDict(
201
- sorted(split_positions.items(), key=lambda x: x[0])
202
- )
211
+ ordered_positions[split] = OrderedDict(sorted(split_positions.items(), key=lambda x: x[0]))
203
212
 
204
213
  split_paragraphs: dict[str, list[tuple[int, int]]] = {}
205
214
  if field_metadata is not None:
@@ -19,19 +19,17 @@
19
19
  #
20
20
 
21
21
  from contextvars import ContextVar
22
- from typing import Any, AsyncIterator, Optional
22
+ from typing import Any, AsyncGenerator, AsyncIterator, Optional, Type
23
23
 
24
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
25
+ from nucliadb.common.maindb.utils import get_driver
24
26
  from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
25
- from nucliadb.ingest.orm.resource import KB_REVERSE
26
27
  from nucliadb.ingest.orm.resource import Resource as ResourceORM
27
- from nucliadb.middleware.transaction import get_read_only_transaction
28
28
  from nucliadb.train import SERVICE_NAME, logger
29
- from nucliadb.train.types import TrainBatchType
29
+ from nucliadb.train.types import T
30
30
  from nucliadb_utils.utilities import get_storage
31
31
 
32
- rcache: ContextVar[Optional[dict[str, ResourceORM]]] = ContextVar(
33
- "rcache", default=None
34
- )
32
+ rcache: ContextVar[Optional[dict[str, ResourceORM]]] = ContextVar("rcache", default=None)
35
33
 
36
34
 
37
35
  def get_resource_cache(clear: bool = False) -> dict[str, ResourceORM]:
@@ -46,12 +44,12 @@ async def get_resource_from_cache_or_db(kbid: str, uuid: str) -> Optional[Resour
46
44
  resouce_cache = get_resource_cache()
47
45
  orm_resource: Optional[ResourceORM] = None
48
46
  if uuid not in resouce_cache:
49
- transaction = await get_read_only_transaction()
50
47
  storage = await get_storage(service_name=SERVICE_NAME)
51
- kb = KnowledgeBoxORM(transaction, storage, kbid)
52
- orm_resource = await kb.get(uuid)
53
- if orm_resource is not None:
54
- resouce_cache[uuid] = orm_resource
48
+ async with get_driver().transaction(read_only=True) as transaction:
49
+ kb = KnowledgeBoxORM(transaction, storage, kbid)
50
+ orm_resource = await kb.get(uuid)
51
+ if orm_resource is not None:
52
+ resouce_cache[uuid] = orm_resource
55
53
  else:
56
54
  orm_resource = resouce_cache.get(uuid)
57
55
  return orm_resource
@@ -75,13 +73,11 @@ async def get_paragraph(kbid: str, paragraph_id: str) -> str:
75
73
  logger.error(f"{rid} does not exist on DB")
76
74
  return ""
77
75
 
78
- field_type_int = KB_REVERSE[field_type]
76
+ field_type_int = FIELD_TYPE_STR_TO_PB[field_type]
79
77
  field_obj = await orm_resource.get_field(field, field_type_int, load=False)
80
78
  extracted_text = await field_obj.get_extracted_text()
81
79
  if extracted_text is None:
82
- logger.warning(
83
- f"{rid} {field} {field_type_int} extracted_text does not exist on DB"
84
- )
80
+ logger.warning(f"{rid} {field} {field_type_int} extracted_text does not exist on DB")
85
81
  return ""
86
82
 
87
83
  if split is not None:
@@ -94,8 +90,8 @@ async def get_paragraph(kbid: str, paragraph_id: str) -> str:
94
90
 
95
91
 
96
92
  async def batchify(
97
- producer: AsyncIterator[Any], size: int, batch_klass: TrainBatchType
98
- ):
93
+ producer: AsyncIterator[Any], size: int, batch_klass: Type[T]
94
+ ) -> AsyncGenerator[T, None]:
99
95
  # NOTE: we are supposing all protobuffers have a data field
100
96
  batch = []
101
97
  async for item in producer:
@@ -18,6 +18,10 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
+ from contextlib import asynccontextmanager
22
+
23
+ from fastapi import FastAPI
24
+
21
25
  from nucliadb.common.cluster.discovery.utils import (
22
26
  setup_cluster_discovery,
23
27
  teardown_cluster_discovery,
@@ -33,16 +37,16 @@ from nucliadb_telemetry.utils import clean_telemetry, setup_telemetry
33
37
  from nucliadb_utils.utilities import start_audit_utility, stop_audit_utility
34
38
 
35
39
 
36
- async def initialize() -> None:
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
37
42
  await setup_telemetry(SERVICE_NAME)
38
-
39
43
  await setup_cluster_discovery()
40
44
  await start_shard_manager()
41
45
  await start_train_grpc(SERVICE_NAME)
42
46
  await start_audit_utility(SERVICE_NAME)
43
47
 
48
+ yield
44
49
 
45
- async def finalize() -> None:
46
50
  await stop_audit_utility()
47
51
  await stop_train_grpc()
48
52
  await stop_shard_manager()
nucliadb/train/nodes.py CHANGED
@@ -19,6 +19,15 @@
19
19
  #
20
20
  from typing import AsyncIterator, Optional
21
21
 
22
+ from nucliadb.common import datamanagers
23
+ from nucliadb.common.cluster import manager
24
+ from nucliadb.common.cluster.base import AbstractIndexNode
25
+
26
+ # XXX: this keys shouldn't be exposed outside datamanagers
27
+ from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG_BASE
28
+ from nucliadb.common.maindb.driver import Driver, Transaction
29
+ from nucliadb.ingest.orm.entities import EntitiesManager
30
+ from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
22
31
  from nucliadb_protos.train_pb2 import (
23
32
  GetFieldsRequest,
24
33
  GetParagraphsRequest,
@@ -30,15 +39,9 @@ from nucliadb_protos.train_pb2 import (
30
39
  TrainSentence,
31
40
  )
32
41
  from nucliadb_protos.writer_pb2 import ShardObject
33
-
34
- from nucliadb.common import datamanagers
35
- from nucliadb.common.cluster import manager
36
- from nucliadb.common.cluster.base import AbstractIndexNode
37
- from nucliadb.common.maindb.driver import Driver, Transaction
38
- from nucliadb.ingest.orm.entities import EntitiesManager
39
- from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
40
- from nucliadb.ingest.orm.resource import KB_RESOURCE_SLUG_BASE
42
+ from nucliadb_utils import const
41
43
  from nucliadb_utils.storages.storage import Storage
44
+ from nucliadb_utils.utilities import has_feature
42
45
 
43
46
 
44
47
  class TrainShardManager(manager.KBShardManager):
@@ -50,13 +53,13 @@ class TrainShardManager(manager.KBShardManager):
50
53
  async def get_reader(self, kbid: str, shard: str) -> tuple[AbstractIndexNode, str]:
51
54
  shards = await self.get_shards_by_kbid_inner(kbid)
52
55
  try:
53
- shard_object: ShardObject = next(
54
- filter(lambda x: x.shard == shard, shards.shards)
55
- )
56
+ shard_object: ShardObject = next(filter(lambda x: x.shard == shard, shards.shards))
56
57
  except StopIteration:
57
58
  raise KeyError("Shard not found")
58
59
 
59
- node_obj, shard_id = manager.choose_node(shard_object)
60
+ node_obj, shard_id = manager.choose_node(
61
+ shard_object, use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": kbid})
62
+ )
60
63
  return node_obj, shard_id
61
64
 
62
65
  async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]:
@@ -69,9 +72,7 @@ class TrainShardManager(manager.KBShardManager):
69
72
  kbobj = KnowledgeBox(txn, self.storage, kbid)
70
73
  return kbobj
71
74
 
72
- async def get_kb_entities_manager(
73
- self, txn: Transaction, kbid: str
74
- ) -> Optional[EntitiesManager]:
75
+ async def get_kb_entities_manager(self, txn: Transaction, kbid: str) -> Optional[EntitiesManager]:
75
76
  kbobj = await self.get_kb_obj(txn, kbid)
76
77
  if kbobj is None:
77
78
  return None
@@ -79,9 +80,7 @@ class TrainShardManager(manager.KBShardManager):
79
80
  manager = EntitiesManager(kbobj, txn)
80
81
  return manager
81
82
 
82
- async def kb_sentences(
83
- self, request: GetSentencesRequest
84
- ) -> AsyncIterator[TrainSentence]:
83
+ async def kb_sentences(self, request: GetSentencesRequest) -> AsyncIterator[TrainSentence]:
85
84
  async with self.driver.transaction() as txn:
86
85
  kb = KnowledgeBox(txn, self.storage, request.kb.uuid)
87
86
  if request.uuid != "":
@@ -95,24 +94,18 @@ class TrainShardManager(manager.KBShardManager):
95
94
  async for sentence in resource.iterate_sentences(request.metadata):
96
95
  yield sentence
97
96
 
98
- async def kb_paragraphs(
99
- self, request: GetParagraphsRequest
100
- ) -> AsyncIterator[TrainParagraph]:
97
+ async def kb_paragraphs(self, request: GetParagraphsRequest) -> AsyncIterator[TrainParagraph]:
101
98
  async with self.driver.transaction() as txn:
102
99
  kb = KnowledgeBox(txn, self.storage, request.kb.uuid)
103
100
  if request.uuid != "":
104
101
  # Filter by uuid
105
102
  resource = await kb.get(request.uuid)
106
103
  if resource:
107
- async for paragraph in resource.iterate_paragraphs(
108
- request.metadata
109
- ):
104
+ async for paragraph in resource.iterate_paragraphs(request.metadata):
110
105
  yield paragraph
111
106
  else:
112
107
  async for resource in kb.iterate_resources():
113
- async for paragraph in resource.iterate_paragraphs(
114
- request.metadata
115
- ):
108
+ async for paragraph in resource.iterate_paragraphs(request.metadata):
116
109
  yield paragraph
117
110
 
118
111
  async def kb_fields(self, request: GetFieldsRequest) -> AsyncIterator[TrainField]:
@@ -129,15 +122,13 @@ class TrainShardManager(manager.KBShardManager):
129
122
  async for field in resource.iterate_fields(request.metadata):
130
123
  yield field
131
124
 
132
- async def kb_resources(
133
- self, request: GetResourcesRequest
134
- ) -> AsyncIterator[TrainResource]:
125
+ async def kb_resources(self, request: GetResourcesRequest) -> AsyncIterator[TrainResource]:
135
126
  async with self.driver.transaction() as txn:
136
127
  kb = KnowledgeBox(txn, self.storage, request.kb.uuid)
137
128
  base = KB_RESOURCE_SLUG_BASE.format(kbid=request.kb.uuid)
138
- async for key in txn.keys(match=base, count=-1):
129
+ async for key in txn.keys(match=base):
139
130
  # Fetch and Add wanted item
140
- rid = await txn.get(key)
131
+ rid = await txn.get(key, for_update=False)
141
132
  if rid is not None:
142
133
  resource = await kb.get(rid.decode())
143
134
  if resource is not None:
File without changes
@@ -18,10 +18,13 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import traceback
21
- from typing import Optional
22
21
 
23
22
  import aiohttp
24
- from nucliadb_protos.knowledgebox_pb2 import Labels
23
+
24
+ from nucliadb.common import datamanagers
25
+ from nucliadb.train.settings import settings
26
+ from nucliadb.train.utils import get_shard_manager
27
+ from nucliadb_protos import train_pb2_grpc
25
28
  from nucliadb_protos.train_pb2 import (
26
29
  GetFieldsRequest,
27
30
  GetInfoRequest,
@@ -38,10 +41,6 @@ from nucliadb_protos.writer_pb2 import (
38
41
  GetLabelsRequest,
39
42
  GetLabelsResponse,
40
43
  )
41
-
42
- from nucliadb.train.settings import settings
43
- from nucliadb.train.utils import get_shard_manager
44
- from nucliadb_protos import train_pb2_grpc
45
44
  from nucliadb_telemetry import errors
46
45
 
47
46
 
@@ -111,20 +110,15 @@ class TrainServicer(train_pb2_grpc.TrainServicer):
111
110
  async def GetOntology( # type: ignore
112
111
  self, request: GetLabelsRequest, context=None
113
112
  ) -> GetLabelsResponse:
114
- async with self.proc.driver.transaction() as txn:
115
- kbobj = await self.proc.get_kb_obj(txn, request.kb.uuid)
116
- labels: Optional[Labels] = None
117
- if kbobj is not None:
118
- labels = await kbobj.get_labels()
119
-
120
113
  response = GetLabelsResponse()
121
- if kbobj is None:
122
- response.status = GetLabelsResponse.Status.NOTFOUND
114
+ kbid = request.kb.uuid
115
+ labels = await datamanagers.atomic.labelset.get_all(kbid=kbid)
116
+ if labels is not None:
117
+ response.kb.uuid = kbid
118
+ response.status = GetLabelsResponse.Status.OK
119
+ response.labels.CopyFrom(labels)
123
120
  else:
124
- response.kb.uuid = kbobj.kbid
125
- if labels is not None:
126
- response.labels.CopyFrom(labels)
127
-
121
+ response.status = GetLabelsResponse.Status.NOTFOUND
128
122
  return response
129
123
 
130
124
  async def GetOntologyCount( # type: ignore
@@ -132,9 +126,7 @@ class TrainServicer(train_pb2_grpc.TrainServicer):
132
126
  ) -> LabelsetsCount:
133
127
  url = settings.internal_search_api.format(kbid=request.kb.uuid)
134
128
  facets = [f"faceted=/p/{labelset}" for labelset in request.paragraph_labelsets]
135
- facets.extend(
136
- [f"faceted=/l/{labelset}" for labelset in request.resource_labelsets]
137
- )
129
+ facets.extend([f"faceted=/l/{labelset}" for labelset in request.resource_labelsets])
138
130
  query = "&".join(facets)
139
131
  headers = {"X-NUCLIADB-ROLES": "READER"}
140
132
  async with aiohttp.ClientSession() as sess:
@@ -29,13 +29,9 @@ class Settings(DriverSettings):
29
29
  nuclia_learning_url: Optional[str] = "https://nuclia.cloud/api/v1/learning/"
30
30
  nuclia_learning_apikey: Optional[str] = None
31
31
 
32
- internal_counter_api: str = (
33
- "http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/counters"
34
- )
32
+ internal_counter_api: str = "http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/counters"
35
33
 
36
- internal_search_api: str = (
37
- "http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/search"
38
- )
34
+ internal_search_api: str = "http://search.nuclia.svc.cluster.local:8030/api/v1/kb/{kbid}/search"
39
35
 
40
36
 
41
37
  settings = Settings()
nucliadb/train/types.py CHANGED
@@ -17,7 +17,7 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import Union
20
+ from typing import TypeVar, Union
21
21
 
22
22
  from nucliadb_protos import dataset_pb2 as dpb
23
23
 
@@ -29,14 +29,17 @@ TrainBatch = Union[
29
29
  dpb.QuestionAnswerStreamingBatch,
30
30
  dpb.SentenceClassificationBatch,
31
31
  dpb.TokenClassificationBatch,
32
+ dpb.FieldStreamingBatch,
32
33
  ]
33
34
 
34
- TrainBatchType = Union[
35
- type[dpb.FieldClassificationBatch],
36
- type[dpb.ImageClassificationBatch],
37
- type[dpb.ParagraphClassificationBatch],
38
- type[dpb.ParagraphStreamingBatch],
39
- type[dpb.QuestionAnswerStreamingBatch],
40
- type[dpb.SentenceClassificationBatch],
41
- type[dpb.TokenClassificationBatch],
42
- ]
35
+ T = TypeVar(
36
+ "T",
37
+ dpb.FieldClassificationBatch,
38
+ dpb.ImageClassificationBatch,
39
+ dpb.ParagraphClassificationBatch,
40
+ dpb.ParagraphStreamingBatch,
41
+ dpb.QuestionAnswerStreamingBatch,
42
+ dpb.SentenceClassificationBatch,
43
+ dpb.TokenClassificationBatch,
44
+ dpb.FieldStreamingBatch,
45
+ )
nucliadb/train/upload.py CHANGED
@@ -19,11 +19,10 @@
19
19
  #
20
20
  import argparse
21
21
  import asyncio
22
+ import importlib.metadata
22
23
  from asyncio import tasks
23
24
  from typing import Callable
24
25
 
25
- import pkg_resources
26
-
27
26
  from nucliadb.train.uploader import start_upload
28
27
  from nucliadb_telemetry import errors
29
28
  from nucliadb_telemetry.logs import setup_logging
@@ -33,9 +32,7 @@ from nucliadb_utils.settings import running_settings
33
32
  def arg_parse():
34
33
  parser = argparse.ArgumentParser(description="Upload data to Nuclia Learning API.")
35
34
 
36
- parser.add_argument(
37
- "-r", "--request", dest="request", help="Request UUID", required=True
38
- )
35
+ parser.add_argument("-r", "--request", dest="request", help="Request UUID", required=True)
39
36
 
40
37
  parser.add_argument("-k", "--kb", dest="kb", help="Knowledge Box", required=True)
41
38
 
@@ -75,7 +72,7 @@ def _cancel_all_tasks(loop):
75
72
  def run() -> None:
76
73
  setup_logging()
77
74
 
78
- errors.setup_error_handling(pkg_resources.get_distribution("nucliadb").version)
75
+ errors.setup_error_handling(importlib.metadata.distribution("nucliadb").version)
79
76
 
80
77
  if asyncio._get_running_loop() is not None:
81
78
  raise RuntimeError("cannot be called from a running event loop")
@@ -20,6 +20,14 @@
20
20
  from typing import Optional
21
21
 
22
22
  import aiohttp
23
+
24
+ from nucliadb.common import datamanagers
25
+ from nucliadb.common.maindb.utils import setup_driver
26
+ from nucliadb.ingest.orm.entities import EntitiesManager
27
+ from nucliadb.ingest.orm.processor import Processor
28
+ from nucliadb.train import SERVICE_NAME
29
+ from nucliadb.train.models import RequestData
30
+ from nucliadb.train.settings import settings
23
31
  from nucliadb_protos.knowledgebox_pb2 import Labels
24
32
  from nucliadb_protos.train_pb2 import (
25
33
  EnabledMetadata,
@@ -34,13 +42,6 @@ from nucliadb_protos.writer_pb2 import (
34
42
  GetLabelsRequest,
35
43
  GetLabelsResponse,
36
44
  )
37
-
38
- from nucliadb.common.maindb.utils import setup_driver
39
- from nucliadb.ingest.orm.entities import EntitiesManager
40
- from nucliadb.ingest.orm.processor import Processor
41
- from nucliadb.train import SERVICE_NAME
42
- from nucliadb.train.models import RequestData
43
- from nucliadb.train.settings import settings
44
45
  from nucliadb_utils.utilities import get_pubsub, get_storage
45
46
 
46
47
 
@@ -74,9 +75,8 @@ class UploadServicer:
74
75
  ) -> GetEntitiesResponse:
75
76
  kbid = request.kb.uuid
76
77
  response = GetEntitiesResponse()
77
- async with self.proc.driver.transaction() as txn:
78
+ async with self.proc.driver.transaction(read_only=True) as txn:
78
79
  kbobj = await self.proc.get_kb_obj(txn, request.kb)
79
-
80
80
  if kbobj is None:
81
81
  response.status = GetEntitiesResponse.Status.NOTFOUND
82
82
  return response
@@ -90,20 +90,16 @@ class UploadServicer:
90
90
  async def GetOntology( # type: ignore
91
91
  self, request: GetLabelsRequest, context=None
92
92
  ) -> GetLabelsResponse:
93
- async with self.proc.driver.transaction() as txn:
94
- kbobj = await self.proc.get_kb_obj(txn, request.kb)
95
- labels: Optional[Labels] = None
96
- if kbobj is not None:
97
- labels = await kbobj.get_labels()
98
-
93
+ kbid = request.kb.uuid
99
94
  response = GetLabelsResponse()
100
- if kbobj is None:
95
+ kb_exists = await datamanagers.atomic.kb.exists_kb(kbid=kbid)
96
+ if not kb_exists:
101
97
  response.status = GetLabelsResponse.Status.NOTFOUND
102
- else:
103
- response.kb.uuid = kbobj.kbid
104
- if labels is not None:
105
- response.labels.CopyFrom(labels)
106
-
98
+ return response
99
+ response.kb.uuid = kbid
100
+ labels: Optional[Labels] = await datamanagers.atomic.labelset.get_all(kbid=kbid)
101
+ if labels is not None:
102
+ response.labels.CopyFrom(labels)
107
103
  return response
108
104
 
109
105
 
@@ -123,9 +119,9 @@ async def start_upload(request: str, kb: str):
123
119
  }
124
120
  ) as sess:
125
121
  req = await sess.get(f"{url}/request")
126
- request_data = RequestData.parse_raw(await req.read())
122
+ request_data = RequestData.model_validate_json(await req.read())
127
123
 
128
- metadata = EnabledMetadata(**request_data.metadata.dict())
124
+ metadata = EnabledMetadata(**request_data.metadata.model_dump())
129
125
 
130
126
  if request_data.sentences:
131
127
  pbsr = GetSentencesRequest()
nucliadb/train/utils.py CHANGED
@@ -23,7 +23,7 @@ from grpc import aio
23
23
  from grpc_health.v1 import health, health_pb2_grpc
24
24
 
25
25
  from nucliadb.common.maindb.utils import setup_driver, teardown_driver
26
- from nucliadb.train.nodes import TrainShardManager # type: ignore
26
+ from nucliadb.train.nodes import TrainShardManager
27
27
  from nucliadb.train.settings import settings
28
28
  from nucliadb_protos import train_pb2_grpc
29
29
  from nucliadb_telemetry.utils import setup_telemetry
@@ -29,9 +29,7 @@ logger = logging.getLogger(SERVICE_NAME)
29
29
  class EndpointFilter(logging.Filter):
30
30
  def filter(self, record: logging.LogRecord) -> bool:
31
31
  return (
32
- record.args is not None
33
- and len(record.args) >= 3
34
- and record.args[2] not in ("/", "/metrics") # type: ignore
32
+ record.args is not None and len(record.args) >= 3 and record.args[2] not in ("/", "/metrics") # type: ignore
35
33
  )
36
34
 
37
35
 
@@ -17,18 +17,21 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from nucliadb_protos.resources_pb2 import FieldKeywordset
20
+ from functools import wraps
21
21
 
22
- from nucliadb.ingest.fields.base import Field
22
+ from fastapi import HTTPException
23
23
 
24
+ from nucliadb_utils.settings import is_onprem_nucliadb
24
25
 
25
- class Keywordset(Field):
26
- pbklass = FieldKeywordset
27
- value: FieldKeywordset
28
- type: str = "k"
29
26
 
30
- async def set_value(self, payload: FieldKeywordset):
31
- await self.db_set_value(payload)
27
+ def only_for_onprem(fun):
28
+ @wraps(fun)
29
+ async def endpoint_wrapper(*args, **kwargs):
30
+ if not is_onprem_nucliadb():
31
+ raise HTTPException(
32
+ status_code=403,
33
+ detail="This endpoint is only available for onprem NucliaDB",
34
+ )
35
+ return await fun(*args, **kwargs)
32
36
 
33
- async def get_value(self) -> FieldKeywordset:
34
- return await self.db_get_value()
37
+ return endpoint_wrapper