nucliadb 4.0.0.post542__py3-none-any.whl → 6.2.1.post2777__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (418) hide show
  1. migrations/0003_allfields_key.py +1 -35
  2. migrations/0009_upgrade_relations_and_texts_to_v2.py +4 -2
  3. migrations/0010_fix_corrupt_indexes.py +10 -10
  4. migrations/0011_materialize_labelset_ids.py +1 -16
  5. migrations/0012_rollover_shards.py +5 -10
  6. migrations/0014_rollover_shards.py +4 -5
  7. migrations/0015_targeted_rollover.py +5 -10
  8. migrations/0016_upgrade_to_paragraphs_v2.py +25 -28
  9. migrations/0017_multiple_writable_shards.py +2 -4
  10. migrations/0018_purge_orphan_kbslugs.py +5 -7
  11. migrations/0019_upgrade_to_paragraphs_v3.py +25 -28
  12. migrations/0020_drain_nodes_from_cluster.py +3 -3
  13. nucliadb/standalone/tests/unit/test_run.py → migrations/0021_overwrite_vectorsets_key.py +16 -19
  14. nucliadb/tests/unit/test_openapi.py → migrations/0022_fix_paragraph_deletion_bug.py +16 -11
  15. migrations/0023_backfill_pg_catalog.py +80 -0
  16. migrations/0025_assign_models_to_kbs_v2.py +113 -0
  17. migrations/0026_fix_high_cardinality_content_types.py +61 -0
  18. migrations/0027_rollover_texts3.py +73 -0
  19. nucliadb/ingest/fields/date.py → migrations/pg/0001_bootstrap.py +10 -12
  20. migrations/pg/0002_catalog.py +42 -0
  21. nucliadb/ingest/tests/unit/test_settings.py → migrations/pg/0003_catalog_kbid_index.py +5 -3
  22. nucliadb/common/cluster/base.py +30 -16
  23. nucliadb/common/cluster/discovery/base.py +6 -14
  24. nucliadb/common/cluster/discovery/k8s.py +9 -19
  25. nucliadb/common/cluster/discovery/manual.py +1 -3
  26. nucliadb/common/cluster/discovery/utils.py +1 -3
  27. nucliadb/common/cluster/grpc_node_dummy.py +3 -11
  28. nucliadb/common/cluster/index_node.py +10 -19
  29. nucliadb/common/cluster/manager.py +174 -59
  30. nucliadb/common/cluster/rebalance.py +27 -29
  31. nucliadb/common/cluster/rollover.py +353 -194
  32. nucliadb/common/cluster/settings.py +6 -0
  33. nucliadb/common/cluster/standalone/grpc_node_binding.py +13 -64
  34. nucliadb/common/cluster/standalone/index_node.py +4 -11
  35. nucliadb/common/cluster/standalone/service.py +2 -6
  36. nucliadb/common/cluster/standalone/utils.py +2 -6
  37. nucliadb/common/cluster/utils.py +29 -22
  38. nucliadb/common/constants.py +20 -0
  39. nucliadb/common/context/__init__.py +3 -0
  40. nucliadb/common/context/fastapi.py +8 -5
  41. nucliadb/{tests/knowledgeboxes/__init__.py → common/counters.py} +8 -2
  42. nucliadb/common/datamanagers/__init__.py +7 -1
  43. nucliadb/common/datamanagers/atomic.py +22 -4
  44. nucliadb/common/datamanagers/cluster.py +5 -5
  45. nucliadb/common/datamanagers/entities.py +6 -16
  46. nucliadb/common/datamanagers/fields.py +84 -0
  47. nucliadb/common/datamanagers/kb.py +83 -37
  48. nucliadb/common/datamanagers/labels.py +26 -56
  49. nucliadb/common/datamanagers/processing.py +2 -6
  50. nucliadb/common/datamanagers/resources.py +41 -103
  51. nucliadb/common/datamanagers/rollover.py +76 -15
  52. nucliadb/common/datamanagers/synonyms.py +1 -1
  53. nucliadb/common/datamanagers/utils.py +15 -6
  54. nucliadb/common/datamanagers/vectorsets.py +110 -0
  55. nucliadb/common/external_index_providers/base.py +257 -0
  56. nucliadb/{ingest/tests/unit/orm/test_orm_utils.py → common/external_index_providers/exceptions.py} +9 -8
  57. nucliadb/common/external_index_providers/manager.py +101 -0
  58. nucliadb/common/external_index_providers/pinecone.py +933 -0
  59. nucliadb/common/external_index_providers/settings.py +52 -0
  60. nucliadb/common/http_clients/auth.py +3 -6
  61. nucliadb/common/http_clients/processing.py +6 -11
  62. nucliadb/common/http_clients/utils.py +1 -3
  63. nucliadb/common/ids.py +240 -0
  64. nucliadb/common/locking.py +29 -7
  65. nucliadb/common/maindb/driver.py +11 -35
  66. nucliadb/common/maindb/exceptions.py +3 -0
  67. nucliadb/common/maindb/local.py +22 -9
  68. nucliadb/common/maindb/pg.py +206 -111
  69. nucliadb/common/maindb/utils.py +11 -42
  70. nucliadb/common/models_utils/from_proto.py +479 -0
  71. nucliadb/common/models_utils/to_proto.py +60 -0
  72. nucliadb/common/nidx.py +260 -0
  73. nucliadb/export_import/datamanager.py +25 -19
  74. nucliadb/export_import/exporter.py +5 -11
  75. nucliadb/export_import/importer.py +5 -7
  76. nucliadb/export_import/models.py +3 -3
  77. nucliadb/export_import/tasks.py +4 -4
  78. nucliadb/export_import/utils.py +25 -37
  79. nucliadb/health.py +1 -3
  80. nucliadb/ingest/app.py +15 -11
  81. nucliadb/ingest/consumer/auditing.py +21 -19
  82. nucliadb/ingest/consumer/consumer.py +82 -47
  83. nucliadb/ingest/consumer/materializer.py +5 -12
  84. nucliadb/ingest/consumer/pull.py +12 -27
  85. nucliadb/ingest/consumer/service.py +19 -17
  86. nucliadb/ingest/consumer/shard_creator.py +2 -4
  87. nucliadb/ingest/consumer/utils.py +1 -3
  88. nucliadb/ingest/fields/base.py +137 -105
  89. nucliadb/ingest/fields/conversation.py +18 -5
  90. nucliadb/ingest/fields/exceptions.py +1 -4
  91. nucliadb/ingest/fields/file.py +7 -16
  92. nucliadb/ingest/fields/link.py +5 -10
  93. nucliadb/ingest/fields/text.py +9 -4
  94. nucliadb/ingest/orm/brain.py +200 -213
  95. nucliadb/ingest/orm/broker_message.py +181 -0
  96. nucliadb/ingest/orm/entities.py +36 -51
  97. nucliadb/ingest/orm/exceptions.py +12 -0
  98. nucliadb/ingest/orm/knowledgebox.py +322 -197
  99. nucliadb/ingest/orm/processor/__init__.py +2 -700
  100. nucliadb/ingest/orm/processor/auditing.py +4 -23
  101. nucliadb/ingest/orm/processor/data_augmentation.py +164 -0
  102. nucliadb/ingest/orm/processor/pgcatalog.py +84 -0
  103. nucliadb/ingest/orm/processor/processor.py +752 -0
  104. nucliadb/ingest/orm/processor/sequence_manager.py +1 -1
  105. nucliadb/ingest/orm/resource.py +249 -402
  106. nucliadb/ingest/orm/utils.py +4 -4
  107. nucliadb/ingest/partitions.py +3 -9
  108. nucliadb/ingest/processing.py +64 -73
  109. nucliadb/ingest/py.typed +0 -0
  110. nucliadb/ingest/serialize.py +37 -167
  111. nucliadb/ingest/service/__init__.py +1 -3
  112. nucliadb/ingest/service/writer.py +185 -412
  113. nucliadb/ingest/settings.py +10 -20
  114. nucliadb/ingest/utils.py +3 -6
  115. nucliadb/learning_proxy.py +242 -55
  116. nucliadb/metrics_exporter.py +30 -19
  117. nucliadb/middleware/__init__.py +1 -3
  118. nucliadb/migrator/command.py +1 -3
  119. nucliadb/migrator/datamanager.py +13 -13
  120. nucliadb/migrator/migrator.py +47 -30
  121. nucliadb/migrator/utils.py +18 -10
  122. nucliadb/purge/__init__.py +139 -33
  123. nucliadb/purge/orphan_shards.py +7 -13
  124. nucliadb/reader/__init__.py +1 -3
  125. nucliadb/reader/api/models.py +1 -12
  126. nucliadb/reader/api/v1/__init__.py +0 -1
  127. nucliadb/reader/api/v1/download.py +21 -88
  128. nucliadb/reader/api/v1/export_import.py +1 -1
  129. nucliadb/reader/api/v1/knowledgebox.py +10 -10
  130. nucliadb/reader/api/v1/learning_config.py +2 -6
  131. nucliadb/reader/api/v1/resource.py +62 -88
  132. nucliadb/reader/api/v1/services.py +64 -83
  133. nucliadb/reader/app.py +12 -29
  134. nucliadb/reader/lifecycle.py +18 -4
  135. nucliadb/reader/py.typed +0 -0
  136. nucliadb/reader/reader/notifications.py +10 -28
  137. nucliadb/search/__init__.py +1 -3
  138. nucliadb/search/api/v1/__init__.py +1 -2
  139. nucliadb/search/api/v1/ask.py +17 -10
  140. nucliadb/search/api/v1/catalog.py +184 -0
  141. nucliadb/search/api/v1/feedback.py +16 -24
  142. nucliadb/search/api/v1/find.py +36 -36
  143. nucliadb/search/api/v1/knowledgebox.py +89 -60
  144. nucliadb/search/api/v1/resource/ask.py +2 -8
  145. nucliadb/search/api/v1/resource/search.py +49 -70
  146. nucliadb/search/api/v1/search.py +44 -210
  147. nucliadb/search/api/v1/suggest.py +39 -54
  148. nucliadb/search/app.py +12 -32
  149. nucliadb/search/lifecycle.py +10 -3
  150. nucliadb/search/predict.py +136 -187
  151. nucliadb/search/py.typed +0 -0
  152. nucliadb/search/requesters/utils.py +25 -58
  153. nucliadb/search/search/cache.py +149 -20
  154. nucliadb/search/search/chat/ask.py +571 -123
  155. nucliadb/search/{tests/unit/test_run.py → search/chat/exceptions.py} +14 -14
  156. nucliadb/search/search/chat/images.py +41 -17
  157. nucliadb/search/search/chat/prompt.py +817 -266
  158. nucliadb/search/search/chat/query.py +213 -309
  159. nucliadb/{tests/migrations/__init__.py → search/search/cut.py} +8 -8
  160. nucliadb/search/search/fetch.py +43 -36
  161. nucliadb/search/search/filters.py +9 -15
  162. nucliadb/search/search/find.py +214 -53
  163. nucliadb/search/search/find_merge.py +408 -391
  164. nucliadb/search/search/hydrator.py +191 -0
  165. nucliadb/search/search/merge.py +187 -223
  166. nucliadb/search/search/metrics.py +73 -2
  167. nucliadb/search/search/paragraphs.py +64 -106
  168. nucliadb/search/search/pgcatalog.py +233 -0
  169. nucliadb/search/search/predict_proxy.py +1 -1
  170. nucliadb/search/search/query.py +305 -150
  171. nucliadb/search/search/query_parser/exceptions.py +22 -0
  172. nucliadb/search/search/query_parser/models.py +101 -0
  173. nucliadb/search/search/query_parser/parser.py +183 -0
  174. nucliadb/search/search/rank_fusion.py +204 -0
  175. nucliadb/search/search/rerankers.py +270 -0
  176. nucliadb/search/search/shards.py +3 -32
  177. nucliadb/search/search/summarize.py +7 -18
  178. nucliadb/search/search/utils.py +27 -4
  179. nucliadb/search/settings.py +15 -1
  180. nucliadb/standalone/api_router.py +4 -10
  181. nucliadb/standalone/app.py +8 -14
  182. nucliadb/standalone/auth.py +7 -21
  183. nucliadb/standalone/config.py +7 -10
  184. nucliadb/standalone/lifecycle.py +26 -25
  185. nucliadb/standalone/migrations.py +1 -3
  186. nucliadb/standalone/purge.py +1 -1
  187. nucliadb/standalone/py.typed +0 -0
  188. nucliadb/standalone/run.py +3 -6
  189. nucliadb/standalone/settings.py +9 -16
  190. nucliadb/standalone/versions.py +15 -5
  191. nucliadb/tasks/consumer.py +8 -12
  192. nucliadb/tasks/producer.py +7 -6
  193. nucliadb/tests/config.py +53 -0
  194. nucliadb/train/__init__.py +1 -3
  195. nucliadb/train/api/utils.py +1 -2
  196. nucliadb/train/api/v1/shards.py +1 -1
  197. nucliadb/train/api/v1/trainset.py +2 -4
  198. nucliadb/train/app.py +10 -31
  199. nucliadb/train/generator.py +10 -19
  200. nucliadb/train/generators/field_classifier.py +7 -19
  201. nucliadb/train/generators/field_streaming.py +156 -0
  202. nucliadb/train/generators/image_classifier.py +12 -18
  203. nucliadb/train/generators/paragraph_classifier.py +5 -9
  204. nucliadb/train/generators/paragraph_streaming.py +6 -9
  205. nucliadb/train/generators/question_answer_streaming.py +19 -20
  206. nucliadb/train/generators/sentence_classifier.py +9 -15
  207. nucliadb/train/generators/token_classifier.py +48 -39
  208. nucliadb/train/generators/utils.py +14 -18
  209. nucliadb/train/lifecycle.py +7 -3
  210. nucliadb/train/nodes.py +23 -32
  211. nucliadb/train/py.typed +0 -0
  212. nucliadb/train/servicer.py +13 -21
  213. nucliadb/train/settings.py +2 -6
  214. nucliadb/train/types.py +13 -10
  215. nucliadb/train/upload.py +3 -6
  216. nucliadb/train/uploader.py +19 -23
  217. nucliadb/train/utils.py +1 -1
  218. nucliadb/writer/__init__.py +1 -3
  219. nucliadb/{ingest/fields/keywordset.py → writer/api/utils.py} +13 -10
  220. nucliadb/writer/api/v1/export_import.py +67 -14
  221. nucliadb/writer/api/v1/field.py +16 -269
  222. nucliadb/writer/api/v1/knowledgebox.py +218 -68
  223. nucliadb/writer/api/v1/resource.py +68 -88
  224. nucliadb/writer/api/v1/services.py +51 -70
  225. nucliadb/writer/api/v1/slug.py +61 -0
  226. nucliadb/writer/api/v1/transaction.py +67 -0
  227. nucliadb/writer/api/v1/upload.py +114 -113
  228. nucliadb/writer/app.py +6 -43
  229. nucliadb/writer/back_pressure.py +16 -38
  230. nucliadb/writer/exceptions.py +0 -4
  231. nucliadb/writer/lifecycle.py +21 -15
  232. nucliadb/writer/py.typed +0 -0
  233. nucliadb/writer/resource/audit.py +2 -1
  234. nucliadb/writer/resource/basic.py +48 -46
  235. nucliadb/writer/resource/field.py +25 -127
  236. nucliadb/writer/resource/origin.py +1 -2
  237. nucliadb/writer/settings.py +6 -2
  238. nucliadb/writer/tus/__init__.py +17 -15
  239. nucliadb/writer/tus/azure.py +111 -0
  240. nucliadb/writer/tus/dm.py +17 -5
  241. nucliadb/writer/tus/exceptions.py +1 -3
  242. nucliadb/writer/tus/gcs.py +49 -84
  243. nucliadb/writer/tus/local.py +21 -37
  244. nucliadb/writer/tus/s3.py +28 -68
  245. nucliadb/writer/tus/storage.py +5 -56
  246. nucliadb/writer/vectorsets.py +125 -0
  247. nucliadb-6.2.1.post2777.dist-info/METADATA +148 -0
  248. nucliadb-6.2.1.post2777.dist-info/RECORD +343 -0
  249. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/WHEEL +1 -1
  250. nucliadb/common/maindb/redis.py +0 -194
  251. nucliadb/common/maindb/tikv.py +0 -433
  252. nucliadb/ingest/fields/layout.py +0 -58
  253. nucliadb/ingest/tests/conftest.py +0 -30
  254. nucliadb/ingest/tests/fixtures.py +0 -764
  255. nucliadb/ingest/tests/integration/consumer/__init__.py +0 -18
  256. nucliadb/ingest/tests/integration/consumer/test_auditing.py +0 -78
  257. nucliadb/ingest/tests/integration/consumer/test_materializer.py +0 -126
  258. nucliadb/ingest/tests/integration/consumer/test_pull.py +0 -144
  259. nucliadb/ingest/tests/integration/consumer/test_service.py +0 -81
  260. nucliadb/ingest/tests/integration/consumer/test_shard_creator.py +0 -68
  261. nucliadb/ingest/tests/integration/ingest/test_ingest.py +0 -684
  262. nucliadb/ingest/tests/integration/ingest/test_processing_engine.py +0 -95
  263. nucliadb/ingest/tests/integration/ingest/test_relations.py +0 -272
  264. nucliadb/ingest/tests/unit/consumer/__init__.py +0 -18
  265. nucliadb/ingest/tests/unit/consumer/test_auditing.py +0 -139
  266. nucliadb/ingest/tests/unit/consumer/test_consumer.py +0 -69
  267. nucliadb/ingest/tests/unit/consumer/test_pull.py +0 -60
  268. nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +0 -140
  269. nucliadb/ingest/tests/unit/consumer/test_utils.py +0 -67
  270. nucliadb/ingest/tests/unit/orm/__init__.py +0 -19
  271. nucliadb/ingest/tests/unit/orm/test_brain.py +0 -247
  272. nucliadb/ingest/tests/unit/orm/test_brain_vectors.py +0 -74
  273. nucliadb/ingest/tests/unit/orm/test_processor.py +0 -131
  274. nucliadb/ingest/tests/unit/orm/test_resource.py +0 -331
  275. nucliadb/ingest/tests/unit/test_cache.py +0 -31
  276. nucliadb/ingest/tests/unit/test_partitions.py +0 -40
  277. nucliadb/ingest/tests/unit/test_processing.py +0 -171
  278. nucliadb/middleware/transaction.py +0 -117
  279. nucliadb/reader/api/v1/learning_collector.py +0 -63
  280. nucliadb/reader/tests/__init__.py +0 -19
  281. nucliadb/reader/tests/conftest.py +0 -31
  282. nucliadb/reader/tests/fixtures.py +0 -136
  283. nucliadb/reader/tests/test_list_resources.py +0 -75
  284. nucliadb/reader/tests/test_reader_file_download.py +0 -273
  285. nucliadb/reader/tests/test_reader_resource.py +0 -353
  286. nucliadb/reader/tests/test_reader_resource_field.py +0 -219
  287. nucliadb/search/api/v1/chat.py +0 -263
  288. nucliadb/search/api/v1/resource/chat.py +0 -174
  289. nucliadb/search/tests/__init__.py +0 -19
  290. nucliadb/search/tests/conftest.py +0 -33
  291. nucliadb/search/tests/fixtures.py +0 -199
  292. nucliadb/search/tests/node.py +0 -466
  293. nucliadb/search/tests/unit/__init__.py +0 -18
  294. nucliadb/search/tests/unit/api/__init__.py +0 -19
  295. nucliadb/search/tests/unit/api/v1/__init__.py +0 -19
  296. nucliadb/search/tests/unit/api/v1/resource/__init__.py +0 -19
  297. nucliadb/search/tests/unit/api/v1/resource/test_chat.py +0 -98
  298. nucliadb/search/tests/unit/api/v1/test_ask.py +0 -120
  299. nucliadb/search/tests/unit/api/v1/test_chat.py +0 -96
  300. nucliadb/search/tests/unit/api/v1/test_predict_proxy.py +0 -98
  301. nucliadb/search/tests/unit/api/v1/test_summarize.py +0 -99
  302. nucliadb/search/tests/unit/search/__init__.py +0 -18
  303. nucliadb/search/tests/unit/search/requesters/__init__.py +0 -18
  304. nucliadb/search/tests/unit/search/requesters/test_utils.py +0 -211
  305. nucliadb/search/tests/unit/search/search/__init__.py +0 -19
  306. nucliadb/search/tests/unit/search/search/test_shards.py +0 -45
  307. nucliadb/search/tests/unit/search/search/test_utils.py +0 -82
  308. nucliadb/search/tests/unit/search/test_chat_prompt.py +0 -270
  309. nucliadb/search/tests/unit/search/test_fetch.py +0 -108
  310. nucliadb/search/tests/unit/search/test_filters.py +0 -125
  311. nucliadb/search/tests/unit/search/test_paragraphs.py +0 -157
  312. nucliadb/search/tests/unit/search/test_predict_proxy.py +0 -106
  313. nucliadb/search/tests/unit/search/test_query.py +0 -153
  314. nucliadb/search/tests/unit/test_app.py +0 -79
  315. nucliadb/search/tests/unit/test_find_merge.py +0 -112
  316. nucliadb/search/tests/unit/test_merge.py +0 -34
  317. nucliadb/search/tests/unit/test_predict.py +0 -525
  318. nucliadb/standalone/tests/__init__.py +0 -19
  319. nucliadb/standalone/tests/conftest.py +0 -33
  320. nucliadb/standalone/tests/fixtures.py +0 -38
  321. nucliadb/standalone/tests/unit/__init__.py +0 -18
  322. nucliadb/standalone/tests/unit/test_api_router.py +0 -61
  323. nucliadb/standalone/tests/unit/test_auth.py +0 -169
  324. nucliadb/standalone/tests/unit/test_introspect.py +0 -35
  325. nucliadb/standalone/tests/unit/test_migrations.py +0 -63
  326. nucliadb/standalone/tests/unit/test_versions.py +0 -68
  327. nucliadb/tests/benchmarks/__init__.py +0 -19
  328. nucliadb/tests/benchmarks/test_search.py +0 -99
  329. nucliadb/tests/conftest.py +0 -32
  330. nucliadb/tests/fixtures.py +0 -735
  331. nucliadb/tests/knowledgeboxes/philosophy_books.py +0 -202
  332. nucliadb/tests/knowledgeboxes/ten_dummy_resources.py +0 -107
  333. nucliadb/tests/migrations/test_migration_0017.py +0 -76
  334. nucliadb/tests/migrations/test_migration_0018.py +0 -95
  335. nucliadb/tests/tikv.py +0 -240
  336. nucliadb/tests/unit/__init__.py +0 -19
  337. nucliadb/tests/unit/common/__init__.py +0 -19
  338. nucliadb/tests/unit/common/cluster/__init__.py +0 -19
  339. nucliadb/tests/unit/common/cluster/discovery/__init__.py +0 -19
  340. nucliadb/tests/unit/common/cluster/discovery/test_k8s.py +0 -172
  341. nucliadb/tests/unit/common/cluster/standalone/__init__.py +0 -18
  342. nucliadb/tests/unit/common/cluster/standalone/test_service.py +0 -114
  343. nucliadb/tests/unit/common/cluster/standalone/test_utils.py +0 -61
  344. nucliadb/tests/unit/common/cluster/test_cluster.py +0 -408
  345. nucliadb/tests/unit/common/cluster/test_kb_shard_manager.py +0 -173
  346. nucliadb/tests/unit/common/cluster/test_rebalance.py +0 -38
  347. nucliadb/tests/unit/common/cluster/test_rollover.py +0 -282
  348. nucliadb/tests/unit/common/maindb/__init__.py +0 -18
  349. nucliadb/tests/unit/common/maindb/test_driver.py +0 -127
  350. nucliadb/tests/unit/common/maindb/test_tikv.py +0 -53
  351. nucliadb/tests/unit/common/maindb/test_utils.py +0 -92
  352. nucliadb/tests/unit/common/test_context.py +0 -36
  353. nucliadb/tests/unit/export_import/__init__.py +0 -19
  354. nucliadb/tests/unit/export_import/test_datamanager.py +0 -37
  355. nucliadb/tests/unit/export_import/test_utils.py +0 -301
  356. nucliadb/tests/unit/migrator/__init__.py +0 -19
  357. nucliadb/tests/unit/migrator/test_migrator.py +0 -87
  358. nucliadb/tests/unit/tasks/__init__.py +0 -19
  359. nucliadb/tests/unit/tasks/conftest.py +0 -42
  360. nucliadb/tests/unit/tasks/test_consumer.py +0 -92
  361. nucliadb/tests/unit/tasks/test_producer.py +0 -95
  362. nucliadb/tests/unit/tasks/test_tasks.py +0 -58
  363. nucliadb/tests/unit/test_field_ids.py +0 -49
  364. nucliadb/tests/unit/test_health.py +0 -86
  365. nucliadb/tests/unit/test_kb_slugs.py +0 -54
  366. nucliadb/tests/unit/test_learning_proxy.py +0 -252
  367. nucliadb/tests/unit/test_metrics_exporter.py +0 -77
  368. nucliadb/tests/unit/test_purge.py +0 -136
  369. nucliadb/tests/utils/__init__.py +0 -74
  370. nucliadb/tests/utils/aiohttp_session.py +0 -44
  371. nucliadb/tests/utils/broker_messages/__init__.py +0 -171
  372. nucliadb/tests/utils/broker_messages/fields.py +0 -197
  373. nucliadb/tests/utils/broker_messages/helpers.py +0 -33
  374. nucliadb/tests/utils/entities.py +0 -78
  375. nucliadb/train/api/v1/check.py +0 -60
  376. nucliadb/train/tests/__init__.py +0 -19
  377. nucliadb/train/tests/conftest.py +0 -29
  378. nucliadb/train/tests/fixtures.py +0 -342
  379. nucliadb/train/tests/test_field_classification.py +0 -122
  380. nucliadb/train/tests/test_get_entities.py +0 -80
  381. nucliadb/train/tests/test_get_info.py +0 -51
  382. nucliadb/train/tests/test_get_ontology.py +0 -34
  383. nucliadb/train/tests/test_get_ontology_count.py +0 -63
  384. nucliadb/train/tests/test_image_classification.py +0 -221
  385. nucliadb/train/tests/test_list_fields.py +0 -39
  386. nucliadb/train/tests/test_list_paragraphs.py +0 -73
  387. nucliadb/train/tests/test_list_resources.py +0 -39
  388. nucliadb/train/tests/test_list_sentences.py +0 -71
  389. nucliadb/train/tests/test_paragraph_classification.py +0 -123
  390. nucliadb/train/tests/test_paragraph_streaming.py +0 -118
  391. nucliadb/train/tests/test_question_answer_streaming.py +0 -239
  392. nucliadb/train/tests/test_sentence_classification.py +0 -143
  393. nucliadb/train/tests/test_token_classification.py +0 -136
  394. nucliadb/train/tests/utils.py +0 -101
  395. nucliadb/writer/layouts/__init__.py +0 -51
  396. nucliadb/writer/layouts/v1.py +0 -59
  397. nucliadb/writer/tests/__init__.py +0 -19
  398. nucliadb/writer/tests/conftest.py +0 -31
  399. nucliadb/writer/tests/fixtures.py +0 -191
  400. nucliadb/writer/tests/test_fields.py +0 -475
  401. nucliadb/writer/tests/test_files.py +0 -740
  402. nucliadb/writer/tests/test_knowledgebox.py +0 -49
  403. nucliadb/writer/tests/test_reprocess_file_field.py +0 -133
  404. nucliadb/writer/tests/test_resources.py +0 -476
  405. nucliadb/writer/tests/test_service.py +0 -137
  406. nucliadb/writer/tests/test_tus.py +0 -203
  407. nucliadb/writer/tests/utils.py +0 -35
  408. nucliadb/writer/tus/pg.py +0 -125
  409. nucliadb-4.0.0.post542.dist-info/METADATA +0 -135
  410. nucliadb-4.0.0.post542.dist-info/RECORD +0 -462
  411. {nucliadb/ingest/tests → migrations/pg}/__init__.py +0 -0
  412. /nucliadb/{ingest/tests/integration → common/external_index_providers}/__init__.py +0 -0
  413. /nucliadb/{ingest/tests/integration/ingest → common/models_utils}/__init__.py +0 -0
  414. /nucliadb/{ingest/tests/unit → search/search/query_parser}/__init__.py +0 -0
  415. /nucliadb/{ingest/tests → tests}/vectors.py +0 -0
  416. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/entry_points.txt +0 -0
  417. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/top_level.txt +0 -0
  418. {nucliadb-4.0.0.post542.dist-info → nucliadb-6.2.1.post2777.dist-info}/zip-safe +0 -0
@@ -1,123 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- import uuid
22
-
23
- import aiohttp
24
- import pytest
25
- from nucliadb_protos.dataset_pb2 import ParagraphClassificationBatch, TaskType, TrainSet
26
- from nucliadb_protos.writer_pb2 import BrokerMessage
27
- from nucliadb_protos.writer_pb2_grpc import WriterStub
28
-
29
- from nucliadb.tests.utils import inject_message
30
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
31
- from nucliadb.train import API_PREFIX
32
- from nucliadb.train.api.v1.router import KB_PREFIX
33
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
34
- from nucliadb_protos import resources_pb2 as rpb
35
-
36
-
37
- @pytest.mark.asyncio
38
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
39
- async def test_generator_paragraph_classification(
40
- train_rest_api: aiohttp.ClientSession,
41
- nucliadb_grpc: WriterStub,
42
- knowledgebox_with_labels: str,
43
- ):
44
- kbid = knowledgebox_with_labels
45
-
46
- await inject_resource_with_paragraph_classification(kbid, nucliadb_grpc)
47
-
48
- async with train_rest_api.get(
49
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
50
- ) as partitions:
51
- assert partitions.status == 200
52
- data = await partitions.json()
53
- assert len(data["partitions"]) == 1
54
- partition_id = data["partitions"][0]
55
-
56
- trainset = TrainSet()
57
- trainset.type = TaskType.PARAGRAPH_CLASSIFICATION
58
- trainset.batch_size = 2
59
- trainset.filter.labels.append("labelset_paragraphs")
60
-
61
- async with train_rest_api.post(
62
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
63
- data=trainset.SerializeToString(),
64
- ) as response:
65
- assert response.status == 200
66
- batches = []
67
- async for batch in get_batches_from_train_response_stream(
68
- response, ParagraphClassificationBatch
69
- ):
70
- batches.append(batch)
71
- assert len(batch.data) == 2
72
- assert len(batches) == 2
73
-
74
-
75
- async def inject_resource_with_paragraph_classification(knowledgebox, writer):
76
- bm = broker_resource(knowledgebox)
77
- await inject_message(writer, bm)
78
- await asyncio.sleep(0.1)
79
- return bm.uuid
80
-
81
-
82
- def broker_resource(knowledgebox: str) -> BrokerMessage:
83
- rid = str(uuid.uuid4())
84
- bmb = BrokerMessageBuilder(kbid=knowledgebox, rid=rid)
85
- bmb.with_title("Title Resource")
86
- bmb.with_summary("Summary of document")
87
- bmb.with_resource_labels("labelset_resources", ["label_user"])
88
-
89
- file_field = FieldBuilder("file", rpb.FieldType.FILE)
90
- file_field.with_extracted_text(
91
- "My own text Ramon. This is great to be here. \n Where is my beer? Do you want to go shooping? This is a test!" # noqa
92
- )
93
-
94
- labelset = "labelset_paragraphs"
95
- labels = ["label_user"]
96
- file_field.with_user_paragraph_labels(f"{rid}/f/file/0-45", labelset, labels)
97
- file_field.with_user_paragraph_labels(f"{rid}/f/file/47-64", labelset, labels)
98
- file_field.with_user_paragraph_labels(f"{rid}/f/file/65-93", labelset, labels)
99
- file_field.with_user_paragraph_labels(f"{rid}/f/file/93-109", labelset, labels)
100
-
101
- classification = rpb.Classification(
102
- labelset="labelset_paragraphs", label="label_machine"
103
- )
104
- file_field.with_extracted_paragraph_metadata(
105
- rpb.Paragraph(start=0, end=45, classifications=[classification])
106
- )
107
- file_field.with_extracted_paragraph_metadata(
108
- rpb.Paragraph(start=47, end=64, classifications=[classification])
109
- )
110
- file_field.with_extracted_paragraph_metadata(
111
- rpb.Paragraph(start=65, end=93, classifications=[classification])
112
- )
113
- file_field.with_extracted_paragraph_metadata(
114
- rpb.Paragraph(start=93, end=109, classifications=[classification])
115
- )
116
-
117
- file_field.with_extracted_labels("labelset_resources", ["label_machine"])
118
-
119
- bmb.add_field_builder(file_field)
120
-
121
- bm = bmb.build()
122
-
123
- return bm
@@ -1,118 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- from typing import AsyncIterator
22
-
23
- import aiohttp
24
- import pytest
25
- from nucliadb_protos.dataset_pb2 import ParagraphStreamingBatch, TaskType, TrainSet
26
- from nucliadb_protos.writer_pb2 import BrokerMessage
27
- from nucliadb_protos.writer_pb2_grpc import WriterStub
28
-
29
- from nucliadb.tests.utils import inject_message
30
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
31
- from nucliadb.train import API_PREFIX
32
- from nucliadb.train.api.v1.router import KB_PREFIX
33
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
34
- from nucliadb_protos import resources_pb2 as rpb
35
-
36
-
37
- async def get_paragraph_streaming_batch_from_response(
38
- response: aiohttp.ClientResponse,
39
- ) -> AsyncIterator[ParagraphStreamingBatch]:
40
- while True:
41
- header = await response.content.read(4)
42
- if header == b"":
43
- break
44
- payload_size = int.from_bytes(header, byteorder="big", signed=False)
45
- payload = await response.content.read(payload_size)
46
- pcb = ParagraphStreamingBatch()
47
- pcb.ParseFromString(payload)
48
- assert pcb.data
49
- yield pcb
50
-
51
-
52
- @pytest.mark.asyncio
53
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
54
- async def test_generator_paragraph_streaming(
55
- train_rest_api: aiohttp.ClientSession,
56
- nucliadb_grpc: WriterStub,
57
- knowledgebox: str,
58
- ):
59
- kbid = knowledgebox
60
-
61
- await inject_resources_with_paragraphs(kbid, nucliadb_grpc)
62
-
63
- async with train_rest_api.get(
64
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
65
- ) as partitions:
66
- assert partitions.status == 200
67
- data = await partitions.json()
68
- assert len(data["partitions"]) == 1
69
- partition_id = data["partitions"][0]
70
-
71
- trainset = TrainSet()
72
- trainset.type = TaskType.PARAGRAPH_STREAMING
73
- trainset.batch_size = 5
74
- trainset.filter.labels.append("labelset_paragraphs")
75
-
76
- async with train_rest_api.post(
77
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
78
- data=trainset.SerializeToString(),
79
- ) as response:
80
- assert response.status == 200
81
- batches = []
82
- async for batch in get_batches_from_train_response_stream(
83
- response, ParagraphStreamingBatch
84
- ):
85
- batches.append(batch)
86
- assert len(batch.data) == 5
87
- assert len(batches) == 1
88
-
89
-
90
- async def inject_resources_with_paragraphs(kbid: str, nucliadb_grpc: WriterStub):
91
- await inject_message(nucliadb_grpc, smb_wonder_bm(kbid))
92
- await asyncio.sleep(0.1)
93
-
94
-
95
- def smb_wonder_bm(kbid: str) -> BrokerMessage:
96
- bmb = BrokerMessageBuilder(kbid=kbid)
97
- bmb.with_title("Super Mario Bros. Wonder")
98
- bmb.with_summary("SMB Wonder: the new Mario game from Nintendo")
99
-
100
- field_builder = FieldBuilder("smb-wonder", rpb.FieldType.FILE)
101
- paragraphs = [
102
- "Super Mario Bros. Wonder (SMB Wonder) is a 2023 platform game developed and published by Nintendo.\n", # noqa
103
- "SMB Wonder is a side-scrolling plaftorm game.\n",
104
- "As one of eight player characters, the player completes levels across the Flower Kingdom.", # noqa
105
- ]
106
- field_builder.with_extracted_text("".join(paragraphs))
107
- start = 0
108
- for paragraph in paragraphs:
109
- end = start + len(paragraph)
110
- field_builder.with_extracted_paragraph_metadata(
111
- rpb.Paragraph(start=start, end=end)
112
- )
113
- start = end
114
- bmb.add_field_builder(field_builder)
115
-
116
- bm = bmb.build()
117
-
118
- return bm
@@ -1,239 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- import uuid
22
- from typing import AsyncIterator
23
-
24
- import aiohttp
25
- import pytest
26
- from nucliadb_protos.dataset_pb2 import QuestionAnswerStreamingBatch, TaskType, TrainSet
27
- from nucliadb_protos.writer_pb2 import BrokerMessage
28
- from nucliadb_protos.writer_pb2_grpc import WriterStub
29
-
30
- from nucliadb.ingest.orm.resource import FIELD_TYPE_TO_ID
31
- from nucliadb.tests.utils import inject_message
32
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
33
- from nucliadb.train import API_PREFIX
34
- from nucliadb.train.api.v1.router import KB_PREFIX
35
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
36
- from nucliadb_protos import resources_pb2 as rpb
37
-
38
-
39
- async def get_question_answer_streaming_batch_from_response(
40
- response: aiohttp.ClientResponse,
41
- ) -> AsyncIterator[QuestionAnswerStreamingBatch]:
42
- while True:
43
- header = await response.content.read(4)
44
- if header == b"":
45
- break
46
- payload_size = int.from_bytes(header, byteorder="big", signed=False)
47
- payload = await response.content.read(payload_size)
48
- pcb = QuestionAnswerStreamingBatch()
49
- pcb.ParseFromString(payload)
50
- assert pcb.data
51
- yield pcb
52
-
53
-
54
- @pytest.mark.asyncio
55
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
56
- async def test_generator_question_answer_streaming(
57
- train_rest_api: aiohttp.ClientSession,
58
- nucliadb_grpc: WriterStub,
59
- knowledgebox: str,
60
- ):
61
- kbid = knowledgebox
62
-
63
- await inject_resources_with_question_answers(kbid, nucliadb_grpc)
64
-
65
- async with train_rest_api.get(
66
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
67
- ) as partitions:
68
- assert partitions.status == 200
69
- data = await partitions.json()
70
- assert len(data["partitions"]) == 1
71
- partition_id = data["partitions"][0]
72
-
73
- trainset = TrainSet()
74
- trainset.type = TaskType.QUESTION_ANSWER_STREAMING
75
- trainset.batch_size = 5
76
- trainset.filter.labels.append("labelset_paragraphs")
77
-
78
- async with train_rest_api.post(
79
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
80
- data=trainset.SerializeToString(),
81
- ) as response:
82
- assert response.status == 200
83
- batches = []
84
- questions = []
85
- answers = []
86
- question_paragraphs_count = 0
87
- answer_paragraphs_count = 0
88
- async for batch in get_batches_from_train_response_stream(
89
- response, QuestionAnswerStreamingBatch
90
- ):
91
- batches.append(batch)
92
- assert len(batch.data) == 3
93
- for data in batch.data:
94
- questions.append(data.question.text)
95
- question_paragraphs_count += len(data.question.paragraphs)
96
- answers.append(data.answer.text)
97
- answer_paragraphs_count += len(data.answer.paragraphs)
98
- assert len(batches) == 1
99
- assert len(questions) == len(answers) == 3
100
- assert len(set(questions)) == 2
101
- assert question_paragraphs_count == 2
102
- assert answer_paragraphs_count == 4
103
-
104
-
105
- async def inject_resources_with_question_answers(kbid: str, nucliadb_grpc: WriterStub):
106
- await inject_message(nucliadb_grpc, smb_wonder_bm(kbid))
107
- await asyncio.sleep(0.1)
108
-
109
-
110
- def smb_wonder_bm(kbid: str) -> BrokerMessage:
111
- rid = str(uuid.uuid4())
112
- bmb = BrokerMessageBuilder(kbid=kbid, rid=rid)
113
- bmb.with_title("Super Mario Bros. Wonder")
114
- bmb.with_summary("SMB Wonder: the new Mario game from Nintendo")
115
-
116
- field_builder = FieldBuilder("smb-wonder", rpb.FieldType.FILE)
117
- paragraphs = [
118
- "Super Mario Bros. Wonder (SMB Wonder) is a 2023 platform game developed and published by Nintendo.\n", # noqa
119
- "SMB Wonder is a side-scrolling plaftorm game.\n",
120
- "As one of eight player characters, the player completes levels across the Flower Kingdom.", # noqa
121
- ]
122
- field_builder.with_extracted_text("".join(paragraphs))
123
- start = 0
124
- for paragraph in paragraphs:
125
- end = start + len(paragraph)
126
- field_builder.with_extracted_paragraph_metadata(
127
- rpb.Paragraph(start=start, end=end)
128
- )
129
- start = end
130
-
131
- start = 0
132
- end = len(paragraphs[0])
133
- paragraph_0_id = (
134
- f"{rid}/{FIELD_TYPE_TO_ID[rpb.FieldType.FILE]}/smb-wonder/{start}-{end}"
135
- )
136
-
137
- start = len(paragraphs[0])
138
- end = len(paragraphs[0]) + len(paragraphs[1])
139
- paragraph_1_id = (
140
- f"{rid}/{FIELD_TYPE_TO_ID[rpb.FieldType.FILE]}/smb-wonder/{start}-{end}"
141
- )
142
-
143
- question = "What is SMB Wonder?"
144
- field_builder.add_question_answer(
145
- question=question,
146
- question_paragraph_ids=[paragraph_0_id],
147
- answer="SMB Wonder is a side-scrolling Nintendo Switch game",
148
- answer_paragraph_ids=[paragraph_0_id, paragraph_1_id],
149
- )
150
- field_builder.add_question_answer(
151
- question=question,
152
- question_paragraph_ids=[paragraph_0_id],
153
- answer="It's the new Mario game for Nintendo Switch",
154
- answer_paragraph_ids=[paragraph_0_id],
155
- )
156
-
157
- question = "Give me an example of side-scrolling game"
158
- field_builder.add_question_answer(
159
- question=question,
160
- answer="SMB Wonder game",
161
- answer_paragraph_ids=[paragraph_1_id],
162
- )
163
-
164
- bmb.add_field_builder(field_builder)
165
-
166
- bm = bmb.build()
167
-
168
- return bm
169
-
170
-
171
- @pytest.mark.asyncio
172
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
173
- async def test_generator_question_answer_streaming_streams_qa_annotations(
174
- train_rest_api: aiohttp.ClientSession,
175
- writer_rest_api: aiohttp.ClientSession,
176
- knowledgebox: str,
177
- ):
178
- kbid = knowledgebox
179
-
180
- resp = await writer_rest_api.post(
181
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/resources",
182
- json={
183
- "title": "Super Mario Bros. Wonder",
184
- "texts": {
185
- "smb-wonder": {
186
- "body": "Super Mario Bros. Wonder (SMB Wonder) is a 2023 platform game developed and published by Nintendo.\n" # noqa
187
- },
188
- },
189
- "fieldmetadata": [
190
- {
191
- "field": {"field_type": "text", "field": "smb-wonder"},
192
- "question_answers": [
193
- {
194
- "cancelled_by_user": True,
195
- "question_answer": {
196
- "question": {
197
- "text": "What is SMB Wonder?",
198
- "ids_paragraphs": [],
199
- },
200
- "answers": [
201
- {
202
- "ids_paragraphs": [],
203
- "language": "english",
204
- "text": "SMB Wonder is a Nintendo Switch game",
205
- }
206
- ],
207
- },
208
- }
209
- ],
210
- }
211
- ],
212
- },
213
- )
214
- assert resp.status == 201, resp.text
215
-
216
- async with train_rest_api.get(
217
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
218
- ) as partitions:
219
- assert partitions.status == 200
220
- data = await partitions.json()
221
- assert len(data["partitions"]) == 1
222
- partition_id = data["partitions"][0]
223
-
224
- trainset = TrainSet()
225
- trainset.type = TaskType.QUESTION_ANSWER_STREAMING
226
- trainset.batch_size = 5
227
-
228
- async with train_rest_api.post(
229
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
230
- data=trainset.SerializeToString(),
231
- ) as response:
232
- assert response.status == 200
233
- batches = []
234
- async for batch in get_batches_from_train_response_stream(
235
- response, QuestionAnswerStreamingBatch
236
- ):
237
- batches.append(batch)
238
- assert len(batch.data) == 1
239
- assert len(batches) == 1
@@ -1,143 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import asyncio
21
- import uuid
22
-
23
- import aiohttp
24
- import pytest
25
- from nucliadb_protos.dataset_pb2 import SentenceClassificationBatch, TaskType, TrainSet
26
- from nucliadb_protos.writer_pb2 import BrokerMessage
27
- from nucliadb_protos.writer_pb2_grpc import WriterStub
28
-
29
- from nucliadb.tests.utils import inject_message
30
- from nucliadb.tests.utils.broker_messages import BrokerMessageBuilder, FieldBuilder
31
- from nucliadb.train import API_PREFIX
32
- from nucliadb.train.api.v1.router import KB_PREFIX
33
- from nucliadb.train.tests.utils import get_batches_from_train_response_stream
34
- from nucliadb_protos import resources_pb2 as rpb
35
-
36
-
37
- @pytest.mark.asyncio
38
- @pytest.mark.parametrize("knowledgebox", ["STABLE", "EXPERIMENTAL"], indirect=True)
39
- async def test_generator_sentence_classification(
40
- train_rest_api: aiohttp.ClientSession,
41
- nucliadb_grpc: WriterStub,
42
- knowledgebox_with_labels: str,
43
- ):
44
- kbid = knowledgebox_with_labels
45
-
46
- await inject_resource_with_sentence_classification(kbid, nucliadb_grpc)
47
-
48
- async with train_rest_api.get(
49
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset"
50
- ) as partitions:
51
- assert partitions.status == 200
52
- data = await partitions.json()
53
- assert len(data["partitions"]) == 1
54
- partition_id = data["partitions"][0]
55
-
56
- trainset = TrainSet()
57
- trainset.type = TaskType.SENTENCE_CLASSIFICATION
58
- trainset.batch_size = 2
59
- trainset.filter.labels.append("labelset_paragraphs")
60
-
61
- async with train_rest_api.post(
62
- f"/{API_PREFIX}/v1/{KB_PREFIX}/{kbid}/trainset/{partition_id}",
63
- data=trainset.SerializeToString(),
64
- ) as response:
65
- assert response.status == 200
66
- batches = []
67
- async for batch in get_batches_from_train_response_stream(
68
- response, SentenceClassificationBatch
69
- ):
70
- batches.append(batch)
71
- assert len(batch.data) == 2
72
- assert len(batches) == 2
73
-
74
-
75
- async def inject_resource_with_sentence_classification(knowledgebox, writer):
76
- bm = broker_resource(knowledgebox)
77
- await inject_message(writer, bm)
78
- await asyncio.sleep(0.1)
79
- return bm.uuid
80
-
81
-
82
- def broker_resource(knowledgebox: str) -> BrokerMessage:
83
- rid = str(uuid.uuid4())
84
- bmb = BrokerMessageBuilder(kbid=knowledgebox, rid=rid)
85
- bmb.with_title("Title Resource")
86
- bmb.with_summary("Summary of document")
87
- bmb.with_resource_labels("labelset_resources", ["label_user"])
88
-
89
- file_field = FieldBuilder("file", rpb.FieldType.FILE)
90
- file_field.with_extracted_text(
91
- "My own text Ramon. This is great to be here. \n Where is my beer? Do you want to go shooping? This is a test!" # noqa
92
- )
93
-
94
- labelset = "labelset_paragraphs"
95
- labels = ["label_user"]
96
- file_field.with_user_paragraph_labels(f"{rid}/f/file/0-45", labelset, labels)
97
- file_field.with_user_paragraph_labels(f"{rid}/f/file/47-64", labelset, labels)
98
- file_field.with_user_paragraph_labels(f"{rid}/f/file/65-93", labelset, labels)
99
- file_field.with_user_paragraph_labels(f"{rid}/f/file/93-109", labelset, labels)
100
-
101
- classification = rpb.Classification(
102
- labelset="labelset_paragraphs", label="label_machine"
103
- )
104
- file_field.with_extracted_paragraph_metadata(
105
- rpb.Paragraph(
106
- start=0,
107
- end=45,
108
- classifications=[classification],
109
- sentences=[rpb.Sentence(start=0, end=45)],
110
- )
111
- )
112
- file_field.with_extracted_paragraph_metadata(
113
- rpb.Paragraph(
114
- start=47,
115
- end=64,
116
- classifications=[classification],
117
- sentences=[rpb.Sentence(start=47, end=64)],
118
- )
119
- )
120
- file_field.with_extracted_paragraph_metadata(
121
- rpb.Paragraph(
122
- start=65,
123
- end=93,
124
- classifications=[classification],
125
- sentences=[rpb.Sentence(start=65, end=93)],
126
- )
127
- )
128
- file_field.with_extracted_paragraph_metadata(
129
- rpb.Paragraph(
130
- start=93,
131
- end=109,
132
- classifications=[classification],
133
- sentences=[rpb.Sentence(start=94, end=109)],
134
- )
135
- )
136
-
137
- file_field.with_extracted_labels("labelset_resources", ["label_machine"])
138
-
139
- bmb.add_field_builder(file_field)
140
-
141
- bm = bmb.build()
142
-
143
- return bm