nucliadb 2.46.1.post382__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. migrations/0002_rollover_shards.py +1 -2
  2. migrations/0003_allfields_key.py +2 -37
  3. migrations/0004_rollover_shards.py +1 -2
  4. migrations/0005_rollover_shards.py +1 -2
  5. migrations/0006_rollover_shards.py +2 -4
  6. migrations/0008_cleanup_leftover_rollover_metadata.py +1 -2
  7. migrations/0009_upgrade_relations_and_texts_to_v2.py +5 -4
  8. migrations/0010_fix_corrupt_indexes.py +11 -12
  9. migrations/0011_materialize_labelset_ids.py +2 -18
  10. migrations/0012_rollover_shards.py +6 -12
  11. migrations/0013_rollover_shards.py +2 -4
  12. migrations/0014_rollover_shards.py +5 -7
  13. migrations/0015_targeted_rollover.py +6 -12
  14. migrations/0016_upgrade_to_paragraphs_v2.py +27 -32
  15. migrations/0017_multiple_writable_shards.py +3 -6
  16. migrations/0018_purge_orphan_kbslugs.py +59 -0
  17. migrations/0019_upgrade_to_paragraphs_v3.py +66 -0
  18. migrations/0020_drain_nodes_from_cluster.py +83 -0
  19. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +17 -18
  20. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  21. migrations/0023_backfill_pg_catalog.py +80 -0
  22. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  23. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  24. migrations/0027_rollover_texts3.py +73 -0
  25. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  26. migrations/pg/0002_catalog.py +42 -0
  27. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  28. nucliadb/common/cluster/base.py +41 -24
  29. nucliadb/common/cluster/discovery/base.py +6 -14
  30. nucliadb/common/cluster/discovery/k8s.py +9 -19
  31. nucliadb/common/cluster/discovery/manual.py +1 -3
  32. nucliadb/common/cluster/discovery/single.py +1 -2
  33. nucliadb/common/cluster/discovery/utils.py +1 -3
  34. nucliadb/common/cluster/grpc_node_dummy.py +11 -16
  35. nucliadb/common/cluster/index_node.py +10 -19
  36. nucliadb/common/cluster/manager.py +223 -102
  37. nucliadb/common/cluster/rebalance.py +42 -37
  38. nucliadb/common/cluster/rollover.py +377 -204
  39. nucliadb/common/cluster/settings.py +16 -9
  40. nucliadb/common/cluster/standalone/grpc_node_binding.py +24 -76
  41. nucliadb/common/cluster/standalone/index_node.py +4 -11
  42. nucliadb/common/cluster/standalone/service.py +2 -6
  43. nucliadb/common/cluster/standalone/utils.py +9 -6
  44. nucliadb/common/cluster/utils.py +43 -29
  45. nucliadb/common/constants.py +20 -0
  46. nucliadb/common/context/__init__.py +6 -4
  47. nucliadb/common/context/fastapi.py +8 -5
  48. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  49. nucliadb/common/datamanagers/__init__.py +24 -5
  50. nucliadb/common/datamanagers/atomic.py +102 -0
  51. nucliadb/common/datamanagers/cluster.py +5 -5
  52. nucliadb/common/datamanagers/entities.py +6 -16
  53. nucliadb/common/datamanagers/fields.py +84 -0
  54. nucliadb/common/datamanagers/kb.py +101 -24
  55. nucliadb/common/datamanagers/labels.py +26 -56
  56. nucliadb/common/datamanagers/processing.py +2 -6
  57. nucliadb/common/datamanagers/resources.py +214 -117
  58. nucliadb/common/datamanagers/rollover.py +77 -16
  59. nucliadb/{ingest/orm → common/datamanagers}/synonyms.py +16 -28
  60. nucliadb/common/datamanagers/utils.py +19 -11
  61. nucliadb/common/datamanagers/vectorsets.py +110 -0
  62. nucliadb/common/external_index_providers/base.py +257 -0
  63. nucliadb/{ingest/tests/unit/test_cache.py → common/external_index_providers/exceptions.py} +9 -8
  64. nucliadb/common/external_index_providers/manager.py +101 -0
  65. nucliadb/common/external_index_providers/pinecone.py +933 -0
  66. nucliadb/common/external_index_providers/settings.py +52 -0
  67. nucliadb/common/http_clients/auth.py +3 -6
  68. nucliadb/common/http_clients/processing.py +6 -11
  69. nucliadb/common/http_clients/utils.py +1 -3
  70. nucliadb/common/ids.py +240 -0
  71. nucliadb/common/locking.py +43 -13
  72. nucliadb/common/maindb/driver.py +11 -35
  73. nucliadb/common/maindb/exceptions.py +6 -6
  74. nucliadb/common/maindb/local.py +22 -9
  75. nucliadb/common/maindb/pg.py +206 -111
  76. nucliadb/common/maindb/utils.py +13 -44
  77. nucliadb/common/models_utils/from_proto.py +479 -0
  78. nucliadb/common/models_utils/to_proto.py +60 -0
  79. nucliadb/common/nidx.py +260 -0
  80. nucliadb/export_import/datamanager.py +25 -19
  81. nucliadb/export_import/exceptions.py +8 -0
  82. nucliadb/export_import/exporter.py +20 -7
  83. nucliadb/export_import/importer.py +6 -11
  84. nucliadb/export_import/models.py +5 -5
  85. nucliadb/export_import/tasks.py +4 -4
  86. nucliadb/export_import/utils.py +94 -54
  87. nucliadb/health.py +1 -3
  88. nucliadb/ingest/app.py +15 -11
  89. nucliadb/ingest/consumer/auditing.py +30 -147
  90. nucliadb/ingest/consumer/consumer.py +96 -52
  91. nucliadb/ingest/consumer/materializer.py +10 -12
  92. nucliadb/ingest/consumer/pull.py +12 -27
  93. nucliadb/ingest/consumer/service.py +20 -19
  94. nucliadb/ingest/consumer/shard_creator.py +7 -14
  95. nucliadb/ingest/consumer/utils.py +1 -3
  96. nucliadb/ingest/fields/base.py +139 -188
  97. nucliadb/ingest/fields/conversation.py +18 -5
  98. nucliadb/ingest/fields/exceptions.py +1 -4
  99. nucliadb/ingest/fields/file.py +7 -25
  100. nucliadb/ingest/fields/link.py +11 -16
  101. nucliadb/ingest/fields/text.py +9 -4
  102. nucliadb/ingest/orm/brain.py +255 -262
  103. nucliadb/ingest/orm/broker_message.py +181 -0
  104. nucliadb/ingest/orm/entities.py +36 -51
  105. nucliadb/ingest/orm/exceptions.py +12 -0
  106. nucliadb/ingest/orm/knowledgebox.py +334 -278
  107. nucliadb/ingest/orm/processor/__init__.py +2 -697
  108. nucliadb/ingest/orm/processor/auditing.py +117 -0
  109. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  110. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  111. nucliadb/ingest/orm/processor/processor.py +752 -0
  112. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  113. nucliadb/ingest/orm/resource.py +280 -520
  114. nucliadb/ingest/orm/utils.py +25 -31
  115. nucliadb/ingest/partitions.py +3 -9
  116. nucliadb/ingest/processing.py +76 -81
  117. nucliadb/ingest/py.typed +0 -0
  118. nucliadb/ingest/serialize.py +37 -173
  119. nucliadb/ingest/service/__init__.py +1 -3
  120. nucliadb/ingest/service/writer.py +186 -577
  121. nucliadb/ingest/settings.py +13 -22
  122. nucliadb/ingest/utils.py +3 -6
  123. nucliadb/learning_proxy.py +264 -51
  124. nucliadb/metrics_exporter.py +30 -19
  125. nucliadb/middleware/__init__.py +1 -3
  126. nucliadb/migrator/command.py +1 -3
  127. nucliadb/migrator/datamanager.py +13 -13
  128. nucliadb/migrator/migrator.py +57 -37
  129. nucliadb/migrator/settings.py +2 -1
  130. nucliadb/migrator/utils.py +18 -10
  131. nucliadb/purge/__init__.py +139 -33
  132. nucliadb/purge/orphan_shards.py +7 -13
  133. nucliadb/reader/__init__.py +1 -3
  134. nucliadb/reader/api/models.py +3 -14
  135. nucliadb/reader/api/v1/__init__.py +0 -1
  136. nucliadb/reader/api/v1/download.py +27 -94
  137. nucliadb/reader/api/v1/export_import.py +4 -4
  138. nucliadb/reader/api/v1/knowledgebox.py +13 -13
  139. nucliadb/reader/api/v1/learning_config.py +8 -12
  140. nucliadb/reader/api/v1/resource.py +67 -93
  141. nucliadb/reader/api/v1/services.py +70 -125
  142. nucliadb/reader/app.py +16 -46
  143. nucliadb/reader/lifecycle.py +18 -4
  144. nucliadb/reader/py.typed +0 -0
  145. nucliadb/reader/reader/notifications.py +10 -31
  146. nucliadb/search/__init__.py +1 -3
  147. nucliadb/search/api/v1/__init__.py +2 -2
  148. nucliadb/search/api/v1/ask.py +112 -0
  149. nucliadb/search/api/v1/catalog.py +184 -0
  150. nucliadb/search/api/v1/feedback.py +17 -25
  151. nucliadb/search/api/v1/find.py +41 -41
  152. nucliadb/search/api/v1/knowledgebox.py +90 -62
  153. nucliadb/search/api/v1/predict_proxy.py +2 -2
  154. nucliadb/search/api/v1/resource/ask.py +66 -117
  155. nucliadb/search/api/v1/resource/search.py +51 -72
  156. nucliadb/search/api/v1/router.py +1 -0
  157. nucliadb/search/api/v1/search.py +50 -197
  158. nucliadb/search/api/v1/suggest.py +40 -54
  159. nucliadb/search/api/v1/summarize.py +9 -5
  160. nucliadb/search/api/v1/utils.py +2 -1
  161. nucliadb/search/app.py +16 -48
  162. nucliadb/search/lifecycle.py +10 -3
  163. nucliadb/search/predict.py +176 -188
  164. nucliadb/search/py.typed +0 -0
  165. nucliadb/search/requesters/utils.py +41 -63
  166. nucliadb/search/search/cache.py +149 -20
  167. nucliadb/search/search/chat/ask.py +918 -0
  168. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -13
  169. nucliadb/search/search/chat/images.py +41 -17
  170. nucliadb/search/search/chat/prompt.py +851 -282
  171. nucliadb/search/search/chat/query.py +274 -267
  172. nucliadb/{writer/resource/slug.py → search/search/cut.py} +8 -6
  173. nucliadb/search/search/fetch.py +43 -36
  174. nucliadb/search/search/filters.py +9 -15
  175. nucliadb/search/search/find.py +214 -54
  176. nucliadb/search/search/find_merge.py +408 -391
  177. nucliadb/search/search/hydrator.py +191 -0
  178. nucliadb/search/search/merge.py +198 -234
  179. nucliadb/search/search/metrics.py +73 -2
  180. nucliadb/search/search/paragraphs.py +64 -106
  181. nucliadb/search/search/pgcatalog.py +233 -0
  182. nucliadb/search/search/predict_proxy.py +1 -1
  183. nucliadb/search/search/query.py +386 -257
  184. nucliadb/search/search/query_parser/exceptions.py +22 -0
  185. nucliadb/search/search/query_parser/models.py +101 -0
  186. nucliadb/search/search/query_parser/parser.py +183 -0
  187. nucliadb/search/search/rank_fusion.py +204 -0
  188. nucliadb/search/search/rerankers.py +270 -0
  189. nucliadb/search/search/shards.py +4 -38
  190. nucliadb/search/search/summarize.py +14 -18
  191. nucliadb/search/search/utils.py +27 -4
  192. nucliadb/search/settings.py +15 -1
  193. nucliadb/standalone/api_router.py +4 -10
  194. nucliadb/standalone/app.py +17 -14
  195. nucliadb/standalone/auth.py +7 -21
  196. nucliadb/standalone/config.py +9 -12
  197. nucliadb/standalone/introspect.py +5 -5
  198. nucliadb/standalone/lifecycle.py +26 -25
  199. nucliadb/standalone/migrations.py +58 -0
  200. nucliadb/standalone/purge.py +9 -8
  201. nucliadb/standalone/py.typed +0 -0
  202. nucliadb/standalone/run.py +25 -18
  203. nucliadb/standalone/settings.py +10 -14
  204. nucliadb/standalone/versions.py +15 -5
  205. nucliadb/tasks/consumer.py +8 -12
  206. nucliadb/tasks/producer.py +7 -6
  207. nucliadb/tests/config.py +53 -0
  208. nucliadb/train/__init__.py +1 -3
  209. nucliadb/train/api/utils.py +1 -2
  210. nucliadb/train/api/v1/shards.py +2 -2
  211. nucliadb/train/api/v1/trainset.py +4 -6
  212. nucliadb/train/app.py +14 -47
  213. nucliadb/train/generator.py +10 -19
  214. nucliadb/train/generators/field_classifier.py +7 -19
  215. nucliadb/train/generators/field_streaming.py +156 -0
  216. nucliadb/train/generators/image_classifier.py +12 -18
  217. nucliadb/train/generators/paragraph_classifier.py +5 -9
  218. nucliadb/train/generators/paragraph_streaming.py +6 -9
  219. nucliadb/train/generators/question_answer_streaming.py +19 -20
  220. nucliadb/train/generators/sentence_classifier.py +9 -15
  221. nucliadb/train/generators/token_classifier.py +45 -36
  222. nucliadb/train/generators/utils.py +14 -18
  223. nucliadb/train/lifecycle.py +7 -3
  224. nucliadb/train/nodes.py +23 -32
  225. nucliadb/train/py.typed +0 -0
  226. nucliadb/train/servicer.py +13 -21
  227. nucliadb/train/settings.py +2 -6
  228. nucliadb/train/types.py +13 -10
  229. nucliadb/train/upload.py +3 -6
  230. nucliadb/train/uploader.py +20 -25
  231. nucliadb/train/utils.py +1 -1
  232. nucliadb/writer/__init__.py +1 -3
  233. nucliadb/writer/api/constants.py +0 -5
  234. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  235. nucliadb/writer/api/v1/export_import.py +102 -49
  236. nucliadb/writer/api/v1/field.py +196 -620
  237. nucliadb/writer/api/v1/knowledgebox.py +221 -71
  238. nucliadb/writer/api/v1/learning_config.py +2 -2
  239. nucliadb/writer/api/v1/resource.py +114 -216
  240. nucliadb/writer/api/v1/services.py +64 -132
  241. nucliadb/writer/api/v1/slug.py +61 -0
  242. nucliadb/writer/api/v1/transaction.py +67 -0
  243. nucliadb/writer/api/v1/upload.py +184 -215
  244. nucliadb/writer/app.py +11 -61
  245. nucliadb/writer/back_pressure.py +62 -43
  246. nucliadb/writer/exceptions.py +0 -4
  247. nucliadb/writer/lifecycle.py +21 -15
  248. nucliadb/writer/py.typed +0 -0
  249. nucliadb/writer/resource/audit.py +2 -1
  250. nucliadb/writer/resource/basic.py +48 -62
  251. nucliadb/writer/resource/field.py +45 -135
  252. nucliadb/writer/resource/origin.py +1 -2
  253. nucliadb/writer/settings.py +14 -5
  254. nucliadb/writer/tus/__init__.py +17 -15
  255. nucliadb/writer/tus/azure.py +111 -0
  256. nucliadb/writer/tus/dm.py +17 -5
  257. nucliadb/writer/tus/exceptions.py +1 -3
  258. nucliadb/writer/tus/gcs.py +56 -84
  259. nucliadb/writer/tus/local.py +21 -37
  260. nucliadb/writer/tus/s3.py +28 -68
  261. nucliadb/writer/tus/storage.py +5 -56
  262. nucliadb/writer/vectorsets.py +125 -0
  263. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  264. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  265. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  266. nucliadb/common/maindb/redis.py +0 -194
  267. nucliadb/common/maindb/tikv.py +0 -412
  268. nucliadb/ingest/fields/layout.py +0 -58
  269. nucliadb/ingest/tests/conftest.py +0 -30
  270. nucliadb/ingest/tests/fixtures.py +0 -771
  271. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  272. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -80
  273. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -89
  274. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  275. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  276. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  277. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -691
  278. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  279. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  280. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  281. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -140
  282. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  283. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  284. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -139
  285. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  286. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  287. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  288. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  289. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -275
  290. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  291. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  292. nucliadb/middleware/transaction.py +0 -117
  293. nucliadb/reader/api/v1/learning_collector.py +0 -63
  294. nucliadb/reader/tests/__init__.py +0 -19
  295. nucliadb/reader/tests/conftest.py +0 -31
  296. nucliadb/reader/tests/fixtures.py +0 -136
  297. nucliadb/reader/tests/test_list_resources.py +0 -75
  298. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  299. nucliadb/reader/tests/test_reader_resource.py +0 -379
  300. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  301. nucliadb/search/api/v1/chat.py +0 -258
  302. nucliadb/search/api/v1/resource/chat.py +0 -94
  303. nucliadb/search/tests/__init__.py +0 -19
  304. nucliadb/search/tests/conftest.py +0 -33
  305. nucliadb/search/tests/fixtures.py +0 -199
  306. nucliadb/search/tests/node.py +0 -465
  307. nucliadb/search/tests/unit/__init__.py +0 -18
  308. nucliadb/search/tests/unit/api/__init__.py +0 -19
  309. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  310. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  311. nucliadb/search/tests/unit/api/v1/resource/test_ask.py +0 -67
  312. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -97
  313. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  314. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  315. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -93
  316. nucliadb/search/tests/unit/search/__init__.py +0 -18
  317. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  318. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -210
  319. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  320. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  321. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  322. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -266
  323. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  324. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  325. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  326. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  327. nucliadb/search/tests/unit/search/test_query.py +0 -201
  328. nucliadb/search/tests/unit/test_app.py +0 -79
  329. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  330. nucliadb/search/tests/unit/test_merge.py +0 -34
  331. nucliadb/search/tests/unit/test_predict.py +0 -584
  332. nucliadb/standalone/tests/__init__.py +0 -19
  333. nucliadb/standalone/tests/conftest.py +0 -33
  334. nucliadb/standalone/tests/fixtures.py +0 -38
  335. nucliadb/standalone/tests/unit/__init__.py +0 -18
  336. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  337. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  338. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  339. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  340. nucliadb/tests/benchmarks/__init__.py +0 -19
  341. nucliadb/tests/benchmarks/test_search.py +0 -99
  342. nucliadb/tests/conftest.py +0 -32
  343. nucliadb/tests/fixtures.py +0 -736
  344. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -203
  345. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -109
  346. nucliadb/tests/migrations/__init__.py +0 -19
  347. nucliadb/tests/migrations/test_migration_0017.py +0 -80
  348. nucliadb/tests/tikv.py +0 -240
  349. nucliadb/tests/unit/__init__.py +0 -19
  350. nucliadb/tests/unit/common/__init__.py +0 -19
  351. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  352. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  353. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -170
  354. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  355. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -113
  356. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -59
  357. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -399
  358. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -178
  359. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -279
  360. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  361. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  362. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  363. nucliadb/tests/unit/common/maindb/test_utils.py +0 -81
  364. nucliadb/tests/unit/common/test_context.py +0 -36
  365. nucliadb/tests/unit/export_import/__init__.py +0 -19
  366. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  367. nucliadb/tests/unit/export_import/test_utils.py +0 -294
  368. nucliadb/tests/unit/migrator/__init__.py +0 -19
  369. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  370. nucliadb/tests/unit/tasks/__init__.py +0 -19
  371. nucliadb/tests/unit/tasks/conftest.py +0 -42
  372. nucliadb/tests/unit/tasks/test_consumer.py +0 -93
  373. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  374. nucliadb/tests/unit/tasks/test_tasks.py +0 -60
  375. nucliadb/tests/unit/test_field_ids.py +0 -49
  376. nucliadb/tests/unit/test_health.py +0 -84
  377. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  378. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  379. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  380. nucliadb/tests/unit/test_purge.py +0 -138
  381. nucliadb/tests/utils/__init__.py +0 -74
  382. nucliadb/tests/utils/aiohttp_session.py +0 -44
  383. nucliadb/tests/utils/broker_messages/__init__.py +0 -167
  384. nucliadb/tests/utils/broker_messages/fields.py +0 -181
  385. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  386. nucliadb/tests/utils/entities.py +0 -78
  387. nucliadb/train/api/v1/check.py +0 -60
  388. nucliadb/train/tests/__init__.py +0 -19
  389. nucliadb/train/tests/conftest.py +0 -29
  390. nucliadb/train/tests/fixtures.py +0 -342
  391. nucliadb/train/tests/test_field_classification.py +0 -122
  392. nucliadb/train/tests/test_get_entities.py +0 -80
  393. nucliadb/train/tests/test_get_info.py +0 -51
  394. nucliadb/train/tests/test_get_ontology.py +0 -34
  395. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  396. nucliadb/train/tests/test_image_classification.py +0 -222
  397. nucliadb/train/tests/test_list_fields.py +0 -39
  398. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  399. nucliadb/train/tests/test_list_resources.py +0 -39
  400. nucliadb/train/tests/test_list_sentences.py +0 -71
  401. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  402. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  403. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  404. nucliadb/train/tests/test_sentence_classification.py +0 -143
  405. nucliadb/train/tests/test_token_classification.py +0 -136
  406. nucliadb/train/tests/utils.py +0 -108
  407. nucliadb/writer/layouts/__init__.py +0 -51
  408. nucliadb/writer/layouts/v1.py +0 -59
  409. nucliadb/writer/resource/vectors.py +0 -120
  410. nucliadb/writer/tests/__init__.py +0 -19
  411. nucliadb/writer/tests/conftest.py +0 -31
  412. nucliadb/writer/tests/fixtures.py +0 -192
  413. nucliadb/writer/tests/test_fields.py +0 -486
  414. nucliadb/writer/tests/test_files.py +0 -743
  415. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  416. nucliadb/writer/tests/test_reprocess_file_field.py +0 -139
  417. nucliadb/writer/tests/test_resources.py +0 -546
  418. nucliadb/writer/tests/test_service.py +0 -137
  419. nucliadb/writer/tests/test_tus.py +0 -203
  420. nucliadb/writer/tests/utils.py +0 -35
  421. nucliadb/writer/tus/pg.py +0 -125
  422. nucliadb-2.46.1.post382.dist-info/METADATA +0 -134
  423. nucliadb-2.46.1.post382.dist-info/RECORD +0 -451
  424. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  425. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  426. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  427. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  428. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  429. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  430. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  431. {nucliadb-2.46.1.post382.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -1,123 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- import uuid
22
-
23
- import aiohttp
24
- import pytest
25
- from nucliadb_protos.dataset_pb2 import ParagraphClassificationBatch, TaskType, TrainSet
26
- from nucliadb_protos.writer_pb2 import BrokerMessage
27
- from nucliadb_protos.writer_pb2_grpc import WriterStub
28
-
29
- from nucliadb.tests.utils import inject_message
30
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
31
- from nucliadb.train import API_PREFIX
32
- from nucliadb.train.api.v1.router import KB_PREFIX
33
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
34
- from nucliadb_protos import resources_pb2 as rpb
35
-
36
-
37
- @pytest.mark.asyncio
38
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
39
- async def test_generator_paragraph_classification(
40
- train_rest_api: aiohttp.ClientSession,
41
- nucliadb_grpc: WriterStub,
42
- knowledgebox_with_labels: str,
43
- ):
44
- kbid = knowledgebox_with_labels
45
-
46
- await inject_resource_with_paragraph_classification(kbid, nucliadb_grpc)
47
-
48
- async with train_rest_api.get(
49
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
50
- ) as partitions:
51
- assert partitions.status == 200
52
- data = await partitions.json()
53
- assert len(data["partitions"]) == 1
54
- partition_id = data["partitions"][0]
55
-
56
- trainset = TrainSet()
57
- trainset.type = TaskType.PARAGRAPH_CLASSIFICATION
58
- trainset.batch_size = 2
59
- trainset.filter.labels.append("labelset_paragraphs")
60
-
61
- async with train_rest_api.post(
62
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
63
- data=trainset.SerializeToString(),
64
- ) as response:
65
- assert response.status == 200
66
- batches = []
67
- async for batch in get_batches_from_train_response_stream(
68
- response, ParagraphClassificationBatch
69
- ):
70
- batches.append(batch)
71
- assert len(batch.data) == 2
72
- assert len(batches) == 2
73
-
74
-
75
- async def inject_resource_with_paragraph_classification(knowledgebox, writer):
76
- bm = broker_resource(knowledgebox)
77
- await inject_message(writer, bm)
78
- await asyncio.sleep(0.1)
79
- return bm.uuid
80
-
81
-
82
- def broker_resource(knowledgebox: str) -> BrokerMessage:
83
- rid = str(uuid.uuid4())
84
- bmb = BrokerMessageBuilder(kbid=knowledgebox, rid=rid)
85
- bmb.with_title("Title Resource")
86
- bmb.with_summary("Summary of document")
87
- bmb.with_resource_labels("labelset_resources", ["label_user"])
88
-
89
- file_field = FieldBuilder("file", rpb.FieldType.FILE)
90
- file_field.with_extracted_text(
91
- "My own text Ramon. This is great to be here. \n Where is my beer? Do you want to go shooping? This is a test!" # noqa
92
- )
93
-
94
- labelset = "labelset_paragraphs"
95
- labels = ["label_user"]
96
- file_field.with_user_paragraph_labels(f"{rid}/f/file/0-45", labelset, labels)
97
- file_field.with_user_paragraph_labels(f"{rid}/f/file/47-64", labelset, labels)
98
- file_field.with_user_paragraph_labels(f"{rid}/f/file/65-93", labelset, labels)
99
- file_field.with_user_paragraph_labels(f"{rid}/f/file/93-109", labelset, labels)
100
-
101
- classification = rpb.Classification(
102
- labelset="labelset_paragraphs", label="label_machine"
103
- )
104
- file_field.with_extracted_paragraph_metadata(
105
- rpb.Paragraph(start=0, end=45, classifications=[classification])
106
- )
107
- file_field.with_extracted_paragraph_metadata(
108
- rpb.Paragraph(start=47, end=64, classifications=[classification])
109
- )
110
- file_field.with_extracted_paragraph_metadata(
111
- rpb.Paragraph(start=65, end=93, classifications=[classification])
112
- )
113
- file_field.with_extracted_paragraph_metadata(
114
- rpb.Paragraph(start=93, end=109, classifications=[classification])
115
- )
116
-
117
- file_field.with_extracted_labels("labelset_resources", ["label_machine"])
118
-
119
- bmb.add_field_builder(file_field)
120
-
121
- bm = bmb.build()
122
-
123
- return bm
@@ -1,118 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- from typing import AsyncIterator
22
-
23
- import aiohttp
24
- import pytest
25
- from nucliadb_protos.dataset_pb2 import ParagraphStreamingBatch, TaskType, TrainSet
26
- from nucliadb_protos.writer_pb2 import BrokerMessage
27
- from nucliadb_protos.writer_pb2_grpc import WriterStub
28
-
29
- from nucliadb.tests.utils import inject_message
30
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
31
- from nucliadb.train import API_PREFIX
32
- from nucliadb.train.api.v1.router import KB_PREFIX
33
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
34
- from nucliadb_protos import resources_pb2 as rpb
35
-
36
-
37
- async def get_paragraph_streaming_batch_from_response(
38
- response: aiohttp.ClientResponse,
39
- ) -> AsyncIterator[ParagraphStreamingBatch]:
40
- while True:
41
- header = await response.content.read(4)
42
- if header == b"":
43
- break
44
- payload_size = int.from_bytes(header, byteorder="big", signed=False)
45
- payload = await response.content.read(payload_size)
46
- pcb = ParagraphStreamingBatch()
47
- pcb.ParseFromString(payload)
48
- assert pcb.data
49
- yield pcb
50
-
51
-
52
- @pytest.mark.asyncio
53
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
54
- async def test_generator_paragraph_streaming(
55
- train_rest_api: aiohttp.ClientSession,
56
- nucliadb_grpc: WriterStub,
57
- knowledgebox: str,
58
- ):
59
- kbid = knowledgebox
60
-
61
- await inject_resources_with_paragraphs(kbid, nucliadb_grpc)
62
-
63
- async with train_rest_api.get(
64
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
65
- ) as partitions:
66
- assert partitions.status == 200
67
- data = await partitions.json()
68
- assert len(data["partitions"]) == 1
69
- partition_id = data["partitions"][0]
70
-
71
- trainset = TrainSet()
72
- trainset.type = TaskType.PARAGRAPH_STREAMING
73
- trainset.batch_size = 5
74
- trainset.filter.labels.append("labelset_paragraphs")
75
-
76
- async with train_rest_api.post(
77
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
78
- data=trainset.SerializeToString(),
79
- ) as response:
80
- assert response.status == 200
81
- batches = []
82
- async for batch in get_batches_from_train_response_stream(
83
- response, ParagraphStreamingBatch
84
- ):
85
- batches.append(batch)
86
- assert len(batch.data) == 5
87
- assert len(batches) == 1
88
-
89
-
90
- async def inject_resources_with_paragraphs(kbid: str, nucliadb_grpc: WriterStub):
91
- await inject_message(nucliadb_grpc, smb_wonder_bm(kbid))
92
- await asyncio.sleep(0.1)
93
-
94
-
95
- def smb_wonder_bm(kbid: str) -> BrokerMessage:
96
- bmb = BrokerMessageBuilder(kbid=kbid)
97
- bmb.with_title("Super Mario Bros. Wonder")
98
- bmb.with_summary("SMB Wonder: the new Mario game from Nintendo")
99
-
100
- field_builder = FieldBuilder("smb-wonder", rpb.FieldType.FILE)
101
- paragraphs = [
102
- "Super Mario Bros. Wonder (SMB Wonder) is a 2023 platform game developed and published by Nintendo.\n", # noqa
103
- "SMB Wonder is a side-scrolling plaftorm game.\n",
104
- "As one of eight player characters, the player completes levels across the Flower Kingdom.", # noqa
105
- ]
106
- field_builder.with_extracted_text("".join(paragraphs))
107
- start = 0
108
- for paragraph in paragraphs:
109
- end = start + len(paragraph)
110
- field_builder.with_extracted_paragraph_metadata(
111
- rpb.Paragraph(start=start, end=end)
112
- )
113
- start = end
114
- bmb.add_field_builder(field_builder)
115
-
116
- bm = bmb.build()
117
-
118
- return bm
@@ -1,239 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- import uuid
22
- from typing import AsyncIterator
23
-
24
- import aiohttp
25
- import pytest
26
- from nucliadb_protos.dataset_pb2 import QuestionAnswerStreamingBatch, TaskType, TrainSet
27
- from nucliadb_protos.writer_pb2 import BrokerMessage
28
- from nucliadb_protos.writer_pb2_grpc import WriterStub
29
-
30
- from nucliadb.ingest.orm.resource import FIELD_TYPE_TO_ID
31
- from nucliadb.tests.utils import inject_message
32
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
33
- from nucliadb.train import API_PREFIX
34
- from nucliadb.train.api.v1.router import KB_PREFIX
35
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
36
- from nucliadb_protos import resources_pb2 as rpb
37
-
38
-
39
- async def get_question_answer_streaming_batch_from_response(
40
- response: aiohttp.ClientResponse,
41
- ) -> AsyncIterator[QuestionAnswerStreamingBatch]:
42
- while True:
43
- header = await response.content.read(4)
44
- if header == b"":
45
- break
46
- payload_size = int.from_bytes(header, byteorder="big", signed=False)
47
- payload = await response.content.read(payload_size)
48
- pcb = QuestionAnswerStreamingBatch()
49
- pcb.ParseFromString(payload)
50
- assert pcb.data
51
- yield pcb
52
-
53
-
54
- @pytest.mark.asyncio
55
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
56
- async def test_generator_question_answer_streaming(
57
- train_rest_api: aiohttp.ClientSession,
58
- nucliadb_grpc: WriterStub,
59
- knowledgebox: str,
60
- ):
61
- kbid = knowledgebox
62
-
63
- await inject_resources_with_question_answers(kbid, nucliadb_grpc)
64
-
65
- async with train_rest_api.get(
66
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
67
- ) as partitions:
68
- assert partitions.status == 200
69
- data = await partitions.json()
70
- assert len(data["partitions"]) == 1
71
- partition_id = data["partitions"][0]
72
-
73
- trainset = TrainSet()
74
- trainset.type = TaskType.QUESTION_ANSWER_STREAMING
75
- trainset.batch_size = 5
76
- trainset.filter.labels.append("labelset_paragraphs")
77
-
78
- async with train_rest_api.post(
79
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
80
- data=trainset.SerializeToString(),
81
- ) as response:
82
- assert response.status == 200
83
- batches = []
84
- questions = []
85
- answers = []
86
- question_paragraphs_count = 0
87
- answer_paragraphs_count = 0
88
- async for batch in get_batches_from_train_response_stream(
89
- response, QuestionAnswerStreamingBatch
90
- ):
91
- batches.append(batch)
92
- assert len(batch.data) == 3
93
- for data in batch.data:
94
- questions.append(data.question.text)
95
- question_paragraphs_count += len(data.question.paragraphs)
96
- answers.append(data.answer.text)
97
- answer_paragraphs_count += len(data.answer.paragraphs)
98
- assert len(batches) == 1
99
- assert len(questions) == len(answers) == 3
100
- assert len(set(questions)) == 2
101
- assert question_paragraphs_count == 2
102
- assert answer_paragraphs_count == 4
103
-
104
-
105
- async def inject_resources_with_question_answers(kbid: str, nucliadb_grpc: WriterStub):
106
- await inject_message(nucliadb_grpc, smb_wonder_bm(kbid))
107
- await asyncio.sleep(0.1)
108
-
109
-
110
- def smb_wonder_bm(kbid: str) -> BrokerMessage:
111
- rid = str(uuid.uuid4())
112
- bmb = BrokerMessageBuilder(kbid=kbid, rid=rid)
113
- bmb.with_title("Super Mario Bros. Wonder")
114
- bmb.with_summary("SMB Wonder: the new Mario game from Nintendo")
115
-
116
- field_builder = FieldBuilder("smb-wonder", rpb.FieldType.FILE)
117
- paragraphs = [
118
- "Super Mario Bros. Wonder (SMB Wonder) is a 2023 platform game developed and published by Nintendo.\n", # noqa
119
- "SMB Wonder is a side-scrolling plaftorm game.\n",
120
- "As one of eight player characters, the player completes levels across the Flower Kingdom.", # noqa
121
- ]
122
- field_builder.with_extracted_text("".join(paragraphs))
123
- start = 0
124
- for paragraph in paragraphs:
125
- end = start + len(paragraph)
126
- field_builder.with_extracted_paragraph_metadata(
127
- rpb.Paragraph(start=start, end=end)
128
- )
129
- start = end
130
-
131
- start = 0
132
- end = len(paragraphs[0])
133
- paragraph_0_id = (
134
- f"{rid}/{FIELD_TYPE_TO_ID[rpb.FieldType.FILE]}/smb-wonder/{start}-{end}"
135
- )
136
-
137
- start = len(paragraphs[0])
138
- end = len(paragraphs[0]) + len(paragraphs[1])
139
- paragraph_1_id = (
140
- f"{rid}/{FIELD_TYPE_TO_ID[rpb.FieldType.FILE]}/smb-wonder/{start}-{end}"
141
- )
142
-
143
- question = "What is SMB Wonder?"
144
- field_builder.add_question_answer(
145
- question=question,
146
- question_paragraph_ids=[paragraph_0_id],
147
- answer="SMB Wonder is a side-scrolling Nintendo Switch game",
148
- answer_paragraph_ids=[paragraph_0_id, paragraph_1_id],
149
- )
150
- field_builder.add_question_answer(
151
- question=question,
152
- question_paragraph_ids=[paragraph_0_id],
153
- answer="It's the new Mario game for Nintendo Switch",
154
- answer_paragraph_ids=[paragraph_0_id],
155
- )
156
-
157
- question = "Give me an example of side-scrolling game"
158
- field_builder.add_question_answer(
159
- question=question,
160
- answer="SMB Wonder game",
161
- answer_paragraph_ids=[paragraph_1_id],
162
- )
163
-
164
- bmb.add_field_builder(field_builder)
165
-
166
- bm = bmb.build()
167
-
168
- return bm
169
-
170
-
171
- @pytest.mark.asyncio
172
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
173
- async def test_generator_question_answer_streaming_streams_qa_annotations(
174
- train_rest_api: aiohttp.ClientSession,
175
- writer_rest_api: aiohttp.ClientSession,
176
- knowledgebox: str,
177
- ):
178
- kbid = knowledgebox
179
-
180
- resp = await writer_rest_api.post(
181
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/resources",
182
- json={
183
- "title": "Super Mario Bros. Wonder",
184
- "texts": {
185
- "smb-wonder": {
186
- "body": "Super Mario Bros. Wonder (SMB Wonder) is a 2023 platform game developed and published by Nintendo.\n" # noqa
187
- },
188
- },
189
- "fieldmetadata": [
190
- {
191
- "field": {"field_type": "text", "field": "smb-wonder"},
192
- "question_answers": [
193
- {
194
- "cancelled_by_user": True,
195
- "question_answer": {
196
- "question": {
197
- "text": "What is SMB Wonder?",
198
- "ids_paragraphs": [],
199
- },
200
- "answers": [
201
- {
202
- "ids_paragraphs": [],
203
- "language": "english",
204
- "text": "SMB Wonder is a Nintendo Switch game",
205
- }
206
- ],
207
- },
208
- }
209
- ],
210
- }
211
- ],
212
- },
213
- )
214
- assert resp.status == 201, resp.text
215
-
216
- async with train_rest_api.get(
217
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
218
- ) as partitions:
219
- assert partitions.status == 200
220
- data = await partitions.json()
221
- assert len(data["partitions"]) == 1
222
- partition_id = data["partitions"][0]
223
-
224
- trainset = TrainSet()
225
- trainset.type = TaskType.QUESTION_ANSWER_STREAMING
226
- trainset.batch_size = 5
227
-
228
- async with train_rest_api.post(
229
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
230
- data=trainset.SerializeToString(),
231
- ) as response:
232
- assert response.status == 200
233
- batches = []
234
- async for batch in get_batches_from_train_response_stream(
235
- response, QuestionAnswerStreamingBatch
236
- ):
237
- batches.append(batch)
238
- assert len(batch.data) == 1
239
- assert len(batches) == 1
@@ -1,143 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- import uuid
22
-
23
- import aiohttp
24
- import pytest
25
- from nucliadb_protos.dataset_pb2 import SentenceClassificationBatch, TaskType, TrainSet
26
- from nucliadb_protos.writer_pb2 import BrokerMessage
27
- from nucliadb_protos.writer_pb2_grpc import WriterStub
28
-
29
- from nucliadb.tests.utils import inject_message
30
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
31
- from nucliadb.train import API_PREFIX
32
- from nucliadb.train.api.v1.router import KB_PREFIX
33
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
34
- from nucliadb_protos import resources_pb2 as rpb
35
-
36
-
37
- @pytest.mark.asyncio
38
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
39
- async def test_generator_sentence_classification(
40
- train_rest_api: aiohttp.ClientSession,
41
- nucliadb_grpc: WriterStub,
42
- knowledgebox_with_labels: str,
43
- ):
44
- kbid = knowledgebox_with_labels
45
-
46
- await inject_resource_with_sentence_classification(kbid, nucliadb_grpc)
47
-
48
- async with train_rest_api.get(
49
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
50
- ) as partitions:
51
- assert partitions.status == 200
52
- data = await partitions.json()
53
- assert len(data["partitions"]) == 1
54
- partition_id = data["partitions"][0]
55
-
56
- trainset = TrainSet()
57
- trainset.type = TaskType.SENTENCE_CLASSIFICATION
58
- trainset.batch_size = 2
59
- trainset.filter.labels.append("labelset_paragraphs")
60
-
61
- async with train_rest_api.post(
62
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
63
- data=trainset.SerializeToString(),
64
- ) as response:
65
- assert response.status == 200
66
- batches = []
67
- async for batch in get_batches_from_train_response_stream(
68
- response, SentenceClassificationBatch
69
- ):
70
- batches.append(batch)
71
- assert len(batch.data) == 2
72
- assert len(batches) == 2
73
-
74
-
75
- async def inject_resource_with_sentence_classification(knowledgebox, writer):
76
- bm = broker_resource(knowledgebox)
77
- await inject_message(writer, bm)
78
- await asyncio.sleep(0.1)
79
- return bm.uuid
80
-
81
-
82
- def broker_resource(knowledgebox: str) -> BrokerMessage:
83
- rid = str(uuid.uuid4())
84
- bmb = BrokerMessageBuilder(kbid=knowledgebox, rid=rid)
85
- bmb.with_title("Title Resource")
86
- bmb.with_summary("Summary of document")
87
- bmb.with_resource_labels("labelset_resources", ["label_user"])
88
-
89
- file_field = FieldBuilder("file", rpb.FieldType.FILE)
90
- file_field.with_extracted_text(
91
- "My own text Ramon. This is great to be here. \n Where is my beer? Do you want to go shooping? This is a test!" # noqa
92
- )
93
-
94
- labelset = "labelset_paragraphs"
95
- labels = ["label_user"]
96
- file_field.with_user_paragraph_labels(f"{rid}/f/file/0-45", labelset, labels)
97
- file_field.with_user_paragraph_labels(f"{rid}/f/file/47-64", labelset, labels)
98
- file_field.with_user_paragraph_labels(f"{rid}/f/file/65-93", labelset, labels)
99
- file_field.with_user_paragraph_labels(f"{rid}/f/file/93-109", labelset, labels)
100
-
101
- classification = rpb.Classification(
102
- labelset="labelset_paragraphs", label="label_machine"
103
- )
104
- file_field.with_extracted_paragraph_metadata(
105
- rpb.Paragraph(
106
- start=0,
107
- end=45,
108
- classifications=[classification],
109
- sentences=[rpb.Sentence(start=0, end=45)],
110
- )
111
- )
112
- file_field.with_extracted_paragraph_metadata(
113
- rpb.Paragraph(
114
- start=47,
115
- end=64,
116
- classifications=[classification],
117
- sentences=[rpb.Sentence(start=47, end=64)],
118
- )
119
- )
120
- file_field.with_extracted_paragraph_metadata(
121
- rpb.Paragraph(
122
- start=65,
123
- end=93,
124
- classifications=[classification],
125
- sentences=[rpb.Sentence(start=65, end=93)],
126
- )
127
- )
128
- file_field.with_extracted_paragraph_metadata(
129
- rpb.Paragraph(
130
- start=93,
131
- end=109,
132
- classifications=[classification],
133
- sentences=[rpb.Sentence(start=94, end=109)],
134
- )
135
- )
136
-
137
- file_field.with_extracted_labels("labelset_resources", ["label_machine"])
138
-
139
- bmb.add_field_builder(file_field)
140
-
141
- bm = bmb.build()
142
-
143
- return bm