endoreg-db 0.8.9.2__py3-none-any.whl → 0.8.9.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of endoreg-db might be problematic. Click here for more details.

Files changed (450) hide show
  1. endoreg_db/admin.py +10 -5
  2. endoreg_db/apps.py +4 -7
  3. endoreg_db/authz/auth.py +1 -0
  4. endoreg_db/authz/backends.py +1 -1
  5. endoreg_db/authz/management/commands/list_routes.py +2 -0
  6. endoreg_db/authz/middleware.py +8 -7
  7. endoreg_db/authz/permissions.py +21 -10
  8. endoreg_db/authz/policy.py +14 -19
  9. endoreg_db/authz/views_auth.py +14 -10
  10. endoreg_db/codemods/rename_datetime_fields.py +8 -1
  11. endoreg_db/exceptions.py +5 -2
  12. endoreg_db/forms/__init__.py +0 -1
  13. endoreg_db/forms/examination_form.py +4 -3
  14. endoreg_db/forms/patient_finding_intervention_form.py +30 -8
  15. endoreg_db/forms/patient_form.py +9 -13
  16. endoreg_db/forms/questionnaires/__init__.py +1 -1
  17. endoreg_db/forms/settings/__init__.py +4 -1
  18. endoreg_db/forms/unit.py +2 -1
  19. endoreg_db/helpers/count_db.py +17 -14
  20. endoreg_db/helpers/default_objects.py +2 -1
  21. endoreg_db/helpers/download_segmentation_model.py +4 -3
  22. endoreg_db/helpers/interact.py +0 -5
  23. endoreg_db/helpers/test_video_helper.py +33 -25
  24. endoreg_db/import_files/__init__.py +1 -1
  25. endoreg_db/import_files/context/__init__.py +1 -1
  26. endoreg_db/import_files/context/default_sensitive_meta.py +11 -9
  27. endoreg_db/import_files/context/ensure_center.py +4 -4
  28. endoreg_db/import_files/context/file_lock.py +3 -3
  29. endoreg_db/import_files/context/import_context.py +11 -12
  30. endoreg_db/import_files/context/validate_directories.py +1 -0
  31. endoreg_db/import_files/file_storage/create_report_file.py +57 -34
  32. endoreg_db/import_files/file_storage/create_video_file.py +64 -35
  33. endoreg_db/import_files/file_storage/sensitive_meta_storage.py +5 -2
  34. endoreg_db/import_files/file_storage/state_management.py +89 -122
  35. endoreg_db/import_files/file_storage/storage.py +5 -1
  36. endoreg_db/import_files/processing/report_processing/report_anonymization.py +24 -19
  37. endoreg_db/import_files/processing/sensitive_meta_adapter.py +3 -3
  38. endoreg_db/import_files/processing/video_processing/video_anonymization.py +18 -18
  39. endoreg_db/import_files/pseudonymization/k_anonymity.py +8 -9
  40. endoreg_db/import_files/pseudonymization/k_pseudonymity.py +16 -5
  41. endoreg_db/import_files/report_import_service.py +36 -30
  42. endoreg_db/import_files/video_import_service.py +27 -23
  43. endoreg_db/logger_conf.py +56 -40
  44. endoreg_db/management/__init__.py +1 -1
  45. endoreg_db/management/commands/__init__.py +1 -1
  46. endoreg_db/management/commands/check_auth.py +45 -38
  47. endoreg_db/management/commands/create_model_meta_from_huggingface.py +53 -2
  48. endoreg_db/management/commands/create_multilabel_model_meta.py +54 -19
  49. endoreg_db/management/commands/fix_missing_patient_data.py +105 -71
  50. endoreg_db/management/commands/fix_video_paths.py +75 -54
  51. endoreg_db/management/commands/import_report.py +1 -3
  52. endoreg_db/management/commands/list_routes.py +2 -0
  53. endoreg_db/management/commands/load_ai_model_data.py +8 -2
  54. endoreg_db/management/commands/load_ai_model_label_data.py +0 -1
  55. endoreg_db/management/commands/load_center_data.py +3 -3
  56. endoreg_db/management/commands/load_distribution_data.py +35 -38
  57. endoreg_db/management/commands/load_endoscope_data.py +0 -3
  58. endoreg_db/management/commands/load_examination_data.py +20 -4
  59. endoreg_db/management/commands/load_finding_data.py +18 -3
  60. endoreg_db/management/commands/load_gender_data.py +17 -24
  61. endoreg_db/management/commands/load_green_endoscopy_wuerzburg_data.py +95 -85
  62. endoreg_db/management/commands/load_information_source.py +0 -3
  63. endoreg_db/management/commands/load_lab_value_data.py +14 -3
  64. endoreg_db/management/commands/load_legacy_data.py +303 -0
  65. endoreg_db/management/commands/load_name_data.py +1 -2
  66. endoreg_db/management/commands/load_pdf_type_data.py +4 -8
  67. endoreg_db/management/commands/load_profession_data.py +0 -1
  68. endoreg_db/management/commands/load_report_reader_flag_data.py +0 -4
  69. endoreg_db/management/commands/load_requirement_data.py +6 -2
  70. endoreg_db/management/commands/load_unit_data.py +0 -4
  71. endoreg_db/management/commands/load_user_groups.py +5 -7
  72. endoreg_db/management/commands/model_input.py +169 -0
  73. endoreg_db/management/commands/register_ai_model.py +22 -16
  74. endoreg_db/management/commands/setup_endoreg_db.py +110 -32
  75. endoreg_db/management/commands/storage_management.py +14 -8
  76. endoreg_db/management/commands/summarize_db_content.py +154 -63
  77. endoreg_db/management/commands/train_image_multilabel_model.py +144 -0
  78. endoreg_db/management/commands/validate_video_files.py +82 -50
  79. endoreg_db/management/commands/video_validation.py +4 -6
  80. endoreg_db/migrations/0001_initial.py +112 -63
  81. endoreg_db/models/__init__.py +8 -0
  82. endoreg_db/models/administration/ai/active_model.py +5 -5
  83. endoreg_db/models/administration/ai/ai_model.py +41 -18
  84. endoreg_db/models/administration/ai/model_type.py +1 -0
  85. endoreg_db/models/administration/case/case.py +22 -22
  86. endoreg_db/models/administration/center/__init__.py +5 -5
  87. endoreg_db/models/administration/center/center.py +6 -2
  88. endoreg_db/models/administration/center/center_resource.py +18 -4
  89. endoreg_db/models/administration/center/center_shift.py +3 -1
  90. endoreg_db/models/administration/center/center_waste.py +6 -2
  91. endoreg_db/models/administration/person/__init__.py +1 -1
  92. endoreg_db/models/administration/person/employee/__init__.py +1 -1
  93. endoreg_db/models/administration/person/employee/employee_type.py +3 -1
  94. endoreg_db/models/administration/person/examiner/__init__.py +1 -1
  95. endoreg_db/models/administration/person/examiner/examiner.py +10 -2
  96. endoreg_db/models/administration/person/names/first_name.py +6 -4
  97. endoreg_db/models/administration/person/names/last_name.py +4 -3
  98. endoreg_db/models/administration/person/patient/__init__.py +1 -1
  99. endoreg_db/models/administration/person/patient/patient.py +0 -1
  100. endoreg_db/models/administration/person/patient/patient_external_id.py +0 -1
  101. endoreg_db/models/administration/person/person.py +1 -1
  102. endoreg_db/models/administration/product/__init__.py +7 -6
  103. endoreg_db/models/administration/product/product.py +6 -2
  104. endoreg_db/models/administration/product/product_group.py +9 -7
  105. endoreg_db/models/administration/product/product_material.py +9 -2
  106. endoreg_db/models/administration/product/reference_product.py +64 -15
  107. endoreg_db/models/administration/qualification/qualification.py +3 -1
  108. endoreg_db/models/administration/shift/shift.py +3 -1
  109. endoreg_db/models/administration/shift/shift_type.py +12 -4
  110. endoreg_db/models/aidataset/__init__.py +5 -0
  111. endoreg_db/models/aidataset/aidataset.py +193 -0
  112. endoreg_db/models/label/__init__.py +1 -1
  113. endoreg_db/models/label/label.py +10 -2
  114. endoreg_db/models/label/label_set.py +3 -1
  115. endoreg_db/models/label/label_video_segment/_create_from_video.py +6 -2
  116. endoreg_db/models/label/label_video_segment/label_video_segment.py +148 -44
  117. endoreg_db/models/media/__init__.py +12 -5
  118. endoreg_db/models/media/frame/__init__.py +1 -1
  119. endoreg_db/models/media/frame/frame.py +34 -8
  120. endoreg_db/models/media/pdf/__init__.py +2 -1
  121. endoreg_db/models/media/pdf/raw_pdf.py +11 -4
  122. endoreg_db/models/media/pdf/report_file.py +6 -2
  123. endoreg_db/models/media/pdf/report_reader/__init__.py +3 -3
  124. endoreg_db/models/media/pdf/report_reader/report_reader_flag.py +15 -5
  125. endoreg_db/models/media/video/create_from_file.py +20 -41
  126. endoreg_db/models/media/video/pipe_1.py +75 -30
  127. endoreg_db/models/media/video/pipe_2.py +37 -12
  128. endoreg_db/models/media/video/video_file.py +36 -24
  129. endoreg_db/models/media/video/video_file_ai.py +235 -70
  130. endoreg_db/models/media/video/video_file_anonymize.py +240 -65
  131. endoreg_db/models/media/video/video_file_frames/_bulk_create_frames.py +6 -1
  132. endoreg_db/models/media/video/video_file_frames/_create_frame_object.py +3 -1
  133. endoreg_db/models/media/video/video_file_frames/_delete_frames.py +30 -9
  134. endoreg_db/models/media/video/video_file_frames/_extract_frames.py +95 -29
  135. endoreg_db/models/media/video/video_file_frames/_get_frame.py +13 -3
  136. endoreg_db/models/media/video/video_file_frames/_get_frame_path.py +4 -1
  137. endoreg_db/models/media/video/video_file_frames/_get_frame_paths.py +15 -3
  138. endoreg_db/models/media/video/video_file_frames/_get_frame_range.py +15 -3
  139. endoreg_db/models/media/video/video_file_frames/_get_frames.py +7 -2
  140. endoreg_db/models/media/video/video_file_frames/_initialize_frames.py +109 -23
  141. endoreg_db/models/media/video/video_file_frames/_manage_frame_range.py +111 -27
  142. endoreg_db/models/media/video/video_file_frames/_mark_frames_extracted_status.py +46 -13
  143. endoreg_db/models/media/video/video_file_io.py +85 -33
  144. endoreg_db/models/media/video/video_file_meta/__init__.py +6 -6
  145. endoreg_db/models/media/video/video_file_meta/get_crop_template.py +17 -4
  146. endoreg_db/models/media/video/video_file_meta/get_endo_roi.py +28 -7
  147. endoreg_db/models/media/video/video_file_meta/get_fps.py +46 -13
  148. endoreg_db/models/media/video/video_file_meta/initialize_video_specs.py +81 -20
  149. endoreg_db/models/media/video/video_file_meta/text_meta.py +61 -20
  150. endoreg_db/models/media/video/video_file_meta/video_meta.py +40 -12
  151. endoreg_db/models/media/video/video_file_segments.py +118 -27
  152. endoreg_db/models/media/video/video_metadata.py +25 -6
  153. endoreg_db/models/media/video/video_processing.py +54 -15
  154. endoreg_db/models/medical/__init__.py +3 -13
  155. endoreg_db/models/medical/contraindication/__init__.py +3 -1
  156. endoreg_db/models/medical/disease.py +18 -6
  157. endoreg_db/models/medical/event.py +6 -2
  158. endoreg_db/models/medical/examination/__init__.py +5 -1
  159. endoreg_db/models/medical/examination/examination.py +22 -6
  160. endoreg_db/models/medical/examination/examination_indication.py +23 -7
  161. endoreg_db/models/medical/examination/examination_time.py +6 -2
  162. endoreg_db/models/medical/finding/__init__.py +3 -1
  163. endoreg_db/models/medical/finding/finding.py +37 -12
  164. endoreg_db/models/medical/finding/finding_classification.py +27 -8
  165. endoreg_db/models/medical/finding/finding_intervention.py +19 -6
  166. endoreg_db/models/medical/finding/finding_type.py +3 -1
  167. endoreg_db/models/medical/hardware/__init__.py +1 -1
  168. endoreg_db/models/medical/hardware/endoscope.py +14 -2
  169. endoreg_db/models/medical/laboratory/__init__.py +1 -1
  170. endoreg_db/models/medical/laboratory/lab_value.py +139 -39
  171. endoreg_db/models/medical/medication/__init__.py +7 -3
  172. endoreg_db/models/medical/medication/medication.py +3 -1
  173. endoreg_db/models/medical/medication/medication_indication.py +3 -1
  174. endoreg_db/models/medical/medication/medication_indication_type.py +11 -3
  175. endoreg_db/models/medical/medication/medication_intake_time.py +3 -1
  176. endoreg_db/models/medical/medication/medication_schedule.py +3 -1
  177. endoreg_db/models/medical/patient/__init__.py +2 -10
  178. endoreg_db/models/medical/patient/medication_examples.py +3 -14
  179. endoreg_db/models/medical/patient/patient_disease.py +17 -5
  180. endoreg_db/models/medical/patient/patient_event.py +12 -4
  181. endoreg_db/models/medical/patient/patient_examination.py +52 -15
  182. endoreg_db/models/medical/patient/patient_examination_indication.py +15 -4
  183. endoreg_db/models/medical/patient/patient_finding.py +105 -29
  184. endoreg_db/models/medical/patient/patient_finding_classification.py +41 -12
  185. endoreg_db/models/medical/patient/patient_finding_intervention.py +11 -3
  186. endoreg_db/models/medical/patient/patient_lab_sample.py +6 -2
  187. endoreg_db/models/medical/patient/patient_lab_value.py +42 -10
  188. endoreg_db/models/medical/patient/patient_medication.py +25 -7
  189. endoreg_db/models/medical/patient/patient_medication_schedule.py +34 -10
  190. endoreg_db/models/metadata/model_meta.py +40 -12
  191. endoreg_db/models/metadata/model_meta_logic.py +51 -16
  192. endoreg_db/models/metadata/sensitive_meta.py +65 -28
  193. endoreg_db/models/metadata/sensitive_meta_logic.py +28 -26
  194. endoreg_db/models/metadata/video_meta.py +146 -39
  195. endoreg_db/models/metadata/video_prediction_logic.py +70 -21
  196. endoreg_db/models/metadata/video_prediction_meta.py +80 -27
  197. endoreg_db/models/operation_log.py +63 -0
  198. endoreg_db/models/other/__init__.py +10 -10
  199. endoreg_db/models/other/distribution/__init__.py +9 -7
  200. endoreg_db/models/other/distribution/base_value_distribution.py +3 -1
  201. endoreg_db/models/other/distribution/date_value_distribution.py +19 -5
  202. endoreg_db/models/other/distribution/multiple_categorical_value_distribution.py +3 -1
  203. endoreg_db/models/other/distribution/numeric_value_distribution.py +34 -9
  204. endoreg_db/models/other/emission/__init__.py +1 -1
  205. endoreg_db/models/other/emission/emission_factor.py +9 -3
  206. endoreg_db/models/other/information_source.py +15 -5
  207. endoreg_db/models/other/material.py +3 -1
  208. endoreg_db/models/other/transport_route.py +3 -1
  209. endoreg_db/models/other/unit.py +6 -2
  210. endoreg_db/models/report/report.py +0 -1
  211. endoreg_db/models/requirement/requirement.py +84 -27
  212. endoreg_db/models/requirement/requirement_error.py +5 -6
  213. endoreg_db/models/requirement/requirement_evaluation/__init__.py +1 -1
  214. endoreg_db/models/requirement/requirement_evaluation/evaluate_with_dependencies.py +8 -8
  215. endoreg_db/models/requirement/requirement_evaluation/get_values.py +3 -3
  216. endoreg_db/models/requirement/requirement_evaluation/requirement_type_parser.py +24 -8
  217. endoreg_db/models/requirement/requirement_operator.py +28 -8
  218. endoreg_db/models/requirement/requirement_set.py +34 -11
  219. endoreg_db/models/state/__init__.py +1 -0
  220. endoreg_db/models/state/audit_ledger.py +9 -2
  221. endoreg_db/models/{media → state}/processing_history/__init__.py +1 -3
  222. endoreg_db/models/state/processing_history/processing_history.py +136 -0
  223. endoreg_db/models/state/raw_pdf.py +0 -1
  224. endoreg_db/models/state/video.py +2 -4
  225. endoreg_db/models/utils.py +4 -2
  226. endoreg_db/queries/__init__.py +2 -6
  227. endoreg_db/queries/annotations/__init__.py +1 -3
  228. endoreg_db/queries/annotations/legacy.py +37 -26
  229. endoreg_db/root_urls.py +3 -4
  230. endoreg_db/schemas/examination_evaluation.py +3 -0
  231. endoreg_db/serializers/Frames_NICE_and_PARIS_classifications.py +249 -163
  232. endoreg_db/serializers/__init__.py +2 -8
  233. endoreg_db/serializers/administration/__init__.py +1 -2
  234. endoreg_db/serializers/administration/ai/__init__.py +0 -1
  235. endoreg_db/serializers/administration/ai/active_model.py +3 -1
  236. endoreg_db/serializers/administration/ai/ai_model.py +5 -3
  237. endoreg_db/serializers/administration/ai/model_type.py +3 -1
  238. endoreg_db/serializers/administration/center.py +7 -2
  239. endoreg_db/serializers/administration/gender.py +4 -2
  240. endoreg_db/serializers/anonymization.py +13 -13
  241. endoreg_db/serializers/evaluation/examination_evaluation.py +0 -1
  242. endoreg_db/serializers/examination/__init__.py +1 -1
  243. endoreg_db/serializers/examination/base.py +12 -13
  244. endoreg_db/serializers/examination/dropdown.py +6 -7
  245. endoreg_db/serializers/examination_serializer.py +3 -6
  246. endoreg_db/serializers/finding/__init__.py +1 -1
  247. endoreg_db/serializers/finding/finding.py +14 -7
  248. endoreg_db/serializers/finding_classification/__init__.py +3 -3
  249. endoreg_db/serializers/finding_classification/choice.py +3 -3
  250. endoreg_db/serializers/finding_classification/classification.py +2 -4
  251. endoreg_db/serializers/label_video_segment/__init__.py +5 -3
  252. endoreg_db/serializers/{label → label_video_segment}/image_classification_annotation.py +5 -5
  253. endoreg_db/serializers/label_video_segment/label/__init__.py +6 -0
  254. endoreg_db/serializers/{label → label_video_segment/label}/label.py +1 -1
  255. endoreg_db/serializers/label_video_segment/label_video_segment.py +338 -228
  256. endoreg_db/serializers/meta/__init__.py +1 -2
  257. endoreg_db/serializers/meta/sensitive_meta_detail.py +28 -13
  258. endoreg_db/serializers/meta/sensitive_meta_update.py +51 -46
  259. endoreg_db/serializers/meta/sensitive_meta_verification.py +19 -16
  260. endoreg_db/serializers/misc/__init__.py +2 -2
  261. endoreg_db/serializers/misc/file_overview.py +11 -7
  262. endoreg_db/serializers/misc/stats.py +10 -8
  263. endoreg_db/serializers/misc/translatable_field_mix_in.py +6 -6
  264. endoreg_db/serializers/misc/upload_job.py +32 -29
  265. endoreg_db/serializers/patient/__init__.py +2 -1
  266. endoreg_db/serializers/patient/patient.py +32 -15
  267. endoreg_db/serializers/patient/patient_dropdown.py +11 -3
  268. endoreg_db/serializers/patient_examination/__init__.py +1 -1
  269. endoreg_db/serializers/patient_examination/patient_examination.py +67 -40
  270. endoreg_db/serializers/patient_finding/__init__.py +1 -1
  271. endoreg_db/serializers/patient_finding/patient_finding.py +2 -1
  272. endoreg_db/serializers/patient_finding/patient_finding_classification.py +17 -9
  273. endoreg_db/serializers/patient_finding/patient_finding_detail.py +26 -17
  274. endoreg_db/serializers/patient_finding/patient_finding_intervention.py +7 -5
  275. endoreg_db/serializers/patient_finding/patient_finding_list.py +10 -11
  276. endoreg_db/serializers/patient_finding/patient_finding_write.py +36 -27
  277. endoreg_db/serializers/pdf/__init__.py +1 -3
  278. endoreg_db/serializers/requirements/requirement_schema.py +1 -6
  279. endoreg_db/serializers/sensitive_meta_serializer.py +100 -81
  280. endoreg_db/serializers/video/__init__.py +2 -2
  281. endoreg_db/serializers/video/{segmentation.py → video_file.py} +66 -47
  282. endoreg_db/serializers/video/video_file_brief.py +6 -2
  283. endoreg_db/serializers/video/video_file_detail.py +36 -23
  284. endoreg_db/serializers/video/video_file_list.py +4 -2
  285. endoreg_db/serializers/video/video_processing_history.py +54 -50
  286. endoreg_db/services/__init__.py +1 -1
  287. endoreg_db/services/anonymization.py +2 -2
  288. endoreg_db/services/examination_evaluation.py +40 -17
  289. endoreg_db/services/model_meta_from_hf.py +76 -0
  290. endoreg_db/services/polling_coordinator.py +101 -70
  291. endoreg_db/services/pseudonym_service.py +27 -22
  292. endoreg_db/services/report_import.py +6 -3
  293. endoreg_db/services/segment_sync.py +75 -59
  294. endoreg_db/services/video_import.py +6 -7
  295. endoreg_db/urls/__init__.py +2 -2
  296. endoreg_db/urls/ai.py +7 -25
  297. endoreg_db/urls/anonymization.py +61 -15
  298. endoreg_db/urls/auth.py +4 -4
  299. endoreg_db/urls/classification.py +4 -9
  300. endoreg_db/urls/examination.py +27 -18
  301. endoreg_db/urls/media.py +27 -34
  302. endoreg_db/urls/patient.py +11 -7
  303. endoreg_db/urls/requirements.py +3 -1
  304. endoreg_db/urls/root_urls.py +2 -3
  305. endoreg_db/urls/stats.py +24 -16
  306. endoreg_db/urls/upload.py +3 -11
  307. endoreg_db/utils/__init__.py +14 -15
  308. endoreg_db/utils/ai/__init__.py +1 -1
  309. endoreg_db/utils/ai/data_loader_for_model_input.py +262 -0
  310. endoreg_db/utils/ai/data_loader_for_model_training.py +262 -0
  311. endoreg_db/utils/ai/get.py +2 -1
  312. endoreg_db/utils/ai/inference_dataset.py +14 -15
  313. endoreg_db/utils/ai/model_training/config.py +117 -0
  314. endoreg_db/utils/ai/model_training/dataset.py +74 -0
  315. endoreg_db/utils/ai/model_training/losses.py +68 -0
  316. endoreg_db/utils/ai/model_training/metrics.py +78 -0
  317. endoreg_db/utils/ai/model_training/model_backbones.py +155 -0
  318. endoreg_db/utils/ai/model_training/model_gastronet_resnet.py +118 -0
  319. endoreg_db/utils/ai/model_training/trainer_gastronet_multilabel.py +771 -0
  320. endoreg_db/utils/ai/multilabel_classification_net.py +21 -6
  321. endoreg_db/utils/ai/predict.py +4 -4
  322. endoreg_db/utils/ai/preprocess.py +19 -11
  323. endoreg_db/utils/calc_duration_seconds.py +4 -4
  324. endoreg_db/utils/case_generator/lab_sample_factory.py +3 -4
  325. endoreg_db/utils/check_video_files.py +74 -47
  326. endoreg_db/utils/cropping.py +10 -9
  327. endoreg_db/utils/dataloader.py +11 -3
  328. endoreg_db/utils/dates.py +3 -4
  329. endoreg_db/utils/defaults/set_default_center.py +7 -6
  330. endoreg_db/utils/env.py +6 -2
  331. endoreg_db/utils/extract_specific_frames.py +24 -9
  332. endoreg_db/utils/file_operations.py +30 -18
  333. endoreg_db/utils/fix_video_path_direct.py +57 -41
  334. endoreg_db/utils/frame_anonymization_utils.py +157 -157
  335. endoreg_db/utils/hashs.py +3 -18
  336. endoreg_db/utils/links/requirement_link.py +96 -52
  337. endoreg_db/utils/ocr.py +30 -25
  338. endoreg_db/utils/operation_log.py +61 -0
  339. endoreg_db/utils/parse_and_generate_yaml.py +12 -13
  340. endoreg_db/utils/paths.py +6 -6
  341. endoreg_db/utils/permissions.py +40 -24
  342. endoreg_db/utils/pipelines/process_video_dir.py +50 -26
  343. endoreg_db/utils/product/sum_emissions.py +5 -3
  344. endoreg_db/utils/product/sum_weights.py +4 -2
  345. endoreg_db/utils/pydantic_models/__init__.py +3 -4
  346. endoreg_db/utils/requirement_operator_logic/_old/lab_value_operators.py +207 -107
  347. endoreg_db/utils/requirement_operator_logic/_old/model_evaluators.py +252 -65
  348. endoreg_db/utils/requirement_operator_logic/new_operator_logic.py +27 -10
  349. endoreg_db/utils/setup_config.py +21 -5
  350. endoreg_db/utils/storage.py +3 -1
  351. endoreg_db/utils/translation.py +19 -15
  352. endoreg_db/utils/uuid.py +1 -0
  353. endoreg_db/utils/validate_endo_roi.py +12 -4
  354. endoreg_db/utils/validate_subcategory_dict.py +26 -24
  355. endoreg_db/utils/validate_video_detailed.py +207 -149
  356. endoreg_db/utils/video/__init__.py +7 -3
  357. endoreg_db/utils/video/extract_frames.py +30 -18
  358. endoreg_db/utils/video/names.py +11 -6
  359. endoreg_db/utils/video/streaming_processor.py +175 -101
  360. endoreg_db/utils/video/video_splitter.py +30 -19
  361. endoreg_db/views/Frames_NICE_and_PARIS_classifications_views.py +59 -50
  362. endoreg_db/views/__init__.py +0 -20
  363. endoreg_db/views/anonymization/__init__.py +6 -2
  364. endoreg_db/views/anonymization/media_management.py +2 -6
  365. endoreg_db/views/anonymization/overview.py +34 -1
  366. endoreg_db/views/anonymization/validate.py +79 -18
  367. endoreg_db/views/auth/__init__.py +1 -1
  368. endoreg_db/views/auth/keycloak.py +16 -14
  369. endoreg_db/views/examination/__init__.py +12 -15
  370. endoreg_db/views/examination/examination.py +5 -5
  371. endoreg_db/views/examination/examination_manifest_cache.py +5 -5
  372. endoreg_db/views/examination/get_finding_classification_choices.py +8 -5
  373. endoreg_db/views/examination/get_finding_classifications.py +9 -7
  374. endoreg_db/views/examination/get_findings.py +8 -10
  375. endoreg_db/views/examination/get_instruments.py +3 -2
  376. endoreg_db/views/examination/get_interventions.py +1 -1
  377. endoreg_db/views/finding/__init__.py +2 -2
  378. endoreg_db/views/finding/finding.py +58 -54
  379. endoreg_db/views/finding/get_classifications.py +1 -1
  380. endoreg_db/views/finding/get_interventions.py +1 -1
  381. endoreg_db/views/finding_classification/__init__.py +5 -5
  382. endoreg_db/views/finding_classification/finding_classification.py +5 -6
  383. endoreg_db/views/finding_classification/get_classification_choices.py +3 -4
  384. endoreg_db/views/media/__init__.py +13 -13
  385. endoreg_db/views/media/pdf_media.py +9 -9
  386. endoreg_db/views/media/sensitive_metadata.py +10 -7
  387. endoreg_db/views/media/video_media.py +4 -4
  388. endoreg_db/views/meta/__init__.py +1 -1
  389. endoreg_db/views/meta/sensitive_meta_list.py +20 -22
  390. endoreg_db/views/meta/sensitive_meta_verification.py +14 -11
  391. endoreg_db/views/misc/__init__.py +6 -34
  392. endoreg_db/views/misc/center.py +2 -1
  393. endoreg_db/views/misc/csrf.py +2 -1
  394. endoreg_db/views/misc/gender.py +2 -1
  395. endoreg_db/views/misc/stats.py +141 -106
  396. endoreg_db/views/patient/__init__.py +1 -3
  397. endoreg_db/views/patient/patient.py +141 -99
  398. endoreg_db/views/patient_examination/__init__.py +5 -5
  399. endoreg_db/views/patient_examination/patient_examination.py +43 -42
  400. endoreg_db/views/patient_examination/patient_examination_create.py +10 -15
  401. endoreg_db/views/patient_examination/patient_examination_detail.py +12 -15
  402. endoreg_db/views/patient_examination/patient_examination_list.py +21 -17
  403. endoreg_db/views/patient_examination/video.py +114 -80
  404. endoreg_db/views/patient_finding/__init__.py +1 -1
  405. endoreg_db/views/patient_finding/patient_finding.py +17 -10
  406. endoreg_db/views/patient_finding/patient_finding_optimized.py +127 -95
  407. endoreg_db/views/patient_finding_classification/__init__.py +1 -1
  408. endoreg_db/views/patient_finding_classification/pfc_create.py +35 -27
  409. endoreg_db/views/report/reimport.py +1 -1
  410. endoreg_db/views/report/report_stream.py +5 -8
  411. endoreg_db/views/requirement/__init__.py +2 -1
  412. endoreg_db/views/requirement/evaluate.py +7 -9
  413. endoreg_db/views/requirement/lookup.py +2 -3
  414. endoreg_db/views/requirement/lookup_store.py +0 -1
  415. endoreg_db/views/requirement/requirement_utils.py +2 -4
  416. endoreg_db/views/stats/__init__.py +4 -4
  417. endoreg_db/views/stats/stats_views.py +152 -115
  418. endoreg_db/views/video/__init__.py +18 -27
  419. endoreg_db/views/{ai → video/ai}/__init__.py +2 -2
  420. endoreg_db/views/{ai → video/ai}/label.py +20 -16
  421. endoreg_db/views/video/correction.py +5 -6
  422. endoreg_db/views/video/reimport.py +134 -99
  423. endoreg_db/views/video/segments_crud.py +134 -44
  424. endoreg_db/views/video/video_apply_mask.py +13 -12
  425. endoreg_db/views/video/video_correction.py +2 -1
  426. endoreg_db/views/video/video_download_processed.py +15 -15
  427. endoreg_db/views/video/video_meta_stats.py +7 -6
  428. endoreg_db/views/video/video_processing_history.py +3 -2
  429. endoreg_db/views/video/video_remove_frames.py +13 -12
  430. endoreg_db/views/video/video_stream.py +110 -82
  431. {endoreg_db-0.8.9.2.dist-info → endoreg_db-0.8.9.10.dist-info}/METADATA +9 -3
  432. {endoreg_db-0.8.9.2.dist-info → endoreg_db-0.8.9.10.dist-info}/RECORD +434 -431
  433. endoreg_db/management/commands/import_fallback_video.py +0 -203
  434. endoreg_db/management/commands/import_video.py +0 -422
  435. endoreg_db/management/commands/import_video_with_classification.py +0 -367
  436. endoreg_db/models/media/processing_history/processing_history.py +0 -96
  437. endoreg_db/serializers/label/__init__.py +0 -7
  438. endoreg_db/serializers/label_video_segment/_lvs_create.py +0 -149
  439. endoreg_db/serializers/label_video_segment/_lvs_update.py +0 -138
  440. endoreg_db/serializers/label_video_segment/_lvs_validate.py +0 -149
  441. endoreg_db/serializers/label_video_segment/label_video_segment_annotation.py +0 -99
  442. endoreg_db/serializers/label_video_segment/label_video_segment_update.py +0 -163
  443. endoreg_db/services/__old/pdf_import.py +0 -1487
  444. endoreg_db/services/__old/video_import.py +0 -1306
  445. endoreg_db/tasks/upload_tasks.py +0 -216
  446. endoreg_db/tasks/video_ingest.py +0 -161
  447. endoreg_db/tasks/video_processing_tasks.py +0 -327
  448. endoreg_db/views/misc/translation.py +0 -182
  449. {endoreg_db-0.8.9.2.dist-info → endoreg_db-0.8.9.10.dist-info}/WHEEL +0 -0
  450. {endoreg_db-0.8.9.2.dist-info → endoreg_db-0.8.9.10.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,771 @@
1
+ # endoreg_db/utils/ai/model_training/trainer_gastronet_multilabel.py
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import random
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Sequence, Tuple
9
+
10
+
11
+ import torch
12
+ from torch.utils.data import DataLoader
13
+ from torch.optim.lr_scheduler import CosineAnnealingLR
14
+
15
+ from django.db import models
16
+
17
+ from endoreg_db.models import AIDataSet
18
+ from endoreg_db.utils.ai.data_loader_for_model_input import build_dataset_for_training
19
+ from endoreg_db.utils.ai.model_training.config import (
20
+ TrainingConfig,
21
+ RUNS_DIR,
22
+ )
23
+ from endoreg_db.utils.ai.model_training.dataset import EndoMultiLabelDataset
24
+ from endoreg_db.utils.ai.model_training.losses import (
25
+ compute_class_weights,
26
+ focal_loss_with_mask,
27
+ )
28
+ from endoreg_db.utils.ai.model_training.metrics import compute_metrics
29
+
30
+ from endoreg_db.utils.ai.model_training.model_backbones import (
31
+ create_multilabel_model,
32
+ )
33
+
34
+ # ---------------------------------------------------------------------
35
+ # HELPER: FILTER LABELS BY LABELSET VERSION
36
+ # ---------------------------------------------------------------------
37
+
38
+
39
+ def filter_labels_by_labelset_version(
40
+ labels: Sequence[models.Model],
41
+ label_vectors: Sequence[Sequence[Optional[int]]],
42
+ label_masks: Sequence[Sequence[int]],
43
+ target_version: int,
44
+ ) -> Tuple[
45
+ List[List[Optional[int]]],
46
+ List[List[int]],
47
+ List[models.Model],
48
+ List[int],
49
+ ]:
50
+ """
51
+ From the full label list + vectors, keep ONLY those labels that belong
52
+ to ANY LabelSet with version == target_version.
53
+
54
+ labels: list[Label]
55
+ label_vectors: list[list[0/1/None]] (len = N samples)
56
+ label_masks: list[list[0/1]] (len = N samples)
57
+ target_version: integer LabelSet.version to filter by.
58
+
59
+ Returns:
60
+ filtered_label_vectors,
61
+ filtered_label_masks,
62
+ filtered_labels,
63
+ kept_indices (original label indices kept)
64
+ """
65
+ kept_indices: List[int] = []
66
+
67
+ for idx, lbl in enumerate(labels):
68
+ # lbl.label_sets is the M2M relation "LabelSet.labels"
69
+ if lbl.label_sets.filter(version=target_version).exists():
70
+ kept_indices.append(idx)
71
+
72
+ if not kept_indices:
73
+ raise ValueError(
74
+ f"No labels in this dataset belong to any LabelSet with version={target_version}. "
75
+ "Check your LabelSet configuration or change labelset_version_to_train "
76
+ "in config.py."
77
+ )
78
+
79
+ # Slice vectors + masks to keep only the chosen label indices
80
+ filtered_vectors: List[List[Optional[int]]] = []
81
+ filtered_masks: List[List[int]] = []
82
+
83
+ for vec, mask in zip(label_vectors, label_masks):
84
+ new_vec = [vec[j] for j in kept_indices]
85
+ new_mask = [mask[j] for j in kept_indices]
86
+ filtered_vectors.append(new_vec)
87
+ filtered_masks.append(new_mask)
88
+
89
+ filtered_labels = [labels[j] for j in kept_indices]
90
+
91
+ return filtered_vectors, filtered_masks, filtered_labels, kept_indices
92
+
93
+
94
+ # ---------------------------------------------------------------------
95
+ # GROUP-WISE SPLIT BY old_examination_id
96
+ # ---------------------------------------------------------------------
97
+
98
+
99
+ def groupwise_split_indices_by_examination(
100
+ frame_ids: Sequence[int],
101
+ old_examination_ids: Sequence[Optional[int]],
102
+ val_split: float,
103
+ test_split: float,
104
+ seed: int = 42,
105
+ ) -> Tuple[List[int], List[int], List[int]]:
106
+ """
107
+ Split sample indices into train / val / test based on old_examination_id.
108
+
109
+ All frames sharing the same old_examination_id go into the same split.
110
+ If old_examination_id is None, we treat each frame as its own group.
111
+
112
+ Returns:
113
+ train_indices, val_indices, test_indices
114
+ """
115
+ assert len(frame_ids) == len(old_examination_ids)
116
+
117
+ # 1) Build mapping: group_id -> list of sample indices
118
+ groups: Dict[object, List[int]] = {}
119
+ for idx, (fid, exam_id) in enumerate(zip(frame_ids, old_examination_ids)):
120
+ group_key = exam_id if exam_id is not None else f"no_exam_{fid}"
121
+ groups.setdefault(group_key, []).append(idx)
122
+
123
+ group_ids = list(groups.keys())
124
+ rng = random.Random(seed)
125
+ rng.shuffle(group_ids)
126
+
127
+ n_groups = len(group_ids)
128
+ n_test = int(round(test_split * n_groups))
129
+ n_val = int(round(val_split * n_groups))
130
+ n_train = n_groups - n_val - n_test
131
+
132
+ train_group_ids = group_ids[:n_train]
133
+ val_group_ids = group_ids[n_train : n_train + n_val]
134
+ test_group_ids = group_ids[n_train + n_val :]
135
+
136
+ train_indices: List[int] = []
137
+ val_indices: List[int] = []
138
+ test_indices: List[int] = []
139
+
140
+ for gid in train_group_ids:
141
+ train_indices.extend(groups[gid])
142
+ for gid in val_group_ids:
143
+ val_indices.extend(groups[gid])
144
+ for gid in test_group_ids:
145
+ test_indices.extend(groups[gid])
146
+
147
+ # Sort indices for reproducibility
148
+ train_indices.sort()
149
+ val_indices.sort()
150
+ test_indices.sort()
151
+
152
+ print(
153
+ f"[TRAIN] Group-wise split by old_examination_id: "
154
+ f"#groups={n_groups}, train_groups={len(train_group_ids)}, "
155
+ f"val_groups={len(val_group_ids)}, test_groups={len(test_group_ids)}"
156
+ )
157
+
158
+ return train_indices, val_indices, test_indices
159
+
160
+
161
+ # ---------------------------------------------------------------------
162
+ # MAIN TRAINING FUNCTION
163
+ # ---------------------------------------------------------------------
164
+
165
+
166
+ def train_gastronet_multilabel(config: TrainingConfig) -> Dict:
167
+ """
168
+ High-level training entry point.
169
+
170
+ Pipeline:
171
+ 1. Load AIDataSet from DB and build raw dataset via build_dataset_for_training.
172
+ 2. Filter labels by LabelSet.version == config.labelset_version_to_train.
173
+ 3. Optionally convert unlabeled → negative (Option A).
174
+ 4. Compute dataset statistics (positives per label, etc.).
175
+ 5. Group-wise split by old_examination_id into train/val/test.
176
+ 6. Wrap in PyTorch Dataset + DataLoaders.
177
+ 7. Build GastroNet-ResNet50 backbone + new head.
178
+ 8. Train with focal loss + class weights (+ mask).
179
+ 9. LR schedule: warm-up + cosine decay (if enabled).
180
+ 10. Save model + metadata in model_training/runs.
181
+ """
182
+ # ------------------------------------------------------------------
183
+ # 1. Load dataset from DB
184
+ # ------------------------------------------------------------------
185
+ dataset_obj = AIDataSet.objects.get(id=config.dataset_id)
186
+ data = build_dataset_for_training(dataset_obj)
187
+
188
+ image_paths: List[str] = data["image_paths"]
189
+ label_vectors: List[List[Optional[int]]] = data["label_vectors"]
190
+ label_masks: List[List[int]] = data["label_masks"]
191
+ labels = data["labels"] # list[Label]
192
+ labelset = data["labelset"]
193
+ frame_ids: List[int] = data.get("frame_ids", [])
194
+ old_exam_ids: List[Optional[int]] = data.get("old_examination_ids", [])
195
+
196
+ num_samples_raw = len(image_paths)
197
+ num_labels_raw = len(labels)
198
+
199
+ print(f"[TRAIN] AIDataSet id={dataset_obj.id}")
200
+ print(
201
+ f"[TRAIN] #samples (raw) = {num_samples_raw}, #labels (raw) = {num_labels_raw}"
202
+ )
203
+ print(
204
+ f"[TRAIN] LabelSet id={labelset.id}, "
205
+ f"name={labelset.name}, version={labelset.version}"
206
+ )
207
+ print("[TRAIN] Labels (raw):")
208
+ for idx, lbl in enumerate(labels):
209
+ print(f" [{idx}] {lbl.name}")
210
+
211
+ # ------------------------------------------------------------------
212
+ # 2. Filter labels by LabelSet.version == config.labelset_version_to_train
213
+ # ------------------------------------------------------------------
214
+ target_version = config.labelset_version_to_train
215
+ print(
216
+ f"[TRAIN] Filtering labels to those belonging to ANY LabelSet with version={target_version}..."
217
+ )
218
+
219
+ (
220
+ label_vectors,
221
+ label_masks,
222
+ labels,
223
+ kept_indices,
224
+ ) = filter_labels_by_labelset_version(
225
+ labels=labels,
226
+ label_vectors=label_vectors,
227
+ label_masks=label_masks,
228
+ target_version=target_version,
229
+ )
230
+
231
+ num_labels_filtered = len(labels)
232
+ print(
233
+ f"[TRAIN] Label filtering done. "
234
+ f"Kept {num_labels_filtered} / {num_labels_raw} labels."
235
+ )
236
+ print("[TRAIN] Kept labels (new index -> original index -> name):")
237
+ for new_idx, orig_idx in enumerate(kept_indices):
238
+ print(f" [{new_idx}] (orig {orig_idx}) {labels[new_idx].name}")
239
+
240
+ # ------------------------------------------------------------------
241
+ # 2b. OPTION A: treat UNLABELED v2 labels as NEGATIVE (0) + KNOWN
242
+ # ------------------------------------------------------------------
243
+ # After filtering to the target version, we decide how to interpret
244
+ # unlabeled entries:
245
+ #
246
+ # If treat_unlabeled_as_negative == True:
247
+ # vec[j] == 1 -> positive, mask[j] = 1
248
+ # vec[j] is None -> assume 0 (negative), mask[j] = 1
249
+ #
250
+ # If False:
251
+ # vec[j] is None -> value 0, but mask[j] = 0 (ignored)
252
+ #
253
+ # In your current setup you want Option A (True).
254
+ if config.treat_unlabeled_as_negative:
255
+ for i in range(len(label_vectors)):
256
+ vec = label_vectors[i]
257
+ mask = label_masks[i]
258
+
259
+ new_vec = []
260
+ new_mask = []
261
+ for x in vec:
262
+ if x is None:
263
+ # unlabeled -> assume negative but KNOWN
264
+ new_vec.append(0)
265
+ new_mask.append(1)
266
+ else:
267
+ # explicit label (1 or 0) -> keep value, mark as known
268
+ new_vec.append(int(x))
269
+ new_mask.append(1)
270
+
271
+ label_vectors[i] = new_vec
272
+ label_masks[i] = new_mask
273
+ else:
274
+ # Respect original semantics: None = unknown -> mask=0
275
+ cleaned_vectors = []
276
+ cleaned_masks = []
277
+ for vec, mask in zip(label_vectors, label_masks):
278
+ v = []
279
+ m = []
280
+ for x, ms in zip(vec, mask):
281
+ if x is None:
282
+ v.append(0) # value won't be used
283
+ m.append(0) # unknown -> ignore in loss/metrics
284
+ else:
285
+ v.append(int(x)) # 0 or 1
286
+ m.append(int(ms))
287
+ cleaned_vectors.append(v)
288
+ cleaned_masks.append(m)
289
+
290
+ label_vectors = cleaned_vectors
291
+ label_masks = cleaned_masks
292
+
293
+ # ------------------------------------------------------------------
294
+ # 3. Dataset statistics AFTER filtering + Option A conversion
295
+ # ------------------------------------------------------------------
296
+ labels_arr = []
297
+ masks_arr = []
298
+ for vec, mask in zip(label_vectors, label_masks):
299
+ v = [int(x) for x in vec] # now guaranteed 0/1
300
+ m = [int(x) for x in mask] # typically 1
301
+ labels_arr.append(v)
302
+ masks_arr.append(m)
303
+
304
+ labels_tensor = torch.tensor(labels_arr, dtype=torch.float32)
305
+ masks_tensor = torch.tensor(masks_arr, dtype=torch.float32)
306
+
307
+ total_known = masks_tensor.sum().item()
308
+ total_pos = (labels_tensor * masks_tensor).sum().item()
309
+
310
+ print("[DEBUG] Dataset statistics AFTER label filtering:")
311
+ print(f" #samples = {len(image_paths)}")
312
+ print(f" #labels = {num_labels_filtered}")
313
+ print(f" total known entries= {total_known}")
314
+ print(f" total positive labels (over known) = {total_pos}")
315
+
316
+ pos_per_label = (labels_tensor * masks_tensor).sum(dim=0).tolist()
317
+ print("[DEBUG] Positives per label (index: count):")
318
+ for idx, c in enumerate(pos_per_label):
319
+ print(f" [{idx}] = {int(c)}")
320
+
321
+ # ------------------------------------------------------------------
322
+ # 4. Group-wise split by old_examination_id (train/val/test)
323
+ # ------------------------------------------------------------------
324
+ if not frame_ids or not old_exam_ids:
325
+ frame_ids = list(range(len(image_paths)))
326
+ old_exam_ids = [None] * len(image_paths)
327
+
328
+ train_indices, val_indices, test_indices = groupwise_split_indices_by_examination(
329
+ frame_ids=frame_ids,
330
+ old_examination_ids=old_exam_ids,
331
+ val_split=config.val_split,
332
+ test_split=config.test_split,
333
+ seed=config.random_seed,
334
+ )
335
+
336
+ print(
337
+ f"[TRAIN] Train size: {len(train_indices)}, "
338
+ f"Val size: {len(val_indices)}, "
339
+ f"Test size: {len(test_indices)}"
340
+ )
341
+
342
+ # ------------------------------------------------------------------
343
+ # 5. Build PyTorch datasets + loaders
344
+ # ------------------------------------------------------------------
345
+ full_ds = EndoMultiLabelDataset(
346
+ image_paths=image_paths,
347
+ label_vectors=label_vectors,
348
+ label_masks=label_masks,
349
+ image_size=224,
350
+ )
351
+
352
+ def subset_dataset(
353
+ ds: EndoMultiLabelDataset, indices: List[int]
354
+ ) -> EndoMultiLabelDataset:
355
+ sub_image_paths = [ds.image_paths[i] for i in indices]
356
+ sub_labels = ds.labels[indices]
357
+ sub_masks = ds.masks[indices]
358
+
359
+ sub_label_vectors = sub_labels.tolist()
360
+ sub_label_masks = sub_masks.tolist()
361
+ return EndoMultiLabelDataset(
362
+ image_paths=sub_image_paths,
363
+ label_vectors=sub_label_vectors,
364
+ label_masks=sub_label_masks,
365
+ image_size=ds.image_size,
366
+ )
367
+
368
+ train_ds = subset_dataset(full_ds, train_indices)
369
+ val_ds = subset_dataset(full_ds, val_indices)
370
+ test_ds = subset_dataset(full_ds, test_indices)
371
+
372
+ train_loader = DataLoader(
373
+ train_ds,
374
+ batch_size=config.batch_size,
375
+ shuffle=True,
376
+ num_workers=4,
377
+ pin_memory=True,
378
+ )
379
+ val_loader = DataLoader(
380
+ val_ds,
381
+ batch_size=config.batch_size,
382
+ shuffle=False,
383
+ num_workers=4,
384
+ pin_memory=True,
385
+ )
386
+ test_loader = DataLoader(
387
+ test_ds,
388
+ batch_size=config.batch_size,
389
+ shuffle=False,
390
+ num_workers=4,
391
+ pin_memory=True,
392
+ )
393
+
394
+ # ------------------------------------------------------------------
395
+ # 6. Build model
396
+ # ------------------------------------------------------------------
397
+ if config.device == "auto":
398
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
399
+ else:
400
+ device = torch.device(config.device)
401
+
402
+ """backbone_ckpt = (
403
+ Path(config.backbone_checkpoint)
404
+ if config.backbone_checkpoint is not None
405
+ else None
406
+ )
407
+
408
+ model = GastroNetResNet50MultiLabel(
409
+ num_labels=num_labels_filtered,
410
+ backbone_checkpoint=backbone_ckpt,
411
+ freeze_backbone=True, # start with head-only training
412
+ )
413
+ model.to(device)"""
414
+
415
+ backbone_ckpt = (
416
+ Path(config.backbone_checkpoint)
417
+ if config.backbone_checkpoint is not None
418
+ else None
419
+ )
420
+
421
+ model = create_multilabel_model(
422
+ backbone_name=config.backbone_name,
423
+ num_labels=num_labels_filtered,
424
+ backbone_checkpoint=backbone_ckpt,
425
+ freeze_backbone=config.freeze_backbone,
426
+ )
427
+ model.to(device)
428
+
429
+ # ------------------------------------------------------------------
430
+ # 7. Class weights from full (filtered) dataset
431
+ # ------------------------------------------------------------------
432
+ class_weights = compute_class_weights(full_ds.labels, full_ds.masks).to(device)
433
+ print("[TRAIN] Computed class weights per label:", class_weights.cpu().tolist())
434
+ print(
435
+ "[DEBUG] class_weights range: "
436
+ f"min={float(class_weights.min()):.6f}, max={float(class_weights.max()):.6f}"
437
+ )
438
+
439
+ # ------------------------------------------------------------------
440
+ # 8. Optimizer + LR SCHEDULER (warm-up + cosine)
441
+ # ------------------------------------------------------------------
442
+ head_params = list(model.classifier.parameters())
443
+ backbone_params = [p for p in model.backbone.parameters() if p.requires_grad]
444
+
445
+ optimizer = torch.optim.AdamW(
446
+ [
447
+ {"params": head_params, "lr": config.lr_head},
448
+ {"params": backbone_params, "lr": config.lr_backbone},
449
+ ]
450
+ )
451
+
452
+ # Store base LRs for warm-up
453
+ base_lrs = [config.lr_head, config.lr_backbone]
454
+
455
+ if config.use_scheduler:
456
+ total_epochs = config.num_epochs
457
+ warmup_epochs = max(config.warmup_epochs, 0)
458
+ # We apply cosine decay AFTER warm-up
459
+ t_max = max(total_epochs - warmup_epochs, 1)
460
+
461
+ scheduler = CosineAnnealingLR(
462
+ optimizer,
463
+ T_max=t_max,
464
+ eta_min=config.min_lr,
465
+ )
466
+ print(
467
+ f"[LR] Using warm-up + cosine decay: warmup_epochs={warmup_epochs}, "
468
+ f"T_max={t_max}, min_lr={config.min_lr}"
469
+ )
470
+ else:
471
+ scheduler = None
472
+ warmup_epochs = 0
473
+ print("[LR] No LR scheduler used (fixed learning rate).")
474
+
475
+ # ------------------------------------------------------------------
476
+ # 9. Training loop
477
+ # ------------------------------------------------------------------
478
+ history = {"train_loss": [], "val_loss": [], "test_loss": None}
479
+
480
+ # One-time debug of first batch
481
+ first_batch = next(iter(train_loader))
482
+ imgs_dbg, y_dbg, m_dbg = first_batch
483
+ print("[DEBUG] First training batch shapes:")
484
+ print(" imgs:", imgs_dbg.shape)
485
+ print(" y: ", y_dbg.shape)
486
+ print(" m: ", m_dbg.shape)
487
+ print("[DEBUG] First sample labels (y[0]):")
488
+ print(y_dbg[0].tolist())
489
+ print("[DEBUG] First sample mask (m[0]):")
490
+ print(m_dbg[0].tolist())
491
+
492
+ model.eval()
493
+ with torch.no_grad():
494
+ logits_dbg = model(imgs_dbg.to(device))
495
+ probs_dbg = torch.sigmoid(logits_dbg)
496
+ print("[DEBUG] First sample logits:")
497
+ print(logits_dbg[0].cpu().tolist())
498
+ print("[DEBUG] First sample probs (sigmoid):")
499
+ print(probs_dbg[0].cpu().tolist())
500
+
501
+ for epoch in range(1, config.num_epochs + 1):
502
+ # ----------------- LR SCHEDULER: warm-up + cosine ----------------
503
+ if scheduler is not None:
504
+ if warmup_epochs > 0 and epoch <= warmup_epochs:
505
+ # Linear warm-up: start from 0 → base_lr over warmup_epochs
506
+ warmup_factor = epoch / float(warmup_epochs)
507
+ for i, pg in enumerate(optimizer.param_groups):
508
+ pg["lr"] = base_lrs[i] * warmup_factor
509
+ else:
510
+ # After warm-up, step cosine scheduler once per epoch
511
+ scheduler.step()
512
+
513
+ current_lrs = [pg["lr"] for pg in optimizer.param_groups]
514
+ print(
515
+ f"[LR] Epoch {epoch:03d}: "
516
+ f"head_lr={current_lrs[0]:.6g}, backbone_lr={current_lrs[1]:.6g}"
517
+ )
518
+
519
+ # ----------------- TRAIN PHASE -----------------------------------
520
+ model.train()
521
+ train_loss_sum = 0.0
522
+ train_batches = 0
523
+
524
+ for imgs, y, m in train_loader:
525
+ imgs = imgs.to(device, non_blocking=True)
526
+ y = y.to(device, non_blocking=True)
527
+ m = m.to(device, non_blocking=True)
528
+
529
+ optimizer.zero_grad()
530
+ logits = model(imgs)
531
+
532
+ loss = focal_loss_with_mask(
533
+ logits=logits,
534
+ targets=y,
535
+ masks=m,
536
+ class_weights=class_weights,
537
+ alpha=config.alpha_focal,
538
+ gamma=config.gamma_focal,
539
+ )
540
+ loss.backward()
541
+ optimizer.step()
542
+
543
+ train_loss_sum += loss.item()
544
+ train_batches += 1
545
+
546
+ train_loss = train_loss_sum / max(train_batches, 1)
547
+ history["train_loss"].append(train_loss)
548
+
549
+ # ----------------- VALIDATION PHASE ------------------------------
550
+ model.eval()
551
+ val_loss_sum = 0.0
552
+ val_batches = 0
553
+
554
+ all_val_logits = []
555
+ all_val_targets = []
556
+ all_val_masks = []
557
+
558
+ with torch.no_grad():
559
+ for imgs, y, m in val_loader:
560
+ imgs = imgs.to(device, non_blocking=True)
561
+ y = y.to(device, non_blocking=True)
562
+ m = m.to(device, non_blocking=True)
563
+
564
+ logits = model(imgs)
565
+ loss = focal_loss_with_mask(
566
+ logits=logits,
567
+ targets=y,
568
+ masks=m,
569
+ class_weights=class_weights,
570
+ alpha=config.alpha_focal,
571
+ gamma=config.gamma_focal,
572
+ )
573
+ val_loss_sum += loss.item()
574
+ val_batches += 1
575
+
576
+ all_val_logits.append(logits)
577
+ all_val_targets.append(y)
578
+ all_val_masks.append(m)
579
+
580
+ val_loss = val_loss_sum / max(val_batches, 1)
581
+ history["val_loss"].append(val_loss)
582
+
583
+ all_val_logits = torch.cat(all_val_logits, dim=0)
584
+ all_val_targets = torch.cat(all_val_targets, dim=0)
585
+ all_val_masks = torch.cat(all_val_masks, dim=0)
586
+
587
+ val_metrics = compute_metrics(
588
+ logits=all_val_logits,
589
+ targets=all_val_targets,
590
+ masks=all_val_masks,
591
+ threshold=0.5,
592
+ )
593
+
594
+ print(
595
+ f"[VAL METRICS] "
596
+ f"Precision={val_metrics['precision']:.4f} "
597
+ f"Recall={val_metrics['recall']:.4f} "
598
+ f"F1={val_metrics['f1']:.4f} "
599
+ f"Acc={val_metrics['accuracy']:.4f} "
600
+ f"TP={val_metrics['tp']} FP={val_metrics['fp']} "
601
+ f"TN={val_metrics['tn']} FN={val_metrics['fn']}"
602
+ )
603
+
604
+ print(
605
+ f"[EPOCH {epoch:03d}/{config.num_epochs:03d}] "
606
+ f"train_loss={train_loss:.4f} val_loss={val_loss:.4f}"
607
+ )
608
+
609
+ # Print table of per-label metrics
610
+ print("\n[VAL PER-LABEL METRICS]")
611
+ print(f"{'Label':20s} {'Prec':>8s} {'Rec':>8s} {'F1':>8s} {'Support':>8s}")
612
+ print("-" * 60)
613
+
614
+ for j, stats in enumerate(val_metrics["per_label"]):
615
+ name = labels[j].name
616
+ p = stats["precision"]
617
+ r = stats["recall"]
618
+ f = stats["f1"]
619
+ sup = stats["support"]
620
+
621
+ if p is None:
622
+ print(f"{name:20s} {'N/A':>8} {'N/A':>8} {'N/A':>8} {sup:8d}")
623
+ else:
624
+ print(f"{name:20s} {p:8.4f} {r:8.4f} {f:8.4f} {sup:8d}")
625
+
626
+ print("-" * 60)
627
+
628
+ # ------------------------------------------------------------------
629
+ # 10. Final test loss + metrics
630
+ # ------------------------------------------------------------------
631
+ model.eval()
632
+ test_loss_sum = 0.0
633
+ test_batches = 0
634
+
635
+ all_test_logits = []
636
+ all_test_targets = []
637
+ all_test_masks = []
638
+
639
+ with torch.no_grad():
640
+ for imgs, y, m in test_loader:
641
+ imgs = imgs.to(device, non_blocking=True)
642
+ y = y.to(device, non_blocking=True)
643
+ m = m.to(device, non_blocking=True)
644
+
645
+ logits = model(imgs)
646
+ loss = focal_loss_with_mask(
647
+ logits=logits,
648
+ targets=y,
649
+ masks=m,
650
+ class_weights=class_weights,
651
+ alpha=config.alpha_focal,
652
+ gamma=config.gamma_focal,
653
+ )
654
+ test_loss_sum += loss.item()
655
+ test_batches += 1
656
+
657
+ all_test_logits.append(logits)
658
+ all_test_targets.append(y)
659
+ all_test_masks.append(m)
660
+
661
+ test_loss = test_loss_sum / max(test_batches, 1)
662
+ history["test_loss"] = test_loss
663
+ print(f"[TEST] test_loss={test_loss:.4f}")
664
+
665
+ all_test_logits = torch.cat(all_test_logits, dim=0)
666
+ all_test_targets = torch.cat(all_test_targets, dim=0)
667
+ all_test_masks = torch.cat(all_test_masks, dim=0)
668
+
669
+ test_metrics = compute_metrics(
670
+ logits=all_test_logits,
671
+ targets=all_test_targets,
672
+ masks=all_test_masks,
673
+ threshold=0.5,
674
+ )
675
+
676
+ print(
677
+ f"[TEST METRICS] "
678
+ f"Precision={test_metrics['precision']:.4f} "
679
+ f"Recall={test_metrics['recall']:.4f} "
680
+ f"F1={test_metrics['f1']:.4f} "
681
+ f"Acc={test_metrics['accuracy']:.4f} "
682
+ f"TP={test_metrics['tp']} FP={test_metrics['fp']} "
683
+ f"TN={test_metrics['tn']} FN={test_metrics['fn']}"
684
+ )
685
+
686
+ # Print table of per-label metrics
687
+ print("\n[VAL PER-LABEL METRICS]")
688
+ print(f"{'Label':20s} {'Prec':>8s} {'Rec':>8s} {'F1':>8s} {'Support':>8s}")
689
+ print("-" * 60)
690
+
691
+ for j, stats in enumerate(val_metrics["per_label"]):
692
+ name = labels[j].name
693
+ p = stats["precision"]
694
+ r = stats["recall"]
695
+ f = stats["f1"]
696
+ sup = stats["support"]
697
+
698
+ if p is None:
699
+ print(f"{name:20s} {'N/A':>8} {'N/A':>8} {'N/A':>8} {sup:8d}")
700
+ else:
701
+ print(f"{name:20s} {p:8.4f} {r:8.4f} {f:8.4f} {sup:8d}")
702
+
703
+ print("-" * 60)
704
+
705
+ # ------------------------------------------------------------------
706
+ # 11. Save model + metadata
707
+ # ------------------------------------------------------------------
708
+ backbone_tag = config.backbone_name.replace(" ", "_")
709
+
710
+ """'run_name = (
711
+ f"aidataset_{config.dataset_id}_"
712
+ f"RN50_GastroNet1M_DINO_v{config.labelset_version_to_train}_multilabel"
713
+ )"""
714
+
715
+ # Keep the old name for the GastroNet RN50 backbone
716
+ if getattr(config, "backbone_name", "gastro_rn50") == "gastro_rn50":
717
+ run_name = (
718
+ f"aidataset_{config.dataset_id}_"
719
+ f"RN50_GastroNet1M_DINO_v{config.labelset_version_to_train}_multilabel"
720
+ )
721
+ else:
722
+ # For all other backbones, use a generic name that includes backbone_name
723
+ backbone_tag = config.backbone_name.replace(" ", "_")
724
+ run_name = (
725
+ f"aidataset_{config.dataset_id}_"
726
+ f"{backbone_tag}_v{config.labelset_version_to_train}_multilabel"
727
+ )
728
+
729
+ model_path = RUNS_DIR / f"{run_name}.pth"
730
+ meta_path = RUNS_DIR / f"{run_name}_meta.json"
731
+
732
+ torch.save(model.state_dict(), model_path)
733
+
734
+ meta = {
735
+ "config": {
736
+ "dataset_id": config.dataset_id,
737
+ "labelset_version_to_train": config.labelset_version_to_train,
738
+ "backbone_checkpoint": config.backbone_checkpoint,
739
+ "num_epochs": config.num_epochs,
740
+ "batch_size": config.batch_size,
741
+ "val_split": config.val_split,
742
+ "test_split": config.test_split,
743
+ "lr_head": config.lr_head,
744
+ "lr_backbone": config.lr_backbone,
745
+ "gamma_focal": config.gamma_focal,
746
+ "alpha_focal": config.alpha_focal,
747
+ "device": config.device,
748
+ "random_seed": config.random_seed,
749
+ "treat_unlabeled_as_negative": config.treat_unlabeled_as_negative,
750
+ "use_scheduler": config.use_scheduler,
751
+ "warmup_epochs": config.warmup_epochs,
752
+ "min_lr": config.min_lr,
753
+ },
754
+ "original_labelset_id": labelset.id,
755
+ "original_labelset_name": labelset.name,
756
+ "original_labelset_version": labelset.version,
757
+ "used_label_names": [lbl.name for lbl in labels],
758
+ "used_label_indices_original": kept_indices,
759
+ "history": history,
760
+ }
761
+ with meta_path.open("w", encoding="utf-8") as f:
762
+ json.dump(meta, f, indent=2)
763
+
764
+ print("[TRAIN] Saved model to:", model_path)
765
+ print("[TRAIN] Saved metadata to:", meta_path)
766
+
767
+ return {
768
+ "model_path": str(model_path),
769
+ "meta_path": str(meta_path),
770
+ "history": history,
771
+ }