ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (535) hide show
  1. agentui/.prettierrc +15 -0
  2. agentui/QUICKSTART.md +272 -0
  3. agentui/README.md +59 -0
  4. agentui/env.example +16 -0
  5. agentui/jsconfig.json +14 -0
  6. agentui/package-lock.json +4242 -0
  7. agentui/package.json +34 -0
  8. agentui/scripts/postinstall/apply-patches.mjs +260 -0
  9. agentui/src/app.css +61 -0
  10. agentui/src/app.d.ts +13 -0
  11. agentui/src/app.html +12 -0
  12. agentui/src/components/LoadingSpinner.svelte +64 -0
  13. agentui/src/components/ThemeSwitcher.svelte +159 -0
  14. agentui/src/components/index.js +4 -0
  15. agentui/src/lib/api/bots.ts +60 -0
  16. agentui/src/lib/api/chat.ts +22 -0
  17. agentui/src/lib/api/http.ts +25 -0
  18. agentui/src/lib/components/BotCard.svelte +33 -0
  19. agentui/src/lib/components/ChatBubble.svelte +63 -0
  20. agentui/src/lib/components/Toast.svelte +21 -0
  21. agentui/src/lib/config.ts +20 -0
  22. agentui/src/lib/stores/auth.svelte.ts +73 -0
  23. agentui/src/lib/stores/theme.svelte.js +64 -0
  24. agentui/src/lib/stores/toast.svelte.ts +31 -0
  25. agentui/src/lib/utils/conversation.ts +39 -0
  26. agentui/src/routes/+layout.svelte +20 -0
  27. agentui/src/routes/+page.svelte +232 -0
  28. agentui/src/routes/login/+page.svelte +200 -0
  29. agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
  30. agentui/src/routes/talk/[agentId]/+page.ts +7 -0
  31. agentui/static/README.md +1 -0
  32. agentui/svelte.config.js +11 -0
  33. agentui/tailwind.config.ts +53 -0
  34. agentui/tsconfig.json +3 -0
  35. agentui/vite.config.ts +10 -0
  36. ai_parrot-0.17.2.dist-info/METADATA +472 -0
  37. ai_parrot-0.17.2.dist-info/RECORD +535 -0
  38. ai_parrot-0.17.2.dist-info/WHEEL +6 -0
  39. ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
  40. ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
  41. ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
  42. crew-builder/.prettierrc +15 -0
  43. crew-builder/QUICKSTART.md +259 -0
  44. crew-builder/README.md +113 -0
  45. crew-builder/env.example +17 -0
  46. crew-builder/jsconfig.json +14 -0
  47. crew-builder/package-lock.json +4182 -0
  48. crew-builder/package.json +37 -0
  49. crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
  50. crew-builder/src/app.css +62 -0
  51. crew-builder/src/app.d.ts +13 -0
  52. crew-builder/src/app.html +12 -0
  53. crew-builder/src/components/LoadingSpinner.svelte +64 -0
  54. crew-builder/src/components/ThemeSwitcher.svelte +149 -0
  55. crew-builder/src/components/index.js +9 -0
  56. crew-builder/src/lib/api/bots.ts +60 -0
  57. crew-builder/src/lib/api/chat.ts +80 -0
  58. crew-builder/src/lib/api/client.ts +56 -0
  59. crew-builder/src/lib/api/crew/crew.ts +136 -0
  60. crew-builder/src/lib/api/index.ts +5 -0
  61. crew-builder/src/lib/api/o365/auth.ts +65 -0
  62. crew-builder/src/lib/auth/auth.ts +54 -0
  63. crew-builder/src/lib/components/AgentNode.svelte +43 -0
  64. crew-builder/src/lib/components/BotCard.svelte +33 -0
  65. crew-builder/src/lib/components/ChatBubble.svelte +67 -0
  66. crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
  67. crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
  68. crew-builder/src/lib/components/JsonViewer.svelte +24 -0
  69. crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
  70. crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
  71. crew-builder/src/lib/components/Toast.svelte +67 -0
  72. crew-builder/src/lib/components/Toolbar.svelte +157 -0
  73. crew-builder/src/lib/components/index.ts +10 -0
  74. crew-builder/src/lib/config.ts +8 -0
  75. crew-builder/src/lib/stores/auth.svelte.ts +228 -0
  76. crew-builder/src/lib/stores/crewStore.ts +369 -0
  77. crew-builder/src/lib/stores/theme.svelte.js +145 -0
  78. crew-builder/src/lib/stores/toast.svelte.ts +69 -0
  79. crew-builder/src/lib/utils/conversation.ts +39 -0
  80. crew-builder/src/lib/utils/markdown.ts +122 -0
  81. crew-builder/src/lib/utils/talkHistory.ts +47 -0
  82. crew-builder/src/routes/+layout.svelte +20 -0
  83. crew-builder/src/routes/+page.svelte +539 -0
  84. crew-builder/src/routes/agents/+page.svelte +247 -0
  85. crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
  86. crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
  87. crew-builder/src/routes/builder/+page.svelte +204 -0
  88. crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
  89. crew-builder/src/routes/crew/ask/+page.ts +1 -0
  90. crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
  91. crew-builder/src/routes/login/+page.svelte +197 -0
  92. crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
  93. crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
  94. crew-builder/static/README.md +1 -0
  95. crew-builder/svelte.config.js +11 -0
  96. crew-builder/tailwind.config.ts +53 -0
  97. crew-builder/tsconfig.json +3 -0
  98. crew-builder/vite.config.ts +10 -0
  99. mcp_servers/calculator_server.py +309 -0
  100. parrot/__init__.py +27 -0
  101. parrot/__pycache__/__init__.cpython-310.pyc +0 -0
  102. parrot/__pycache__/version.cpython-310.pyc +0 -0
  103. parrot/_version.py +34 -0
  104. parrot/a2a/__init__.py +48 -0
  105. parrot/a2a/client.py +658 -0
  106. parrot/a2a/discovery.py +89 -0
  107. parrot/a2a/mixin.py +257 -0
  108. parrot/a2a/models.py +376 -0
  109. parrot/a2a/server.py +770 -0
  110. parrot/agents/__init__.py +29 -0
  111. parrot/bots/__init__.py +12 -0
  112. parrot/bots/a2a_agent.py +19 -0
  113. parrot/bots/abstract.py +3139 -0
  114. parrot/bots/agent.py +1129 -0
  115. parrot/bots/basic.py +9 -0
  116. parrot/bots/chatbot.py +669 -0
  117. parrot/bots/data.py +1618 -0
  118. parrot/bots/database/__init__.py +5 -0
  119. parrot/bots/database/abstract.py +3071 -0
  120. parrot/bots/database/cache.py +286 -0
  121. parrot/bots/database/models.py +468 -0
  122. parrot/bots/database/prompts.py +154 -0
  123. parrot/bots/database/retries.py +98 -0
  124. parrot/bots/database/router.py +269 -0
  125. parrot/bots/database/sql.py +41 -0
  126. parrot/bots/db/__init__.py +6 -0
  127. parrot/bots/db/abstract.py +556 -0
  128. parrot/bots/db/bigquery.py +602 -0
  129. parrot/bots/db/cache.py +85 -0
  130. parrot/bots/db/documentdb.py +668 -0
  131. parrot/bots/db/elastic.py +1014 -0
  132. parrot/bots/db/influx.py +898 -0
  133. parrot/bots/db/mock.py +96 -0
  134. parrot/bots/db/multi.py +783 -0
  135. parrot/bots/db/prompts.py +185 -0
  136. parrot/bots/db/sql.py +1255 -0
  137. parrot/bots/db/tools.py +212 -0
  138. parrot/bots/document.py +680 -0
  139. parrot/bots/hrbot.py +15 -0
  140. parrot/bots/kb.py +170 -0
  141. parrot/bots/mcp.py +36 -0
  142. parrot/bots/orchestration/README.md +463 -0
  143. parrot/bots/orchestration/__init__.py +1 -0
  144. parrot/bots/orchestration/agent.py +155 -0
  145. parrot/bots/orchestration/crew.py +3330 -0
  146. parrot/bots/orchestration/fsm.py +1179 -0
  147. parrot/bots/orchestration/hr.py +434 -0
  148. parrot/bots/orchestration/storage/__init__.py +4 -0
  149. parrot/bots/orchestration/storage/memory.py +100 -0
  150. parrot/bots/orchestration/storage/mixin.py +119 -0
  151. parrot/bots/orchestration/verify.py +202 -0
  152. parrot/bots/product.py +204 -0
  153. parrot/bots/prompts/__init__.py +96 -0
  154. parrot/bots/prompts/agents.py +155 -0
  155. parrot/bots/prompts/data.py +216 -0
  156. parrot/bots/prompts/output_generation.py +8 -0
  157. parrot/bots/scraper/__init__.py +3 -0
  158. parrot/bots/scraper/models.py +122 -0
  159. parrot/bots/scraper/scraper.py +1173 -0
  160. parrot/bots/scraper/templates.py +115 -0
  161. parrot/bots/stores/__init__.py +5 -0
  162. parrot/bots/stores/local.py +172 -0
  163. parrot/bots/webdev.py +81 -0
  164. parrot/cli.py +17 -0
  165. parrot/clients/__init__.py +16 -0
  166. parrot/clients/base.py +1491 -0
  167. parrot/clients/claude.py +1191 -0
  168. parrot/clients/factory.py +129 -0
  169. parrot/clients/google.py +4567 -0
  170. parrot/clients/gpt.py +1975 -0
  171. parrot/clients/grok.py +432 -0
  172. parrot/clients/groq.py +986 -0
  173. parrot/clients/hf.py +582 -0
  174. parrot/clients/models.py +18 -0
  175. parrot/conf.py +395 -0
  176. parrot/embeddings/__init__.py +9 -0
  177. parrot/embeddings/base.py +157 -0
  178. parrot/embeddings/google.py +98 -0
  179. parrot/embeddings/huggingface.py +74 -0
  180. parrot/embeddings/openai.py +84 -0
  181. parrot/embeddings/processor.py +88 -0
  182. parrot/exceptions.c +13868 -0
  183. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  184. parrot/exceptions.pxd +22 -0
  185. parrot/exceptions.pxi +15 -0
  186. parrot/exceptions.pyx +44 -0
  187. parrot/generators/__init__.py +29 -0
  188. parrot/generators/base.py +200 -0
  189. parrot/generators/html.py +293 -0
  190. parrot/generators/react.py +205 -0
  191. parrot/generators/streamlit.py +203 -0
  192. parrot/generators/template.py +105 -0
  193. parrot/handlers/__init__.py +4 -0
  194. parrot/handlers/agent.py +861 -0
  195. parrot/handlers/agents/__init__.py +1 -0
  196. parrot/handlers/agents/abstract.py +900 -0
  197. parrot/handlers/bots.py +338 -0
  198. parrot/handlers/chat.py +915 -0
  199. parrot/handlers/creation.sql +192 -0
  200. parrot/handlers/crew/ARCHITECTURE.md +362 -0
  201. parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
  202. parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
  203. parrot/handlers/crew/__init__.py +0 -0
  204. parrot/handlers/crew/handler.py +801 -0
  205. parrot/handlers/crew/models.py +229 -0
  206. parrot/handlers/crew/redis_persistence.py +523 -0
  207. parrot/handlers/jobs/__init__.py +10 -0
  208. parrot/handlers/jobs/job.py +384 -0
  209. parrot/handlers/jobs/mixin.py +627 -0
  210. parrot/handlers/jobs/models.py +115 -0
  211. parrot/handlers/jobs/worker.py +31 -0
  212. parrot/handlers/models.py +596 -0
  213. parrot/handlers/o365_auth.py +105 -0
  214. parrot/handlers/stream.py +337 -0
  215. parrot/interfaces/__init__.py +6 -0
  216. parrot/interfaces/aws.py +143 -0
  217. parrot/interfaces/credentials.py +113 -0
  218. parrot/interfaces/database.py +27 -0
  219. parrot/interfaces/google.py +1123 -0
  220. parrot/interfaces/hierarchy.py +1227 -0
  221. parrot/interfaces/http.py +651 -0
  222. parrot/interfaces/images/__init__.py +0 -0
  223. parrot/interfaces/images/plugins/__init__.py +24 -0
  224. parrot/interfaces/images/plugins/abstract.py +58 -0
  225. parrot/interfaces/images/plugins/analisys.py +148 -0
  226. parrot/interfaces/images/plugins/classify.py +150 -0
  227. parrot/interfaces/images/plugins/classifybase.py +182 -0
  228. parrot/interfaces/images/plugins/detect.py +150 -0
  229. parrot/interfaces/images/plugins/exif.py +1103 -0
  230. parrot/interfaces/images/plugins/hash.py +52 -0
  231. parrot/interfaces/images/plugins/vision.py +104 -0
  232. parrot/interfaces/images/plugins/yolo.py +66 -0
  233. parrot/interfaces/images/plugins/zerodetect.py +197 -0
  234. parrot/interfaces/o365.py +978 -0
  235. parrot/interfaces/onedrive.py +822 -0
  236. parrot/interfaces/sharepoint.py +1435 -0
  237. parrot/interfaces/soap.py +257 -0
  238. parrot/loaders/__init__.py +8 -0
  239. parrot/loaders/abstract.py +1131 -0
  240. parrot/loaders/audio.py +199 -0
  241. parrot/loaders/basepdf.py +53 -0
  242. parrot/loaders/basevideo.py +1568 -0
  243. parrot/loaders/csv.py +409 -0
  244. parrot/loaders/docx.py +116 -0
  245. parrot/loaders/epubloader.py +316 -0
  246. parrot/loaders/excel.py +199 -0
  247. parrot/loaders/factory.py +55 -0
  248. parrot/loaders/files/__init__.py +0 -0
  249. parrot/loaders/files/abstract.py +39 -0
  250. parrot/loaders/files/html.py +26 -0
  251. parrot/loaders/files/text.py +63 -0
  252. parrot/loaders/html.py +152 -0
  253. parrot/loaders/markdown.py +442 -0
  254. parrot/loaders/pdf.py +373 -0
  255. parrot/loaders/pdfmark.py +320 -0
  256. parrot/loaders/pdftables.py +506 -0
  257. parrot/loaders/ppt.py +476 -0
  258. parrot/loaders/qa.py +63 -0
  259. parrot/loaders/splitters/__init__.py +10 -0
  260. parrot/loaders/splitters/base.py +138 -0
  261. parrot/loaders/splitters/md.py +228 -0
  262. parrot/loaders/splitters/token.py +143 -0
  263. parrot/loaders/txt.py +26 -0
  264. parrot/loaders/video.py +89 -0
  265. parrot/loaders/videolocal.py +218 -0
  266. parrot/loaders/videounderstanding.py +377 -0
  267. parrot/loaders/vimeo.py +167 -0
  268. parrot/loaders/web.py +599 -0
  269. parrot/loaders/youtube.py +504 -0
  270. parrot/manager/__init__.py +5 -0
  271. parrot/manager/manager.py +1030 -0
  272. parrot/mcp/__init__.py +28 -0
  273. parrot/mcp/adapter.py +105 -0
  274. parrot/mcp/cli.py +174 -0
  275. parrot/mcp/client.py +119 -0
  276. parrot/mcp/config.py +75 -0
  277. parrot/mcp/integration.py +842 -0
  278. parrot/mcp/oauth.py +933 -0
  279. parrot/mcp/server.py +225 -0
  280. parrot/mcp/transports/__init__.py +3 -0
  281. parrot/mcp/transports/base.py +279 -0
  282. parrot/mcp/transports/grpc_session.py +163 -0
  283. parrot/mcp/transports/http.py +312 -0
  284. parrot/mcp/transports/mcp.proto +108 -0
  285. parrot/mcp/transports/quic.py +1082 -0
  286. parrot/mcp/transports/sse.py +330 -0
  287. parrot/mcp/transports/stdio.py +309 -0
  288. parrot/mcp/transports/unix.py +395 -0
  289. parrot/mcp/transports/websocket.py +547 -0
  290. parrot/memory/__init__.py +16 -0
  291. parrot/memory/abstract.py +209 -0
  292. parrot/memory/agent.py +32 -0
  293. parrot/memory/cache.py +175 -0
  294. parrot/memory/core.py +555 -0
  295. parrot/memory/file.py +153 -0
  296. parrot/memory/mem.py +131 -0
  297. parrot/memory/redis.py +613 -0
  298. parrot/models/__init__.py +46 -0
  299. parrot/models/basic.py +118 -0
  300. parrot/models/compliance.py +208 -0
  301. parrot/models/crew.py +395 -0
  302. parrot/models/detections.py +654 -0
  303. parrot/models/generation.py +85 -0
  304. parrot/models/google.py +223 -0
  305. parrot/models/groq.py +23 -0
  306. parrot/models/openai.py +30 -0
  307. parrot/models/outputs.py +285 -0
  308. parrot/models/responses.py +938 -0
  309. parrot/notifications/__init__.py +743 -0
  310. parrot/openapi/__init__.py +3 -0
  311. parrot/openapi/components.yaml +641 -0
  312. parrot/openapi/config.py +322 -0
  313. parrot/outputs/__init__.py +32 -0
  314. parrot/outputs/formats/__init__.py +108 -0
  315. parrot/outputs/formats/altair.py +359 -0
  316. parrot/outputs/formats/application.py +122 -0
  317. parrot/outputs/formats/base.py +351 -0
  318. parrot/outputs/formats/bokeh.py +356 -0
  319. parrot/outputs/formats/card.py +424 -0
  320. parrot/outputs/formats/chart.py +436 -0
  321. parrot/outputs/formats/d3.py +255 -0
  322. parrot/outputs/formats/echarts.py +310 -0
  323. parrot/outputs/formats/generators/__init__.py +0 -0
  324. parrot/outputs/formats/generators/abstract.py +61 -0
  325. parrot/outputs/formats/generators/panel.py +145 -0
  326. parrot/outputs/formats/generators/streamlit.py +86 -0
  327. parrot/outputs/formats/generators/terminal.py +63 -0
  328. parrot/outputs/formats/holoviews.py +310 -0
  329. parrot/outputs/formats/html.py +147 -0
  330. parrot/outputs/formats/jinja2.py +46 -0
  331. parrot/outputs/formats/json.py +87 -0
  332. parrot/outputs/formats/map.py +933 -0
  333. parrot/outputs/formats/markdown.py +172 -0
  334. parrot/outputs/formats/matplotlib.py +237 -0
  335. parrot/outputs/formats/mixins/__init__.py +0 -0
  336. parrot/outputs/formats/mixins/emaps.py +855 -0
  337. parrot/outputs/formats/plotly.py +341 -0
  338. parrot/outputs/formats/seaborn.py +310 -0
  339. parrot/outputs/formats/table.py +397 -0
  340. parrot/outputs/formats/template_report.py +138 -0
  341. parrot/outputs/formats/yaml.py +125 -0
  342. parrot/outputs/formatter.py +152 -0
  343. parrot/outputs/templates/__init__.py +95 -0
  344. parrot/pipelines/__init__.py +0 -0
  345. parrot/pipelines/abstract.py +210 -0
  346. parrot/pipelines/detector.py +124 -0
  347. parrot/pipelines/models.py +90 -0
  348. parrot/pipelines/planogram.py +3002 -0
  349. parrot/pipelines/table.sql +97 -0
  350. parrot/plugins/__init__.py +106 -0
  351. parrot/plugins/importer.py +80 -0
  352. parrot/py.typed +0 -0
  353. parrot/registry/__init__.py +18 -0
  354. parrot/registry/registry.py +594 -0
  355. parrot/scheduler/__init__.py +1189 -0
  356. parrot/scheduler/models.py +60 -0
  357. parrot/security/__init__.py +16 -0
  358. parrot/security/prompt_injection.py +268 -0
  359. parrot/security/security_events.sql +25 -0
  360. parrot/services/__init__.py +1 -0
  361. parrot/services/mcp/__init__.py +8 -0
  362. parrot/services/mcp/config.py +13 -0
  363. parrot/services/mcp/server.py +295 -0
  364. parrot/services/o365_remote_auth.py +235 -0
  365. parrot/stores/__init__.py +7 -0
  366. parrot/stores/abstract.py +352 -0
  367. parrot/stores/arango.py +1090 -0
  368. parrot/stores/bigquery.py +1377 -0
  369. parrot/stores/cache.py +106 -0
  370. parrot/stores/empty.py +10 -0
  371. parrot/stores/faiss_store.py +1157 -0
  372. parrot/stores/kb/__init__.py +9 -0
  373. parrot/stores/kb/abstract.py +68 -0
  374. parrot/stores/kb/cache.py +165 -0
  375. parrot/stores/kb/doc.py +325 -0
  376. parrot/stores/kb/hierarchy.py +346 -0
  377. parrot/stores/kb/local.py +457 -0
  378. parrot/stores/kb/prompt.py +28 -0
  379. parrot/stores/kb/redis.py +659 -0
  380. parrot/stores/kb/store.py +115 -0
  381. parrot/stores/kb/user.py +374 -0
  382. parrot/stores/models.py +59 -0
  383. parrot/stores/pgvector.py +3 -0
  384. parrot/stores/postgres.py +2853 -0
  385. parrot/stores/utils/__init__.py +0 -0
  386. parrot/stores/utils/chunking.py +197 -0
  387. parrot/telemetry/__init__.py +3 -0
  388. parrot/telemetry/mixin.py +111 -0
  389. parrot/template/__init__.py +3 -0
  390. parrot/template/engine.py +259 -0
  391. parrot/tools/__init__.py +23 -0
  392. parrot/tools/abstract.py +644 -0
  393. parrot/tools/agent.py +363 -0
  394. parrot/tools/arangodbsearch.py +537 -0
  395. parrot/tools/arxiv_tool.py +188 -0
  396. parrot/tools/calculator/__init__.py +3 -0
  397. parrot/tools/calculator/operations/__init__.py +38 -0
  398. parrot/tools/calculator/operations/calculus.py +80 -0
  399. parrot/tools/calculator/operations/statistics.py +76 -0
  400. parrot/tools/calculator/tool.py +150 -0
  401. parrot/tools/cloudwatch.py +988 -0
  402. parrot/tools/codeinterpreter/__init__.py +127 -0
  403. parrot/tools/codeinterpreter/executor.py +371 -0
  404. parrot/tools/codeinterpreter/internals.py +473 -0
  405. parrot/tools/codeinterpreter/models.py +643 -0
  406. parrot/tools/codeinterpreter/prompts.py +224 -0
  407. parrot/tools/codeinterpreter/tool.py +664 -0
  408. parrot/tools/company_info/__init__.py +6 -0
  409. parrot/tools/company_info/tool.py +1138 -0
  410. parrot/tools/correlationanalysis.py +437 -0
  411. parrot/tools/database/abstract.py +286 -0
  412. parrot/tools/database/bq.py +115 -0
  413. parrot/tools/database/cache.py +284 -0
  414. parrot/tools/database/models.py +95 -0
  415. parrot/tools/database/pg.py +343 -0
  416. parrot/tools/databasequery.py +1159 -0
  417. parrot/tools/db.py +1800 -0
  418. parrot/tools/ddgo.py +370 -0
  419. parrot/tools/decorators.py +271 -0
  420. parrot/tools/dftohtml.py +282 -0
  421. parrot/tools/document.py +549 -0
  422. parrot/tools/ecs.py +819 -0
  423. parrot/tools/edareport.py +368 -0
  424. parrot/tools/elasticsearch.py +1049 -0
  425. parrot/tools/employees.py +462 -0
  426. parrot/tools/epson/__init__.py +96 -0
  427. parrot/tools/excel.py +683 -0
  428. parrot/tools/file/__init__.py +13 -0
  429. parrot/tools/file/abstract.py +76 -0
  430. parrot/tools/file/gcs.py +378 -0
  431. parrot/tools/file/local.py +284 -0
  432. parrot/tools/file/s3.py +511 -0
  433. parrot/tools/file/tmp.py +309 -0
  434. parrot/tools/file/tool.py +501 -0
  435. parrot/tools/file_reader.py +129 -0
  436. parrot/tools/flowtask/__init__.py +19 -0
  437. parrot/tools/flowtask/tool.py +761 -0
  438. parrot/tools/gittoolkit.py +508 -0
  439. parrot/tools/google/__init__.py +18 -0
  440. parrot/tools/google/base.py +169 -0
  441. parrot/tools/google/tools.py +1251 -0
  442. parrot/tools/googlelocation.py +5 -0
  443. parrot/tools/googleroutes.py +5 -0
  444. parrot/tools/googlesearch.py +5 -0
  445. parrot/tools/googlesitesearch.py +5 -0
  446. parrot/tools/googlevoice.py +2 -0
  447. parrot/tools/gvoice.py +695 -0
  448. parrot/tools/ibisworld/README.md +225 -0
  449. parrot/tools/ibisworld/__init__.py +11 -0
  450. parrot/tools/ibisworld/tool.py +366 -0
  451. parrot/tools/jiratoolkit.py +1718 -0
  452. parrot/tools/manager.py +1098 -0
  453. parrot/tools/math.py +152 -0
  454. parrot/tools/metadata.py +476 -0
  455. parrot/tools/msteams.py +1621 -0
  456. parrot/tools/msword.py +635 -0
  457. parrot/tools/multidb.py +580 -0
  458. parrot/tools/multistoresearch.py +369 -0
  459. parrot/tools/networkninja.py +167 -0
  460. parrot/tools/nextstop/__init__.py +4 -0
  461. parrot/tools/nextstop/base.py +286 -0
  462. parrot/tools/nextstop/employee.py +733 -0
  463. parrot/tools/nextstop/store.py +462 -0
  464. parrot/tools/notification.py +435 -0
  465. parrot/tools/o365/__init__.py +42 -0
  466. parrot/tools/o365/base.py +295 -0
  467. parrot/tools/o365/bundle.py +522 -0
  468. parrot/tools/o365/events.py +554 -0
  469. parrot/tools/o365/mail.py +992 -0
  470. parrot/tools/o365/onedrive.py +497 -0
  471. parrot/tools/o365/sharepoint.py +641 -0
  472. parrot/tools/openapi_toolkit.py +904 -0
  473. parrot/tools/openweather.py +527 -0
  474. parrot/tools/pdfprint.py +1001 -0
  475. parrot/tools/powerbi.py +518 -0
  476. parrot/tools/powerpoint.py +1113 -0
  477. parrot/tools/pricestool.py +146 -0
  478. parrot/tools/products/__init__.py +246 -0
  479. parrot/tools/prophet_tool.py +171 -0
  480. parrot/tools/pythonpandas.py +630 -0
  481. parrot/tools/pythonrepl.py +910 -0
  482. parrot/tools/qsource.py +436 -0
  483. parrot/tools/querytoolkit.py +395 -0
  484. parrot/tools/quickeda.py +827 -0
  485. parrot/tools/resttool.py +553 -0
  486. parrot/tools/retail/__init__.py +0 -0
  487. parrot/tools/retail/bby.py +528 -0
  488. parrot/tools/sandboxtool.py +703 -0
  489. parrot/tools/sassie/__init__.py +352 -0
  490. parrot/tools/scraping/__init__.py +7 -0
  491. parrot/tools/scraping/docs/select.md +466 -0
  492. parrot/tools/scraping/documentation.md +1278 -0
  493. parrot/tools/scraping/driver.py +436 -0
  494. parrot/tools/scraping/models.py +576 -0
  495. parrot/tools/scraping/options.py +85 -0
  496. parrot/tools/scraping/orchestrator.py +517 -0
  497. parrot/tools/scraping/readme.md +740 -0
  498. parrot/tools/scraping/tool.py +3115 -0
  499. parrot/tools/seasonaldetection.py +642 -0
  500. parrot/tools/shell_tool/__init__.py +5 -0
  501. parrot/tools/shell_tool/actions.py +408 -0
  502. parrot/tools/shell_tool/engine.py +155 -0
  503. parrot/tools/shell_tool/models.py +322 -0
  504. parrot/tools/shell_tool/tool.py +442 -0
  505. parrot/tools/site_search.py +214 -0
  506. parrot/tools/textfile.py +418 -0
  507. parrot/tools/think.py +378 -0
  508. parrot/tools/toolkit.py +298 -0
  509. parrot/tools/webapp_tool.py +187 -0
  510. parrot/tools/whatif.py +1279 -0
  511. parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
  512. parrot/tools/workday/__init__.py +6 -0
  513. parrot/tools/workday/models.py +1389 -0
  514. parrot/tools/workday/tool.py +1293 -0
  515. parrot/tools/yfinance_tool.py +306 -0
  516. parrot/tools/zipcode.py +217 -0
  517. parrot/utils/__init__.py +2 -0
  518. parrot/utils/helpers.py +73 -0
  519. parrot/utils/parsers/__init__.py +5 -0
  520. parrot/utils/parsers/toml.c +12078 -0
  521. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  522. parrot/utils/parsers/toml.pyx +21 -0
  523. parrot/utils/toml.py +11 -0
  524. parrot/utils/types.cpp +20936 -0
  525. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  526. parrot/utils/types.pyx +213 -0
  527. parrot/utils/uv.py +11 -0
  528. parrot/version.py +10 -0
  529. parrot/yaml-rs/Cargo.lock +350 -0
  530. parrot/yaml-rs/Cargo.toml +19 -0
  531. parrot/yaml-rs/pyproject.toml +19 -0
  532. parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
  533. parrot/yaml-rs/src/lib.rs +222 -0
  534. requirements/docker-compose.yml +24 -0
  535. requirements/requirements-dev.txt +21 -0
@@ -0,0 +1,3002 @@
1
+ """
2
+ 3-Step Planogram Compliance Pipeline
3
+ Step 1: Object Detection (YOLO/ResNet)
4
+ Step 2: LLM Object Identification with Reference Images
5
+ Step 3: Planogram Comparison and Compliance Verification
6
+ """
7
+ import asyncio
8
+ import os
9
+ from typing import List, Dict, Any, Optional, Union, Tuple
10
+ from collections import defaultdict, Counter
11
+ from datetime import datetime
12
+ import unicodedata
13
+ import re
14
+ import traceback
15
+ from pathlib import Path
16
+ import math
17
+ import pytesseract
18
+ from PIL import (
19
+ Image,
20
+ ImageDraw,
21
+ ImageFont,
22
+ ImageEnhance,
23
+ ImageOps
24
+ )
25
+ import numpy as np
26
+ from pydantic import BaseModel, Field
27
+ import cv2
28
+ import torch
29
+ from google.genai.errors import ServerError
30
+ from .abstract import AbstractPipeline
31
+ from ..models.detections import (
32
+ DetectionBox,
33
+ Detection,
34
+ Detections,
35
+ ShelfRegion,
36
+ IdentifiedProduct,
37
+ PlanogramDescription
38
+ )
39
+ from ..models.compliance import (
40
+ ComplianceResult,
41
+ ComplianceStatus,
42
+ TextComplianceResult,
43
+ TextMatcher,
44
+ BrandComplianceResult
45
+ )
46
+ from .detector import AbstractDetector
47
+ from .models import PlanogramConfig
48
+
49
+
50
+ CID = {
51
+ "promotional_candidate": 103,
52
+ "product_candidate": 100,
53
+ "box_candidate": 101,
54
+ "price_tag": 102,
55
+ "shelf_region": 190,
56
+ "brand_logo": 105,
57
+ "poster_text": 106,
58
+ }
59
+
60
+ class RetailDetector(AbstractDetector):
61
+ """
62
+ Reference-guided Phase-1 detector.
63
+
64
+ 1) Enhance image (contrast/brightness) to help OCR/YOLO/CLIP.
65
+ 2) Localize the promotional poster using:
66
+ - OCR ('EPSON', 'Hello', 'Savings', etc.)
67
+ - CLIP similarity with your FIRST reference image.
68
+ 3) Crop to poster width (+ margin) to form an endcap ROI (remember offsets).
69
+ 4) Detect shelf lines within ROI (Hough) => top/middle/bottom bands.
70
+ 5) YOLO proposals inside ROI (low conf, class-agnostic).
71
+ 6) For each proposal: OCR + CLIP vs remaining reference images
72
+ => label as promotional/product/box candidate.
73
+ 7) Shrink, merge, suppress items that are inside the poster.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ yolo_model: str = "yolo12l.pt",
79
+ conf: float = 0.15,
80
+ iou: float = 0.5,
81
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
82
+ reference_images: Optional[List[str]] = None, # first is the poster
83
+ **kwargs
84
+ ):
85
+ super().__init__(
86
+ yolo_model=yolo_model,
87
+ conf=conf,
88
+ iou=iou,
89
+ device=device,
90
+ **kwargs
91
+ )
92
+ # Shelf split defaults: header/middle/bottom
93
+ self.shelf_split = (0.40, 0.25, 0.35) # sums to ~1.0
94
+ # Useful elsewhere (price tag guardrails, etc.)
95
+ self.label_strip_ratio = 0.06
96
+ self.ref_paths = reference_images or []
97
+ self.ref_ad = self.ref_paths[0] if self.ref_paths else None
98
+ self.ref_products = self.ref_paths[1:] if len(self.ref_paths) > 1 else []
99
+ self.ref_ad_feat = self._embed_image(self.ref_ad) if self.ref_ad else None
100
+ self.ref_prod_feats = [
101
+ self._embed_image(p) for p in self.ref_products
102
+ ] if self.ref_products else []
103
+
104
+ # -------------------------- Main Detection Entry ---------------------------------
105
+ async def detect(
106
+ self,
107
+ image: Image.Image,
108
+ image_array: np.array,
109
+ endcap: Detection,
110
+ ad: Detection,
111
+ planogram: Optional[PlanogramDescription] = None,
112
+ debug_yolo: Optional[str] = None,
113
+ debug_phase1: Optional[str] = None,
114
+ debug_phases: Optional[str] = None,
115
+ ):
116
+ h, w = image_array.shape[:2]
117
+ # text prompts (backup if no product refs)
118
+ text = [f"a photo of a {t}" for t in planogram.text_tokens if t]
119
+ if not text:
120
+ text = [
121
+ "a photo of a retail promotional poster lightbox",
122
+ "a photo of a product box",
123
+ "a photo of a product cartridge bottle",
124
+ "a photo of a price tag"
125
+ ]
126
+ self.text_tokens = self.proc(
127
+ text=text,
128
+ return_tensors="pt",
129
+ padding=True
130
+ ).to(self.device)
131
+ with torch.no_grad():
132
+ self.text_feats = self.clip.get_text_features(**self.text_tokens)
133
+ self.text_feats = self.text_feats / self.text_feats.norm(dim=-1, keepdim=True)
134
+
135
+ # Check if detections are valid before proceeding
136
+ if not endcap or not ad:
137
+ print("ERROR: Failed to get required detections.")
138
+ return # or raise an exception
139
+
140
+ # 2) endcap ROI
141
+ roi_box = endcap.bbox.get_pixel_coordinates(width=w, height=h)
142
+ ad_box = ad.bbox.get_pixel_coordinates(width=w, height=h)
143
+
144
+ # Unpack the Pixel coordinates
145
+ rx1, ry1, rx2, ry2 = roi_box
146
+
147
+ roi = image_array[ry1:ry2, rx1:rx2]
148
+
149
+ # 4) YOLO inside ROI
150
+ yolo_props = self._yolo_props(roi, rx1, ry1)
151
+
152
+ # Extract planogram config for shelf layout
153
+ planogram_config = None
154
+ if planogram:
155
+ planogram_config = {
156
+ 'shelves': [
157
+ {
158
+ 'level': shelf.level,
159
+ 'height_ratio': getattr(shelf, 'height_ratio', None),
160
+ 'products': [
161
+ {
162
+ 'name': product.name,
163
+ 'product_type': product.product_type
164
+ } for product in shelf.products
165
+ ]
166
+ } for shelf in planogram.shelves
167
+ ]
168
+ }
169
+
170
+ # 3) shelves
171
+ shelf_lines, bands = self._find_shelves(
172
+ roi_box=roi_box,
173
+ ad_box=ad_box,
174
+ w=w,
175
+ h=h,
176
+ planogram_config=planogram_config
177
+ )
178
+ # header_limit_y = min(v[0] for v in bands.values()) if bands else int(0.4 * h)
179
+ # classification fallback limit = header bottom (or 40% of ROI height)
180
+ if bands and "header" in bands:
181
+ header_limit_y = bands["header"][1]
182
+ else:
183
+ roi_h = max(1, ry2 - ry1)
184
+ header_limit_y = ry1 + int(0.4 * roi_h)
185
+
186
+ if debug_yolo:
187
+ dbg = self._draw_phase_areas(image_array.copy(), yolo_props, roi_box)
188
+ if debug_phases:
189
+ cv2.imwrite(
190
+ debug_phases,
191
+ cv2.cvtColor(dbg, cv2.COLOR_RGB2BGR)
192
+ )
193
+ dbg = self._draw_yolo(image_array.copy(), yolo_props, roi_box, shelf_lines)
194
+ cv2.imwrite(
195
+ debug_yolo,
196
+ cv2.cvtColor(dbg, cv2.COLOR_RGB2BGR)
197
+ )
198
+
199
+ # 5) classify YOLO → proposals (works w/ bands={}, header_limit_y above)
200
+ proposals = await self._classify_proposals(
201
+ image_array,
202
+ yolo_props,
203
+ bands,
204
+ header_limit_y,
205
+ ad_box
206
+ )
207
+ # 6) shrink -> merge -> remove those fully inside the poster
208
+ proposals = self._merge(proposals, iou_same=0.45)
209
+
210
+ # shelves dict to satisfy callers; in flat mode keep it empty
211
+ shelves = {
212
+ name: DetectionBox(
213
+ x1=rx1, y1=y1, x2=rx2, y2=y2,
214
+ confidence=1.0,
215
+ class_id=190, class_name="shelf_region",
216
+ area=(rx2-rx1)*(y2-y1),
217
+ )
218
+ for name, (y1, y2) in bands.items()
219
+ }
220
+
221
+ # (OPTIONAL) draw Phase-1 debug
222
+ if debug_phase1:
223
+ dbg = self._draw_phase1(
224
+ image_array.copy(),
225
+ roi_box,
226
+ shelf_lines,
227
+ proposals,
228
+ ad_box
229
+ )
230
+ cv2.imwrite(
231
+ debug_phase1,
232
+ cv2.cvtColor(dbg, cv2.COLOR_RGB2BGR)
233
+ )
234
+
235
+ # 8) ensure the promo exists exactly once
236
+ if ad_box is not None and not any(d.class_name == "promotional_candidate" and self._iou_box_tuple(d, ad_box) > 0.7 for d in proposals):
237
+ x1, y1, x2, y2 = ad_box
238
+ proposals.append(
239
+ DetectionBox(
240
+ x1=x1, y1=y1, x2=x2, y2=y2,
241
+ confidence=0.95,
242
+ class_id=103,
243
+ class_name="promotional_candidate",
244
+ area=(x2-x1)*(y2-y1)
245
+ )
246
+ )
247
+
248
+ return {"shelves": shelves, "proposals": proposals}
249
+
250
+ # --------------------------- shelves -------------------------------------
251
+ def _find_shelves(
252
+ self,
253
+ roi_box: tuple[int, int, int, int],
254
+ ad_box: tuple[int, int, int, int],
255
+ h: int,
256
+ w: int,
257
+ planogram_config: dict = None
258
+ ) -> tuple[List[int], dict]:
259
+ """
260
+ Detects shelf bands based on planogram configuration, prioritizing the
261
+ dynamically detected ad_box for the header.
262
+ """
263
+ rx1, ry1, rx2, ry2 = map(int, roi_box)
264
+ _, ad_y1, _, ad_y2 = map(int, ad_box)
265
+ roi_h = max(1, ry2 - ry1)
266
+
267
+ # Fallback to the old proportional method if no planogram is provided
268
+ if not planogram_config or 'shelves' not in planogram_config:
269
+ return self._find_shelves_proportional(roi_box, rx1, ry1, rx2, ry2, h)
270
+
271
+ shelf_configs = planogram_config['shelves']
272
+ if not shelf_configs:
273
+ return [], {}
274
+
275
+ bands = {}
276
+ levels = []
277
+
278
+ # --- 1. Prioritize the Header based on ad_box ---
279
+ # The header starts at the top of the ROI and ends at the bottom of the ad_box
280
+ header_config = next((s for s in shelf_configs if s.get('level') == 'header'), None)
281
+ if header_config:
282
+ # Use the detected ad_box y-coordinates for the header band
283
+ header_top = ad_y1
284
+ header_bottom = ad_y2
285
+ bands[header_config['level']] = (header_top, header_bottom)
286
+ current_y = header_bottom
287
+ remaining_configs = [s for s in shelf_configs if s.get('level') != 'header']
288
+ else:
289
+ # If no header is defined, start from the top of the ROI
290
+ current_y = ry1
291
+ remaining_configs = shelf_configs
292
+
293
+ # --- 2. Calculate space for remaining shelves ---
294
+ remaining_roi_h = max(1, ry2 - current_y)
295
+
296
+ # Calculate space consumed by shelves with a fixed height_ratio
297
+ height_from_ratios = 0
298
+ shelves_without_ratio = []
299
+ for shelf_config in remaining_configs:
300
+ if 'height_ratio' in shelf_config and shelf_config['height_ratio'] is not None:
301
+ # height_ratio is a percentage of the TOTAL ROI height
302
+ height_from_ratios += int(shelf_config['height_ratio'] * roi_h)
303
+ else:
304
+ shelves_without_ratio.append(shelf_config)
305
+
306
+ # Calculate height for each shelf without a specified ratio
307
+ auto_size_h = max(0, remaining_roi_h - height_from_ratios)
308
+ auto_shelf_height = int(auto_size_h / len(shelves_without_ratio)) if shelves_without_ratio else 0
309
+
310
+ # --- 3. Build the bands for the remaining shelves ---
311
+ for i, shelf_config in enumerate(remaining_configs):
312
+ shelf_level = shelf_config['level']
313
+
314
+ if 'height_ratio' in shelf_config and shelf_config['height_ratio'] is not None:
315
+ shelf_pixel_height = int(shelf_config['height_ratio'] * roi_h)
316
+ else:
317
+ shelf_pixel_height = auto_shelf_height
318
+
319
+ shelf_bottom = current_y + shelf_pixel_height
320
+
321
+ # For the very last shelf, ensure it extends to the bottom of the ROI
322
+ if i == len(remaining_configs) - 1:
323
+ shelf_bottom = ry2
324
+
325
+ # VALIDATION: Ensure valid bounding box
326
+ if shelf_bottom <= current_y:
327
+ print(
328
+ f"WARNING: Invalid shelf {shelf_level}: y1={current_y}, y2={shelf_bottom}"
329
+ )
330
+ shelf_bottom = current_y + 50 # Minimum height
331
+
332
+ bands[shelf_level] = (current_y, shelf_bottom)
333
+ current_y = shelf_bottom
334
+
335
+ # --- 4. Create the levels list (separator lines) ---
336
+ # The levels are the bottom coordinate of each shelf band, except for the last one
337
+ if bands:
338
+ # Ensure order from top to bottom based on the planogram config
339
+ ordered_levels = [bands[s['level']][1] for s in shelf_configs if s['level'] in bands]
340
+ levels = ordered_levels[:-1]
341
+
342
+ self.logger.debug(
343
+ f"📊 Planogram Shelves: {len(shelf_configs)} shelves configured, "
344
+ f"ROI height={roi_h}, bands={bands}"
345
+ )
346
+
347
+ return levels, bands
348
+
349
+ def _find_shelves_proportional(self, roi: tuple, rx1, ry1, rx2, ry2, H, planogram_config: dict = None):
350
+ """
351
+ Fallback proportional layout using planogram config or default 3-shelf layout.
352
+ """
353
+ roi_h = max(1, ry2 - ry1)
354
+
355
+ # Use planogram config if available
356
+ if planogram_config and 'shelves' in planogram_config:
357
+ shelf_configs = planogram_config['shelves']
358
+ num_shelves = len(shelf_configs)
359
+
360
+ if num_shelves > 0:
361
+ # Equal division among configured shelves
362
+ shelf_height = roi_h // num_shelves
363
+
364
+ levels = []
365
+ bands = {}
366
+ current_y = ry1
367
+
368
+ for i, shelf_config in enumerate(shelf_configs):
369
+ shelf_level = shelf_config['level']
370
+ shelf_bottom = current_y + shelf_height
371
+
372
+ # For the last shelf, extend to ROI bottom
373
+ if i == len(shelf_configs) - 1:
374
+ shelf_bottom = ry2
375
+
376
+ bands[shelf_level] = (current_y, shelf_bottom)
377
+ if i < len(shelf_configs) - 1: # Don't add last boundary to levels
378
+ levels.append(shelf_bottom)
379
+
380
+ current_y = shelf_bottom
381
+
382
+ return levels, bands
383
+
384
+ # Default fallback: 3-shelf layout if no config
385
+ hdr_r, mid_r, bot_r = 0.40, 0.30, 0.30
386
+
387
+ header_bottom = ry1 + int(hdr_r * roi_h)
388
+ middle_bottom = header_bottom + int(mid_r * roi_h)
389
+
390
+ # Ensure boundaries don't exceed ROI
391
+ header_bottom = max(ry1 + 20, min(header_bottom, ry2 - 40))
392
+ middle_bottom = max(header_bottom + 20, min(middle_bottom, ry2 - 20))
393
+
394
+ levels = [header_bottom, middle_bottom]
395
+ bands = {
396
+ "header": (ry1, header_bottom),
397
+ "middle": (header_bottom, middle_bottom),
398
+ "bottom": (middle_bottom, ry2),
399
+ }
400
+
401
+ return levels, bands
402
+
403
+ # ---------------------------- YOLO ---------------------------------------
404
+ def _preprocess_roi_for_detection(self, roi: np.ndarray) -> np.ndarray:
405
+ """
406
+ Ultra-minimal preprocessing - only applies when absolutely necessary.
407
+ Use this version if you want maximum preservation of original image quality.
408
+ """
409
+ try:
410
+ # Convert BGR to RGB if needed
411
+ if len(roi.shape) == 3 and roi.shape[2] == 3:
412
+ rgb_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
413
+ else:
414
+ rgb_roi = roi.copy()
415
+
416
+ # Quick contrast check
417
+ gray = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2GRAY)
418
+ contrast = gray.std()
419
+
420
+ # Only process if contrast is very low
421
+ if contrast > 35:
422
+ # Good contrast - return original with minimal sharpening
423
+ result = rgb_roi.astype(np.float32)
424
+
425
+ # Ultra-subtle sharpening
426
+ kernel = np.array([[0, -0.05, 0],
427
+ [-0.05, 1.2, -0.05],
428
+ [0, -0.05, 0]])
429
+
430
+ for i in range(3):
431
+ result[:,:,i] = cv2.filter2D(result[:,:,i], -1, kernel)
432
+
433
+ result = np.clip(result, 0, 255).astype(np.uint8)
434
+ else:
435
+ # Low contrast - apply gentle CLAHE only
436
+ lab = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2LAB)
437
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(10,10))
438
+ lab[:,:,0] = clahe.apply(lab[:,:,0])
439
+ result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
440
+
441
+ # Convert back to BGR if needed
442
+ if len(roi.shape) == 3 and roi.shape[2] == 3:
443
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
444
+
445
+ return result
446
+
447
+ except Exception as e:
448
+ self.logger.warning(f"Minimal ROI preprocessing failed: {e}")
449
+ return roi
450
+
451
+ def _yolo_props(self, roi: np.ndarray, rx1, ry1, detection_phases: Optional[List[Dict[str, Any]]] = None):
452
+ """
453
+ Multi-phase YOLO detection with configurable confidence levels and weighted scoring.
454
+ Returns proposals in the same format expected by existing _classify_proposals method.
455
+
456
+ Args:
457
+ roi: ROI image array
458
+ rx1, ry1: ROI offset coordinates
459
+ detection_phases: List of phase configurations. If None, uses default 2-phase approach.
460
+ """
461
+ # printer ≈ 5–9%, product_box ≈ 7–12%, promotional_graphic ≥ 20%
462
+ CLASS_LIMITS = {
463
+ # Base retail categories
464
+ "poster": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5},
465
+ "person": {"min_area": 0.02, "max_area": 0.60, "min_ar": 0.3, "max_ar": 3.5},
466
+ "printer": {"min_area": 0.010, "max_area": 0.28, "min_ar": 0.6, "max_ar": 2.8},
467
+ "product_box": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2},
468
+ "price_tag": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0},
469
+
470
+ # YOLO classes mapped to retail categories with their own limits
471
+ "tv": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5}, # → poster
472
+ "monitor": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5}, # → poster
473
+ "laptop": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5}, # → poster
474
+ "microwave": {"min_area": 0.010, "max_area": 0.28, "min_ar": 0.6, "max_ar": 2.8}, # → printer
475
+ "book": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2}, # → product_box
476
+ "box": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2}, # → product_box
477
+ "suitcase": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2}, # → product_box
478
+ "bottle": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
479
+ "clock": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
480
+ "mouse": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
481
+ "remote": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
482
+ "cell phone": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
483
+ }
484
+
485
+ # Mapping from YOLO classes to retail categories
486
+ YOLO_TO_RETAIL = {
487
+ "tv": "poster",
488
+ "monitor": "poster",
489
+ "laptop": "poster",
490
+ "microwave": "printer",
491
+ "keyboard": "product_box",
492
+ "book": "product_box",
493
+ "box": "product_box",
494
+ "suitcase": "product_box",
495
+ "bottle": "price_tag",
496
+ "clock": "price_tag",
497
+ "mouse": "price_tag",
498
+ "remote": "price_tag",
499
+ "cell phone": "price_tag",
500
+ }
501
+
502
+ def _get_class_limits(yolo_class: str) -> Optional[Dict[str, float]]:
503
+ """Get class limits for a YOLO class"""
504
+ return CLASS_LIMITS.get(yolo_class, None)
505
+
506
+ def _get_retail_category(yolo_class: str) -> str:
507
+ """Map YOLO class to retail category"""
508
+ return YOLO_TO_RETAIL.get(yolo_class, yolo_class)
509
+
510
+ def _passes_class_limits(yolo_class: str, area_ratio: float, aspect_ratio: float) -> tuple[bool, str]:
511
+ """Check if detection passes class-specific limits"""
512
+ limits = _get_class_limits(yolo_class)
513
+ if not limits:
514
+ # Use generic fallback limits if no class-specific ones
515
+ generic_ok = (0.0008 <= area_ratio <= 0.9 and 0.1 <= aspect_ratio <= 10.0)
516
+ return generic_ok, "generic_limits"
517
+
518
+ area_ok = limits["min_area"] <= area_ratio <= limits["max_area"]
519
+ ar_ok = limits["min_ar"] <= aspect_ratio <= limits["max_ar"]
520
+
521
+ if area_ok and ar_ok:
522
+ retail_category = _get_retail_category(yolo_class)
523
+ return True, f"class_limits_{yolo_class}→{retail_category}"
524
+ else:
525
+ # Provide specific failure reason for debugging
526
+ reasons = []
527
+ if not area_ok:
528
+ reasons.append(
529
+ f"area={area_ratio:.4f} not in [{limits['min_area']:.4f}, {limits['max_area']:.4f}]"
530
+ )
531
+ if not ar_ok:
532
+ reasons.append(
533
+ f"ar={aspect_ratio:.2f} not in [{limits['min_ar']:.2f}, {limits['max_ar']:.2f}]"
534
+ )
535
+ return False, f"failed_{yolo_class}: {'; '.join(reasons)}"
536
+
537
+ # Preprocess ROI to enhance detection of similar-colored objects
538
+ enhanced_roi = self._preprocess_roi_for_detection(roi)
539
+
540
+ if detection_phases is None:
541
+ detection_phases = [
542
+ { # Coarse: quickly find large boxes (e.g., header, promo)
543
+ "name": "coarse",
544
+ "conf": 0.35,
545
+ "iou": 0.35,
546
+ "weight": 0.20,
547
+ "min_area": 0.05, # >= 5% of ROI
548
+ "description": "High confidence pass for large objects",
549
+ },
550
+ # Standard: main workhorse for printers & boxes
551
+ {
552
+ "name": "standard",
553
+ "conf": 0.05,
554
+ "iou": 0.20,
555
+ "weight": 0.70,
556
+ "min_area": 0.001,
557
+ "description": "High confidence pass for clear objects"
558
+ },
559
+ # Aggressive: recover misses but still bounded by class limits
560
+ {
561
+ "name": "aggressive",
562
+ "conf": 0.008,
563
+ "iou": 0.15,
564
+ "weight": 0.10,
565
+ "min_area": 0.0006,
566
+ "description": "Selective aggressive pass for missed objects only"
567
+ },
568
+ ]
569
+
570
+ try:
571
+ H, W = roi.shape[:2]
572
+ roi_area = H * W
573
+ all_proposals = []
574
+
575
+ print(f"\n🔄 Detection with Your Preferred Settings on ROI {W}x{H}")
576
+ print(" " + "="*70)
577
+
578
+ # Statistics tracking
579
+ stats = {
580
+ "total_detections": 0,
581
+ "passed_confidence": 0,
582
+ "passed_size": 0,
583
+ "passed_class_limits": 0,
584
+ "rejected_class_limits": 0
585
+ }
586
+
587
+ # Run both phases with your settings
588
+ for phase_idx, phase in enumerate(detection_phases):
589
+ phase_name = phase["name"]
590
+ conf_thresh = phase["conf"]
591
+ iou_thresh = phase["iou"]
592
+ weight = phase["weight"]
593
+
594
+ print(
595
+ f"\n📡 Phase {phase_idx + 1}: {phase_name}"
596
+ )
597
+ print(
598
+ f" Config: conf={conf_thresh}, iou={iou_thresh}, weight={weight}"
599
+ )
600
+
601
+ r = self.yolo(enhanced_roi, conf=conf_thresh, iou=iou_thresh, verbose=False)[0]
602
+
603
+ if not hasattr(r, 'boxes') or r.boxes is None:
604
+ print(f" 📊 No boxes detected in {phase_name}")
605
+ continue
606
+
607
+ xyxy = r.boxes.xyxy.cpu().numpy()
608
+ confs = r.boxes.conf.cpu().numpy()
609
+ classes = r.boxes.cls.cpu().numpy().astype(int)
610
+ names = r.names
611
+
612
+ print(
613
+ f" 📊 Raw YOLO output: {len(xyxy)} detections"
614
+ )
615
+
616
+ phase_count = 0
617
+ phase_rejected = 0
618
+
619
+ for _, ((x1, y1, x2, y2), conf, cls_id) in enumerate(zip(xyxy, confs, classes)):
620
+ gx1, gy1, gx2, gy2 = int(x1) + rx1, int(y1) + ry1, int(x2) + rx1, int(y2) + ry1
621
+
622
+ width, height = x2 - x1, y2 - y1
623
+ if width <= 0 or height <= 0 or width < 8 or height < 8:
624
+ continue
625
+
626
+ if conf < conf_thresh:
627
+ continue
628
+
629
+ stats["passed_confidence"] += 1
630
+
631
+ area = width * height
632
+ area_ratio = area / roi_area
633
+ aspect_ratio = width / max(height, 1)
634
+ yolo_class = names[cls_id]
635
+
636
+ min_area = phase.get("min_area")
637
+ if min_area and area_ratio < float(min_area):
638
+ continue
639
+
640
+ stats["passed_size"] += 1
641
+
642
+ # Apply class-specific limits
643
+ limits_passed, limit_reason = _passes_class_limits(yolo_class, area_ratio, aspect_ratio)
644
+
645
+ if not limits_passed:
646
+ phase_rejected += 1
647
+ stats["rejected_class_limits"] += 1
648
+ if phase_rejected <= 3: # Log first few rejections for debugging
649
+ print(f" ❌ Rejected {yolo_class}: {limit_reason}")
650
+ continue
651
+
652
+ ocr_text = None
653
+ orientation = self._detect_orientation(gx1, gy1, gx2, gy2)
654
+ if (area_ratio >= 0.0008 and area_ratio <= 0.9):
655
+ # Only run OCR on boxes with an area > 5% of the ROI
656
+ if area_ratio > 0.05:
657
+ try:
658
+ # Crop the specific proposal from the ROI image
659
+ # Use local coordinates (x1, y1, x2, y2) for this
660
+ proposal_img_crop = roi[int(y1):int(y2), int(x1):int(x2)]
661
+
662
+ # --- ROTATION LOGIC for VERTICAL BOXES ---
663
+ if orientation == 'vertical':
664
+ # Rotate the crop 90 degrees counter-clockwise to make text horizontal
665
+ proposal_img_crop = cv2.rotate(
666
+ proposal_img_crop,
667
+ cv2.ROTATE_90_CLOCKWISE
668
+ )
669
+ text = pytesseract.image_to_string(
670
+ proposal_img_crop,
671
+ # config='--psm 6'
672
+ config="--psm 6 -l eng"
673
+ )
674
+ proposal_img_crop = cv2.rotate(
675
+ proposal_img_crop,
676
+ cv2.ROTATE_90_COUNTERCLOCKWISE
677
+ )
678
+ vtext = pytesseract.image_to_string(
679
+ proposal_img_crop,
680
+ # config='--psm 6'
681
+ config="--psm 6 -l eng"
682
+ )
683
+ raw_text = text + ' | ' + vtext
684
+ else:
685
+ # Run Tesseract on the crop
686
+ raw_text = pytesseract.image_to_string(
687
+ proposal_img_crop,
688
+ # config='--psm 6'
689
+ config="--psm 6 -l eng"
690
+ )
691
+
692
+ # Clean up the text
693
+ ocr_text = " ".join(raw_text.strip().split())
694
+ except Exception as ocr_error:
695
+ self.logger.warning(
696
+ f"OCR failed for a proposal: {ocr_error}"
697
+ )
698
+
699
+ orientation = self._detect_orientation(gx1, gy1, gx2, gy2)
700
+ weighted_conf = float(conf) * weight
701
+ proposal = {
702
+ "yolo_label": yolo_class,
703
+ "yolo_conf": float(conf),
704
+ "weighted_conf": weighted_conf,
705
+ "box": (gx1, gy1, gx2, gy2),
706
+ "area_ratio": area_ratio,
707
+ "aspect_ratio": aspect_ratio,
708
+ "orientation": orientation,
709
+ "retail_candidates": self._get_retail_candidates(yolo_class),
710
+ "raw_index": len(all_proposals) + 1,
711
+ "ocr_text": ocr_text,
712
+ "phase": phase_name
713
+ }
714
+ # print('PROPOSAL > ', proposal)
715
+ all_proposals.append(proposal)
716
+ stats["total_detections"] += 1
717
+ phase_count += 1
718
+
719
+ print(f" ✅ Kept {phase_count} detections from {phase_name}")
720
+
721
+ # Light deduplication (let classification handle quality control)
722
+ deduplicated = self._object_deduplication(all_proposals)
723
+
724
+ print(f"\n📊 Detection Summary: {len(deduplicated)} total proposals")
725
+ print(" Focus: Let classification phase handle object type distinction")
726
+
727
+ # Print final statistics
728
+ print(f"\n📊 Detection Summary:")
729
+ print(f" Total YOLO detections: {stats['total_detections']}")
730
+ print(f" Passed confidence: {stats['passed_confidence']}")
731
+ print(f" Passed basic size: {stats['passed_size']}")
732
+ print(f" Passed class limits: {stats['passed_class_limits']}")
733
+ print(f" Rejected by class limits: {stats['rejected_class_limits']}")
734
+ print(f" Final after deduplication: {len(deduplicated)}")
735
+ return deduplicated
736
+
737
+ except Exception as e:
738
+ print(f"Detection failed: {e}")
739
+ traceback.print_exc()
740
+ return []
741
+
742
+ def _determine_shelf_level(self, center_y: float, bands: Dict[str, tuple]) -> str:
743
+ """Enhanced shelf level determination"""
744
+ if not bands:
745
+ return "unknown"
746
+
747
+ for level, (y1, y2) in bands.items():
748
+ if y1 <= center_y <= y2:
749
+ return level
750
+
751
+ # If not in any band, find closest
752
+ min_distance = float('inf')
753
+ closest_level = "unknown"
754
+ for level, (y1, y2) in bands.items():
755
+ band_center = (y1 + y2) / 2
756
+ distance = abs(center_y - band_center)
757
+ if distance < min_distance:
758
+ min_distance = distance
759
+ closest_level = level
760
+
761
+ return closest_level
762
+
763
+ def _detect_orientation(self, x1: int, y1: int, x2: int, y2: int) -> str:
764
+ """Detect orientation from bounding box dimensions"""
765
+ width = x2 - x1
766
+ height = y2 - y1
767
+ aspect_ratio = width / max(height, 1)
768
+
769
+ if aspect_ratio < 0.8:
770
+ return "vertical"
771
+ elif aspect_ratio > 1.5:
772
+ return "horizontal"
773
+ else:
774
+ return "square"
775
+
776
+ def _get_retail_candidates(self, yolo_class: str) -> List[str]:
777
+ """Light retail candidate mapping - let classification do the heavy work"""
778
+ mapping = {
779
+ "microwave": ["printer", "product_box"],
780
+ "tv": ["promotional_graphic", "tv"],
781
+ "television": ["tv"],
782
+ "monitor": ["promotional_graphic"],
783
+ "laptop": ["promotional_graphic"],
784
+ "book": ["product_box"],
785
+ "box": ["product_box"],
786
+ "suitcase": ["product_box", "printer"],
787
+ "bottle": ["ink_bottle", "price_tag"],
788
+ "person": ["promotional_graphic"],
789
+ "clock": ["small_object", "price_tag"],
790
+ "cell phone": ["small_object", "price_tag"],
791
+ }
792
+ return mapping.get(yolo_class, ["product_candidate"])
793
+
794
+ def _object_deduplication(self, all_detections: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
795
+ """
796
+ Enhanced deduplication with container/contained logic and better IoU thresholds
797
+ """
798
+ if not all_detections:
799
+ return []
800
+
801
+ # Sort by weighted confidence (highest first)
802
+ sorted_detections = sorted(all_detections, key=lambda x: x["weighted_conf"], reverse=True)
803
+
804
+ deduplicated = []
805
+ for detection in sorted_detections:
806
+ detection_box = detection["box"]
807
+ x1, y1, x2, y2 = detection_box
808
+ detection_area = (x2 - x1) * (y2 - y1)
809
+
810
+ is_duplicate = False
811
+ is_contained = False
812
+
813
+ for kept in deduplicated[:]:
814
+ kept_box = kept["box"]
815
+ kx1, ky1, kx2, ky2 = kept_box
816
+ kept_area = (kx2 - kx1) * (ky2 - ky1)
817
+
818
+ iou = self._calculate_iou_tuples(detection_box, kept_box)
819
+
820
+ # Standard IoU-based deduplication (lowered threshold)
821
+ if iou > 0.5: # Reduced from 0.7 to 0.5
822
+ is_duplicate = True
823
+ break
824
+
825
+ # (e.g., individual box vs. entire shelf detection)
826
+ if kept_area > detection_area * 3: # Kept is 3x larger
827
+ # Check if detection is substantially contained within kept
828
+ overlap_area = max(0, min(x2, kx2) - max(x1, kx1)) * max(0, min(y2, ky2) - max(y1, ky1))
829
+ contained_ratio = overlap_area / detection_area
830
+ if contained_ratio > 0.8: # 80% of detection is inside kept
831
+ is_contained = True
832
+ break
833
+
834
+ # Check if kept detection is contained within current (much larger) detection
835
+ elif detection_area > kept_area * 3: # Current is 3x larger
836
+ overlap_area = max(0, min(x2, kx2) - max(x1, kx1)) * max(0, min(y2, ky2) - max(y1, ky1))
837
+ contained_ratio = overlap_area / kept_area
838
+ if contained_ratio > 0.8: # 80% of kept is inside current
839
+ # Remove the contained detection and replace with current
840
+ deduplicated.remove(kept)
841
+
842
+ if not is_duplicate and not is_contained:
843
+ deduplicated.append(detection)
844
+
845
+ print(
846
+ f" 🔄 Deduplication: {len(sorted_detections)} → {len(deduplicated)} detections"
847
+ )
848
+ return deduplicated
849
+
850
+ # Additional helper method for phase configuration
851
+ def set_detection_phases(self, phases: List[Dict[str, Any]]):
852
+ """
853
+ Set custom detection phases for the RetailDetector
854
+
855
+ Args:
856
+ phases: List of phase configurations, each containing:
857
+ - name: Phase identifier
858
+ - conf: Confidence threshold
859
+ - iou: IoU threshold
860
+ - weight: Weight for this phase (should sum to 1.0 across all phases)
861
+ - description: Optional description
862
+
863
+ Example:
864
+ detector.set_detection_phases([
865
+ {
866
+ "name": "ultra_high_conf",
867
+ "conf": 0.5,
868
+ "iou": 0.6,
869
+ "weight": 0.5,
870
+ "description": "Ultra high confidence for definite objects"
871
+ },
872
+ {
873
+ "name": "medium_conf",
874
+ "conf": 0.15,
875
+ "iou": 0.4,
876
+ "weight": 0.3,
877
+ "description": "Medium confidence for likely objects"
878
+ },
879
+ {
880
+ "name": "aggressive",
881
+ "conf": 0.005,
882
+ "iou": 0.15,
883
+ "weight": 0.2,
884
+ "description": "Aggressive pass for missed objects"
885
+ }
886
+ ])
887
+ """
888
+ # Validate phase configuration
889
+ total_weight = sum(phase.get("weight", 0) for phase in phases)
890
+ if abs(total_weight - 1.0) > 0.01:
891
+ print(f"⚠️ Warning: Phase weights sum to {total_weight:.3f}, not 1.0")
892
+
893
+ # Validate required fields
894
+ for i, phase in enumerate(phases):
895
+ required_fields = ["name", "conf", "iou", "weight"]
896
+ missing = [field for field in required_fields if field not in phase]
897
+ if missing:
898
+ raise ValueError(f"Phase {i} missing required fields: {missing}")
899
+
900
+ self.detection_phases = phases
901
+ print(f"✅ Configured {len(phases)} detection phases")
902
+ for i, phase in enumerate(phases):
903
+ print(f" Phase {i+1}: {phase['name']} (conf={phase['conf']}, weight={phase['weight']})")
904
+
905
+ def _calculate_iou_tuples(self, box1: tuple, box2: tuple) -> float:
906
+ """Calculate IoU between two bounding boxes in tuple format"""
907
+ x1_1, y1_1, x2_1, y2_1 = box1
908
+ x1_2, y1_2, x2_2, y2_2 = box2
909
+
910
+ # Calculate intersection
911
+ ix1, iy1 = max(x1_1, x1_2), max(y1_1, y1_2)
912
+ ix2, iy2 = min(x2_1, x2_2), min(y2_1, y2_2)
913
+
914
+ if ix2 <= ix1 or iy2 <= iy1:
915
+ return 0.0
916
+
917
+ intersection = (ix2 - ix1) * (iy2 - iy1)
918
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
919
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
920
+ union = area1 + area2 - intersection
921
+
922
+ return intersection / max(union, 1)
923
+
924
+ # ------------------- OCR + CLIP preselection -----------------------------
925
+ def _analyze_crop_visuals(self, crop_bgr: np.ndarray) -> dict:
926
+ """Analyzes a crop for dominant color properties to distinguish printers from boxes."""
927
+ if crop_bgr.size == 0:
928
+ return {"is_mostly_white": False, "is_mostly_blue": False}
929
+
930
+ # Convert to HSV for better color analysis
931
+ hsv = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2HSV)
932
+
933
+ # --- White/Gray Detection ---
934
+ # Define a broad range for white, light gray, and silver colors
935
+ lower_white = np.array([0, 0, 150])
936
+ upper_white = np.array([180, 50, 255])
937
+ white_mask = cv2.inRange(hsv, lower_white, upper_white)
938
+
939
+ # --- Blue Detection ---
940
+ # Define a range for the Epson blue
941
+ lower_blue = np.array([95, 80, 40])
942
+ upper_blue = np.array([125, 255, 255])
943
+ blue_mask = cv2.inRange(hsv, lower_blue, upper_blue)
944
+
945
+ # Calculate the percentage of the image that is white or blue
946
+ total_pixels = crop_bgr.shape[0] * crop_bgr.shape[1]
947
+ white_percentage = (cv2.countNonZero(white_mask) / total_pixels) * 100
948
+ blue_percentage = (cv2.countNonZero(blue_mask) / total_pixels) * 100
949
+
950
+ # Determine if the object is primarily one color
951
+ # Thresholds can be tuned, but these are generally effective.
952
+ is_mostly_white = white_percentage > 40
953
+ is_mostly_blue = blue_percentage > 35
954
+
955
+ return {
956
+ "is_mostly_white": is_mostly_white,
957
+ "is_mostly_blue": is_mostly_blue,
958
+ "white_pct": white_percentage,
959
+ "blue_pct": blue_percentage,
960
+ }
961
+
962
+ async def _classify_proposals(self, img, props, bands, header_limit_y, ad_box=None):
963
+ """
964
+ ENHANCED proposal classification with a robust, heuristic-first decision process.
965
+ 1. Identify price tags by size.
966
+ 2. Identify promotional graphics by position.
967
+ 3. For remaining objects, use strong visual heuristics (color) to classify.
968
+ 4. Use CLIP similarity only as a fallback for ambiguous cases.
969
+ """
970
+ H, W = img.shape[:2]
971
+ final_proposals = []
972
+ PRICE_TAG_AREA_THRESHOLD = 0.005 # 0.5% of total image area
973
+
974
+ print(f"\n🎯 Enhanced Classification: Running {len(props)} proposals...")
975
+ print(" " + "="*60)
976
+
977
+ for p in props:
978
+ x1, y1, x2, y2 = p["box"]
979
+ area = (x2 - x1) * (y2 - y1)
980
+ area_ratio = area / (H * W)
981
+ center_y = (y1 + y2) / 2
982
+
983
+ # Helper to determine shelf level for context
984
+ shelf_level = self._determine_shelf_level(center_y, bands)
985
+
986
+ # --- 1. Price Tag Check (by size) ---
987
+ if area_ratio < PRICE_TAG_AREA_THRESHOLD:
988
+ final_proposals.append(
989
+ DetectionBox(
990
+ x1=x1, y1=y1, x2=x2, y2=y2,
991
+ confidence=p.get('yolo_conf', 0.8),
992
+ class_id=CID["price_tag"],
993
+ class_name="price_tag",
994
+ area=area,
995
+ ocr_text=p.get('ocr_text')
996
+ )
997
+ )
998
+ continue
999
+
1000
+ # --- 2. Promotional Graphic Check (by position) ---
1001
+ if center_y < header_limit_y:
1002
+ final_proposals.append(
1003
+ DetectionBox(
1004
+ x1=x1, y1=y1, x2=x2, y2=y2,
1005
+ confidence=p.get('yolo_conf', 0.9),
1006
+ class_id=CID["promotional_candidate"],
1007
+ class_name="promotional_candidate",
1008
+ area=area,
1009
+ ocr_text=p.get('ocr_text')
1010
+ )
1011
+ )
1012
+ continue
1013
+
1014
+ # --- 3. Heuristic & CLIP Classification for Products/Boxes ---
1015
+ try:
1016
+ crop_bgr = img[y1:y2, x1:x2]
1017
+ if crop_bgr.size == 0:
1018
+ continue
1019
+
1020
+ # Get visual heuristics and CLIP scores
1021
+ visuals = self._analyze_crop_visuals(crop_bgr)
1022
+
1023
+ crop_pil = Image.fromarray(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB))
1024
+ with torch.no_grad():
1025
+ ip = self.proc(images=crop_pil, return_tensors="pt").to(self.device)
1026
+ img_feat = self.clip.get_image_features(**ip)
1027
+ img_feat /= img_feat.norm(dim=-1, keepdim=True)
1028
+ text_sims = (img_feat @ self.text_feats.T).squeeze().tolist()
1029
+ s_poster, s_printer, s_box = text_sims[0], text_sims[1], text_sims[2]
1030
+
1031
+ # --- New Decision Logic ---
1032
+ class_name = None
1033
+ confidence = 0.8 # Default confidence for heuristic-based decision
1034
+
1035
+ # Priority 1: Strong color evidence overrides everything.
1036
+ if visuals["is_mostly_white"] and not visuals["is_mostly_blue"]:
1037
+ class_name = "product_candidate" # It's a white printer device
1038
+ confidence = 0.95 # High confidence in color heuristic
1039
+ elif visuals["is_mostly_blue"]:
1040
+ class_name = "box_candidate" # It's a blue product box
1041
+ confidence = 0.95
1042
+
1043
+ # Priority 2: If color is ambiguous, use shelf location as a strong hint.
1044
+ if not class_name:
1045
+ if shelf_level == "middle":
1046
+ class_name = "product_candidate"
1047
+ confidence = 0.85
1048
+ elif shelf_level == "bottom":
1049
+ class_name = "box_candidate"
1050
+ confidence = 0.85
1051
+
1052
+ # Priority 3 (Fallback): If still undecided, use the original CLIP score.
1053
+ if not class_name:
1054
+ if s_printer > s_box:
1055
+ class_name = "product_candidate"
1056
+ confidence = s_printer
1057
+ else:
1058
+ class_name = "box_candidate"
1059
+ confidence = s_box
1060
+
1061
+ final_class_id = CID[class_name]
1062
+ final_proposals.append(
1063
+ DetectionBox(
1064
+ x1=x1, y1=y1, x2=x2, y2=y2,
1065
+ confidence=confidence,
1066
+ class_id=final_class_id,
1067
+ class_name=class_name,
1068
+ area=area,
1069
+ ocr_text=p.get('ocr_text')
1070
+ )
1071
+ )
1072
+
1073
+ except Exception as e:
1074
+ self.logger.error(f"Failed to classify proposal with heuristics/CLIP: {e}")
1075
+
1076
+ return final_proposals
1077
+
1078
+ # --------------------- merge/cleanup ------------------------------
1079
+ def _merge(self, dets: List[DetectionBox], iou_same=0.3) -> List[DetectionBox]:
1080
+ """Enhanced merge with size-aware logic"""
1081
+ dets = sorted(dets, key=lambda d: (d.class_name, -d.confidence, -d.area))
1082
+ out = []
1083
+
1084
+ for d in dets:
1085
+ placed = False
1086
+ for m in out:
1087
+ if d.class_name == m.class_name:
1088
+ iou = self._iou(d, m)
1089
+
1090
+ # Different merge strategies based on class
1091
+ if d.class_name == "box_candidate":
1092
+ # More aggressive merging for boxes (they're often tightly packed)
1093
+ merge_threshold = 0.25
1094
+ elif d.class_name == "product_candidate":
1095
+ # Conservative merging for printers (they're usually separate)
1096
+ merge_threshold = 0.4
1097
+ else:
1098
+ merge_threshold = iou_same
1099
+
1100
+ if iou > merge_threshold:
1101
+ # Merge by taking the union
1102
+ m.x1 = min(m.x1, d.x1)
1103
+ m.y1 = min(m.y1, d.y1)
1104
+ m.x2 = max(m.x2, d.x2)
1105
+ m.y2 = max(m.y2, d.y2)
1106
+ m.area = (m.x2 - m.x1) * (m.y2 - m.y1)
1107
+ m.confidence = max(m.confidence, d.confidence)
1108
+ placed = True
1109
+ print(f" 🔄 Merged {d.class_name} with IoU={iou:.3f}")
1110
+ break
1111
+
1112
+ if not placed:
1113
+ out.append(d)
1114
+
1115
+ return out
1116
+
1117
+ # ------------------------------ debug ------------------------------------
1118
+ def _rectangle_dashed(self, img, pt1, pt2, color, thickness=2, gap=9):
1119
+ x1, y1 = pt1
1120
+ x2, y2 = pt2
1121
+ # top
1122
+ for x in range(x1, x2, gap * 2):
1123
+ cv2.line(img, (x, y1), (min(x + gap, x2), y1), color, thickness)
1124
+ # bottom
1125
+ for x in range(x1, x2, gap * 2):
1126
+ cv2.line(img, (x, y2), (min(x + gap, x2), y2), color, thickness)
1127
+ # left
1128
+ for y in range(y1, y2, gap * 2):
1129
+ cv2.line(img, (x1, y), (x1, min(y + gap, y2)), color, thickness)
1130
+ # right
1131
+ for y in range(y1, y2, gap * 2):
1132
+ cv2.line(img, (x2, y), (x2, min(y + gap, y2)), color, thickness)
1133
+
1134
+ def _draw_corners(self, img, pt1, pt2, color, length=12, thickness=2):
1135
+ x1, y1 = pt1
1136
+ x2, y2 = pt2
1137
+ # TL
1138
+ cv2.line(img, (x1, y1), (x1 + length, y1), color, thickness)
1139
+ cv2.line(img, (x1, y1), (x1, y1 + length), color, thickness)
1140
+ # TR
1141
+ cv2.line(img, (x2, y1), (x2 - length, y1), color, thickness)
1142
+ cv2.line(img, (x2, y1), (x2, y1 + length), color, thickness)
1143
+ # BL
1144
+ cv2.line(img, (x1, y2), (x1 + length, y2), color, thickness)
1145
+ cv2.line(img, (x1, y2), (x1, y2 - length), color, thickness)
1146
+ # BR
1147
+ cv2.line(img, (x2, y2), (x2 - length, y2), color, thickness)
1148
+ cv2.line(img, (x2, y2), (x2, y2 - length), color, thickness)
1149
+
1150
+ def _draw_phase_areas(self, img, props, roi_box, show_labels=True):
1151
+ """
1152
+ Draw per-phase borders (no fill). Thickness encodes confidence.
1153
+ poster_high = magenta (solid), high_confidence = green (solid), aggressive = orange (dashed).
1154
+ """
1155
+ phase_colors = {
1156
+ "poster_high": (200, 0, 200), # BGR
1157
+ "high_confidence": (0, 220, 0),
1158
+ "aggressive": (0, 165, 255),
1159
+ }
1160
+ dashed = {"poster_high": False, "high_confidence": False, "aggressive": True}
1161
+
1162
+ # --- legend counts
1163
+ counts = Counter(p.get("phase", "aggressive") for p in props)
1164
+
1165
+ # --- draw ROI
1166
+ rx1, ry1, rx2, ry2 = roi_box
1167
+ cv2.rectangle(img, (rx1, ry1), (rx2, ry2), (0, 255, 0), 2)
1168
+
1169
+ # --- per-proposal borders
1170
+ for p in props:
1171
+ x1, y1, x2, y2 = p["box"]
1172
+ phase = p.get("phase", "aggressive")
1173
+ conf = float(p.get("confidence", 0.0))
1174
+ color = phase_colors.get(phase, (180, 180, 180))
1175
+
1176
+ # thickness: 1..5 with a gentle curve so small conf doesn't vanish
1177
+ t = max(1, min(5, int(round(1 + 4 * math.sqrt(max(0.0, min(conf, 1.0)))))))
1178
+
1179
+ if dashed.get(phase, False):
1180
+ self._rectangle_dashed(img, (x1, y1), (x2, y2), color, thickness=t, gap=9)
1181
+ else:
1182
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, t)
1183
+
1184
+ # add subtle phase corners to help when borders overlap
1185
+ self._draw_corners(img, (x1, y1), (x2, y2), color, length=10, thickness=max(1, t - 1))
1186
+
1187
+ if show_labels:
1188
+ lbl = f"{phase.split('_')[0][:1].upper()} {conf:.2f}"
1189
+ ty = max(12, y1 - 6)
1190
+ cv2.putText(img, lbl, (x1 + 2, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
1191
+
1192
+ # --- legend (top-left of ROI)
1193
+ legend_items = [("poster_high", "Poster"), ("high_confidence", "High"), ("aggressive", "Agg")]
1194
+ lx, ly = rx1 + 6, max(18, ry1 + 16)
1195
+ for key, name in legend_items:
1196
+ col = phase_colors[key]
1197
+ cv2.rectangle(img, (lx, ly - 10), (lx + 18, ly - 2), col, -1)
1198
+ text = f"{name}: {counts.get(key, 0)}"
1199
+ cv2.putText(img, text, (lx + 24, ly - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1, cv2.LINE_AA)
1200
+ ly += 16
1201
+
1202
+ return img
1203
+
1204
+ def _draw_yolo(self, img, props, roi_box, shelf_lines):
1205
+ """
1206
+ Draw raw YOLO detections with detailed labels
1207
+ """
1208
+ rx1, ry1, rx2, ry2 = roi_box
1209
+
1210
+ # Draw ROI box
1211
+ cv2.rectangle(img, (rx1, ry1), (rx2, ry2), (0, 255, 0), 3)
1212
+ cv2.putText(img, f"ROI: {rx2-rx1}x{ry2-ry1}", (rx1, ry1-10),
1213
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
1214
+
1215
+ # Draw shelf lines
1216
+ for i, y in enumerate(shelf_lines):
1217
+ cv2.line(img, (rx1, y), (rx2, y), (0, 255, 255), 2)
1218
+ cv2.putText(img, f"Shelf{i+1}", (rx1+5, y-5),
1219
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 255), 1)
1220
+
1221
+ # Color mapping for retail candidates
1222
+ candidate_colors = {
1223
+ "promotional_graphic": (255, 0, 255), # Magenta
1224
+ "printer": (255, 140, 0), # Orange
1225
+ "tv": (0, 200, 0), # Green
1226
+ "product_candidate": (200, 200, 0), # Yellow
1227
+ "product_box": (0, 140, 255), # Blue
1228
+ "small_object": (128, 128, 128), # Gray
1229
+ "ink_bottle": (160, 0, 200), # Purple
1230
+ }
1231
+
1232
+ for p in props:
1233
+ (x1, y1, x2, y2) = p["box"]
1234
+
1235
+ # Choose color based on primary retail candidate
1236
+ candidates = p.get("retail_candidates", ["unknown"])
1237
+ primary_candidate = candidates[0] if candidates else "unknown"
1238
+ color = candidate_colors.get(primary_candidate, (255, 255, 255))
1239
+
1240
+ # Draw detection
1241
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
1242
+
1243
+ # Enhanced label
1244
+ idx = p["raw_index"]
1245
+ yolo_class = p["yolo_label"]
1246
+ conf = p["yolo_conf"]
1247
+ area_pct = p["area_ratio"] * 100
1248
+
1249
+ label1 = f"#{idx} {yolo_class}→{primary_candidate}"
1250
+ label2 = f"conf:{conf:.3f} area:{area_pct:.1f}%"
1251
+
1252
+ cv2.putText(img, label1, (x1, max(15, y1 - 5)),
1253
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
1254
+ cv2.putText(img, label2, (x1, max(30, y1 + 15)),
1255
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1, cv2.LINE_AA)
1256
+
1257
+ return img
1258
+
1259
+ def _draw_phase1(self, img, roi_box, shelf_lines, dets, ad_box=None):
1260
+ """
1261
+ FIXED: Phase-1 debug drawing with better info
1262
+ """
1263
+ rx1, ry1, rx2, ry2 = roi_box
1264
+ cv2.rectangle(img, (rx1, ry1), (rx2, ry2), (0, 255, 0), 2)
1265
+
1266
+ for y in shelf_lines:
1267
+ cv2.line(img, (rx1, y), (rx2, y), (0, 255, 255), 2)
1268
+
1269
+ colors = {
1270
+ "promotional_candidate": (0, 200, 0),
1271
+ "product_candidate": (255, 140, 0),
1272
+ "box_candidate": (0, 140, 255),
1273
+ "price_tag": (255, 0, 255),
1274
+ }
1275
+
1276
+ for i, d in enumerate(dets, 1):
1277
+ c = colors.get(d.class_name, (200, 200, 200))
1278
+ cv2.rectangle(img, (d.x1, d.y1), (d.x2, d.y2), c, 2)
1279
+
1280
+ # Enhanced label with detection info
1281
+ w, h = d.x2 - d.x1, d.y2 - d.y1
1282
+ area_pct = (d.area / (img.shape[0] * img.shape[1])) * 100
1283
+ aspect = w / max(h, 1)
1284
+ center_y = (d.y1 + d.y2) / 2
1285
+
1286
+ print(f" #{i:2d}: {d.class_name:20s} conf={d.confidence:.3f} "
1287
+ f"area={area_pct:.2f}% AR={aspect:.2f} center_y={center_y:.0f}")
1288
+
1289
+ label = f"#{i} {d.class_name} {d.confidence:.2f}"
1290
+ cv2.putText(img, label, (d.x1, max(15, d.y1 - 4)),
1291
+ cv2.FONT_HERSHEY_SIMPLEX, 0.45, c, 1, cv2.LINE_AA)
1292
+
1293
+ if ad_box is not None:
1294
+ cv2.rectangle(img, (ad_box[0], ad_box[1]), (ad_box[2], ad_box[3]), (0, 255, 128), 2)
1295
+ cv2.putText(
1296
+ img, "poster_roi",
1297
+ (ad_box[0], max(12, ad_box[1] - 4)),
1298
+ cv2.FONT_HERSHEY_SIMPLEX,
1299
+ 0.4, (0, 255, 128), 1, cv2.LINE_AA,
1300
+ )
1301
+
1302
+ return img
1303
+
1304
+
1305
+ class PlanogramCompliancePipeline(AbstractPipeline):
1306
+ """
1307
+ Pipeline for planogram compliance checking.
1308
+
1309
+ 3-Step planogram compliance pipeline:
1310
+ Step 1: Object Detection (YOLO/ResNet)
1311
+ Step 2: LLM Object Identification with Reference Images
1312
+ Step 3: Planogram Comparison and Compliance Verification
1313
+ """
1314
+ def __init__(
1315
+ self,
1316
+ planogram_config: PlanogramConfig,
1317
+ llm: Any = None,
1318
+ llm_provider: str = "google",
1319
+ llm_model: Optional[str] = None,
1320
+ **kwargs: Any
1321
+ ):
1322
+ """
1323
+ Initialize the 3-step pipeline
1324
+
1325
+ Args:
1326
+ llm_provider: LLM provider for identification
1327
+ llm_model: Specific LLM model
1328
+ api_key: API key
1329
+ detection_model: Object detection model to use
1330
+ """
1331
+ # Endcap geometry defaults (can be tuned per program)
1332
+ geometry = planogram_config.endcap_geometry
1333
+ self.endcap_aspect_ratio = geometry.aspect_ratio
1334
+ self.left_margin_ratio = geometry.left_margin_ratio
1335
+ self.right_margin_ratio = geometry.right_margin_ratio
1336
+ self.top_margin_ratio = geometry.top_margin_ratio
1337
+ self.bottom_margin_ratio = geometry.bottom_margin_ratio
1338
+ self.inter_shelf_padding = geometry.inter_shelf_padding
1339
+
1340
+ # saving the planogram config for later use
1341
+ self.planogram_config = planogram_config
1342
+ super().__init__(
1343
+ llm=llm,
1344
+ llm_provider=llm_provider,
1345
+ llm_model=llm_model,
1346
+ **kwargs
1347
+ )
1348
+ reference_images = planogram_config.reference_images
1349
+ references = list(reference_images.values()) if reference_images else None
1350
+ # Initialize the generic shape detector
1351
+ self.shape_detector = RetailDetector(
1352
+ yolo_model=planogram_config.detection_model,
1353
+ conf=planogram_config.confidence_threshold,
1354
+ llm=self.llm,
1355
+ device="cuda" if torch.cuda.is_available() else "cpu",
1356
+ reference_images=references
1357
+ )
1358
+ self.logger.debug(
1359
+ f"Initialized RetailDetector with {planogram_config.detection_model}"
1360
+ )
1361
+ self.reference_images = reference_images or {}
1362
+ self.confidence_threshold = planogram_config.confidence_threshold
1363
+
1364
+ async def detect_objects_and_shelves(
1365
+ self,
1366
+ image: Image,
1367
+ image_array: np.ndarray,
1368
+ endcap: Detection,
1369
+ ad: Optional[Detection] = None,
1370
+ brand: Optional[Detection] = None,
1371
+ panel_text: Optional[Detection] = None,
1372
+ planogram_description: Optional[PlanogramDescription] = None
1373
+ ):
1374
+ self.logger.debug(
1375
+ "Step 1: Detecting generic shapes and boundaries..."
1376
+ )
1377
+
1378
+ det_out = await self.shape_detector.detect(
1379
+ image=image,
1380
+ image_array=image_array,
1381
+ endcap=endcap,
1382
+ ad=ad,
1383
+ planogram=planogram_description,
1384
+ debug_yolo="/tmp/data/yolo_raw.png",
1385
+ debug_phase1="/tmp/data/yolo_phase1_debug.png",
1386
+ debug_phases="/tmp/data/yolo_phases_debug.png",
1387
+ )
1388
+
1389
+ shelves = det_out["shelves"] # {'top': DetectionBox(...), 'middle': ...}
1390
+ proposals = det_out["proposals"] # List[DetectionBox]
1391
+
1392
+ print("PROPOSALS:", proposals)
1393
+ print("SHELVES:", shelves)
1394
+
1395
+ h, w = image_array.shape[:2]
1396
+ if brand:
1397
+ bx1, by1, bx2, by2 = brand.bbox.get_pixel_coordinates(width=w, height=h)
1398
+ proposals.append(
1399
+ DetectionBox(
1400
+ x1=bx1, y1=by1, x2=bx2, y2=by2,
1401
+ confidence=brand.confidence,
1402
+ class_id=CID["brand_logo"],
1403
+ class_name="brand_logo",
1404
+ area=(bx2 - bx1) * (by2 - by1),
1405
+ ocr_text=brand.content
1406
+ )
1407
+ )
1408
+ print(f" + Injected brand_logo: '{brand.content}'")
1409
+
1410
+ if panel_text:
1411
+ tx1, ty1, tx2, ty2 = panel_text.bbox.get_pixel_coordinates(width=w, height=h)
1412
+ proposals.append(
1413
+ DetectionBox(
1414
+ x1=tx1, y1=ty1, x2=tx2, y2=ty2,
1415
+ confidence=panel_text.confidence,
1416
+ class_id=CID["poster_text"],
1417
+ class_name="poster_text",
1418
+ area=(tx2 - tx1) * (ty2 - ty1),
1419
+ ocr_text=panel_text.content.replace('.', ' ')
1420
+ )
1421
+ )
1422
+ print(f" + Injected poster_text: '{panel_text.content}'")
1423
+
1424
+ # --- IMPORTANT: use Phase-1 shelf bands (not %-of-image buckets) ---
1425
+ shelf_regions = self._materialize_shelf_regions(shelves, proposals, planogram_description)
1426
+
1427
+ detections = list(proposals)
1428
+
1429
+ self.logger.debug(
1430
+ "Found %d objects in %d shelf regions", len(detections), len(shelf_regions)
1431
+ )
1432
+
1433
+ self.logger.debug("Found %d objects in %d shelf regions",
1434
+ len(detections), len(shelf_regions))
1435
+ return shelf_regions, detections
1436
+
1437
+ def _materialize_shelf_regions(
1438
+ self,
1439
+ shelves_dict: Dict[str, DetectionBox],
1440
+ dets: List[DetectionBox],
1441
+ planogram_description: Optional[PlanogramDescription] = None
1442
+ ) -> List[ShelfRegion]:
1443
+ """Turn Phase-1 shelf bands into ShelfRegion objects and assign detections by y-overlap."""
1444
+ def y_overlap(a1, a2, b1, b2) -> int:
1445
+ return max(0, min(a2, b2) - max(a1, b1))
1446
+
1447
+ regions: List[ShelfRegion] = []
1448
+
1449
+ # Iterate through the shelves defined in the planogram config, in their specified order.
1450
+ for shelf_config in planogram_description.shelves:
1451
+ level = shelf_config.level
1452
+ band = shelves_dict.get(level)
1453
+ if not band:
1454
+ self.logger.warning(
1455
+ f"Shelf '{level}' is defined in the planogram but was not detected in the image."
1456
+ )
1457
+ continue
1458
+
1459
+ # Find all object proposals that vertically overlap with this shelf's detected band.
1460
+ # An object belongs to the shelf if any part of it is within the shelf's y-range.
1461
+ objs = [d for d in dets if y_overlap(d.y1, d.y2, band.y1, band.y2) > 0]
1462
+
1463
+ # If no objects were found on this shelf, we don't need to create a region for it.
1464
+ if objs:
1465
+ x1 = min(o.x1 for o in objs)
1466
+ x2 = max(o.x2 for o in objs)
1467
+ else:
1468
+ # Use band boundaries if no objects
1469
+ x1, x2 = band.x1, band.x2
1470
+
1471
+ # Create a new bounding box for the ShelfRegion.
1472
+ # The Y coordinates are fixed by the detected shelf band.
1473
+ # The X coordinates are calculated as the min/max extent of the objects on that shelf.
1474
+ y1 = band.y1
1475
+ y2 = band.y2
1476
+
1477
+ bbox = DetectionBox(
1478
+ x1=x1, y1=y1, x2=x2, y2=y2,
1479
+ confidence=1.0,
1480
+ class_id=CID["shelf_region"],
1481
+ class_name="shelf_region",
1482
+ area=(x2 - x1) * (y2 - y1)
1483
+ )
1484
+
1485
+ # Create the final ShelfRegion object.
1486
+ regions.append(
1487
+ ShelfRegion(
1488
+ shelf_id=f"{level}_shelf",
1489
+ bbox=bbox,
1490
+ level=level,
1491
+ objects=objs
1492
+ )
1493
+ )
1494
+
1495
+ return regions
1496
+
1497
+ async def identify_objects_with_references(
1498
+ self,
1499
+ image: Union[str, Path, Image.Image],
1500
+ detections: List[DetectionBox],
1501
+ shelf_regions: List[ShelfRegion],
1502
+ reference_images: List[Union[str, Path, Image.Image]],
1503
+ prompt: str
1504
+ ) -> List[IdentifiedProduct]:
1505
+ """
1506
+ Step 2: Use LLM to identify detected objects using reference images
1507
+
1508
+ Args:
1509
+ image: Original endcap image
1510
+ detections: Object detections from Step 1
1511
+ shelf_regions: Shelf regions from Step 1
1512
+ reference_images: Reference product images
1513
+ prompt: Prompt for object identification
1514
+
1515
+ Returns:
1516
+ List of identified products
1517
+ """
1518
+
1519
+ self.logger.debug(
1520
+ f"Starting identification with {len(detections)} detections"
1521
+ )
1522
+ # If no detections, return empty list
1523
+ if not detections:
1524
+ self.logger.warning("No detections to identify")
1525
+ return []
1526
+
1527
+
1528
+ pil_image = self._get_image(image)
1529
+
1530
+ # Create annotated image showing detection boxes
1531
+ effective_dets = [
1532
+ d for d in detections if d.class_name not in {"slot", "shelf_region", "price_tag", "fact_tag"}
1533
+ ]
1534
+ annotated_image = self._create_annotated_image(pil_image, effective_dets)
1535
+
1536
+ async with self.llm as client:
1537
+ try:
1538
+ extra_refs = {
1539
+ "annotated_image": annotated_image,
1540
+ **reference_images
1541
+ }
1542
+ identified_products = await client.image_identification(
1543
+ prompt=self._build_gemini_identification_prompt(
1544
+ effective_dets,
1545
+ shelf_regions,
1546
+ partial_prompt=prompt
1547
+ ),
1548
+ image=image,
1549
+ detections=effective_dets,
1550
+ shelf_regions=shelf_regions,
1551
+ reference_images=extra_refs,
1552
+ temperature=0.0
1553
+ )
1554
+ identified_products = await self._augment_products_with_box_ocr(
1555
+ image,
1556
+ identified_products
1557
+ )
1558
+ for product in identified_products:
1559
+ if product.product_type == "promotional_graphic":
1560
+ if lines := await self._extract_text_from_region(image, product.detection_box):
1561
+ snippet = " ".join(lines)[:120]
1562
+ product.visual_features = (product.visual_features or []) + [f"ocr:{snippet}"]
1563
+ return identified_products
1564
+
1565
+ except Exception as e:
1566
+ self.logger.error(f"Error in structured identification: {e}")
1567
+ traceback.print_exc()
1568
+ raise
1569
+
1570
+ def _guess_et_model_from_text(self, text: str) -> Optional[str]:
1571
+ """
1572
+ Find Epson EcoTank model tokens in text.
1573
+ Returns normalized like 'et-4950' (device) or 'et-2980', etc.
1574
+ """
1575
+ if not text:
1576
+ return None
1577
+ t = text.lower().replace(" ", "")
1578
+ # common variants: et-4950, et4950, et – 4950, etc.
1579
+ m = re.search(r"et[-]?\s?(\d{4})", t)
1580
+ if not m:
1581
+ return None
1582
+ num = m.group(1)
1583
+ # Accept only models we care about (tighten if needed)
1584
+ if num in {"2980", "3950", "4950"}:
1585
+ return f"et-{num}"
1586
+ return None
1587
+
1588
+
1589
+ def _maybe_brand_from_text(self, text: str) -> Optional[str]:
1590
+ if not text:
1591
+ return None
1592
+ t = text.lower()
1593
+ if "epson" in t or "ecotank" in t:
1594
+ return "Epson"
1595
+ if 'hisense' in t or "canvastv" in t:
1596
+ return "Hisense"
1597
+ if "firetv" in t or "fire tv" in t:
1598
+ return "Amazon"
1599
+ if "google tv" in t or "chromecast" in t:
1600
+ return "Google"
1601
+ return None
1602
+
1603
+ def _normalize_ocr_text(self, s: str) -> str:
1604
+ """
1605
+ Make OCR text match-friendly:
1606
+ - Unicode normalize (NFKC), strip diacritics
1607
+ - Replace fancy dashes/quotes with spaces
1608
+ - Remove non-alnum except spaces, collapse whitespace
1609
+ - Lowercase
1610
+ """
1611
+ if not s:
1612
+ return ""
1613
+ s = unicodedata.normalize("NFKC", s)
1614
+ # strip accents
1615
+ s = "".join(ch for ch in unicodedata.normalize("NFKD", s) if not unicodedata.combining(ch))
1616
+ # unify punctuation to spaces
1617
+ s = re.sub(r"[—–‐-‒–—―…“”\"'·•••·•—–/\\|_=+^°™®©§]", " ", s)
1618
+ # keep letters/digits/spaces
1619
+ s = re.sub(r"[^A-Za-z0-9 ]+", " ", s)
1620
+ # collapse
1621
+ s = re.sub(r"\s+", " ", s).strip().lower()
1622
+ return s
1623
+
1624
+ async def _augment_products_with_box_ocr(
1625
+ self,
1626
+ image: Union[str, Path, Image.Image],
1627
+ products: List[IdentifiedProduct]
1628
+ ) -> List[IdentifiedProduct]:
1629
+ """Add OCR-derived evidence to boxes/printers and fix product_model when we see ET-xxxx."""
1630
+ for p in products:
1631
+ if not p.detection_box:
1632
+ continue
1633
+ # normalize product brand logo with OCR or content from detection if is null:
1634
+ if getattr(p.detection_box, 'class_name', None) == 'brand_logo' and not getattr(p, 'brand', None):
1635
+ if p.detection_box.ocr_text:
1636
+ brand = self._maybe_brand_from_text(p.detection_box.ocr_text)
1637
+ if brand:
1638
+ try:
1639
+ p.brand = brand # only if IdentifiedProduct has 'brand'
1640
+ except Exception:
1641
+ if not p.visual_features:
1642
+ p.visual_features = []
1643
+ p.visual_features.append(f"brand:{brand}")
1644
+ if p.product_type in {"product_box", "printer"}:
1645
+ lines = await self._extract_text_from_region(image, p.detection_box, mode="model")
1646
+ if lines:
1647
+ # Keep some OCR as visual evidence (don’t explode the list)
1648
+ snippet = " ".join(lines)[:120]
1649
+ if not p.visual_features:
1650
+ p.visual_features = []
1651
+ p.visual_features.append(f"ocr:{snippet}")
1652
+
1653
+ # Brand hint
1654
+ brand = self._maybe_brand_from_text(snippet)
1655
+ if brand and not getattr(p, "brand", None):
1656
+ try:
1657
+ p.brand = brand # only if IdentifiedProduct has 'brand'
1658
+ except Exception:
1659
+ # If the model doesn’t have brand, keep it as a feature.
1660
+ p.visual_features.append(f"brand:{brand}")
1661
+
1662
+ # Model from OCR
1663
+ model = self._guess_et_model_from_text(snippet)
1664
+ if model:
1665
+ # Normalize to your scheme:
1666
+ target = model.upper()
1667
+ # If missing or mismatched, replace
1668
+ if not p.product_model:
1669
+ p.product_model = target
1670
+ else:
1671
+ # If current looks generic/incorrect, fix it
1672
+ cur = (p.product_model or "").lower()
1673
+ if "et-" in target.lower() and ("et-" not in cur or "box" in target.lower() and "box" not in cur):
1674
+ p.product_model = target
1675
+ elif p.product_type == "promotional_graphic":
1676
+ if lines := await self._extract_text_from_region(image, p.detection_box):
1677
+ snippet = " ".join(lines)[:160]
1678
+ p.visual_features = (p.visual_features or []) + [f"ocr:{snippet}"]
1679
+ # keep a normalized text blob
1680
+ joined = " ".join(lines)
1681
+ if norm := self._normalize_ocr_text(joined):
1682
+ p.visual_features.append(norm)
1683
+ for ln in lines:
1684
+ if ln and (nln := self._normalize_ocr_text(ln)) and nln not in p.visual_features:
1685
+ p.visual_features.append(nln)
1686
+
1687
+ # NEW: infer brand from OCR/features if missing
1688
+ if not getattr(p, "brand", None):
1689
+ brand = self._maybe_brand_from_text(joined)
1690
+ if not brand and p.visual_features:
1691
+ vf_blob = " ".join(p.visual_features)
1692
+ brand = self._maybe_brand_from_text(vf_blob)
1693
+ if brand:
1694
+ p.brand = brand
1695
+ return products
1696
+
1697
+ async def _extract_text_from_region(
1698
+ self,
1699
+ image: Union[str, Path, Image.Image],
1700
+ detection_box: DetectionBox,
1701
+ mode: str = "generic", # "generic" | "model"
1702
+ ) -> List[str]:
1703
+ """Extract text from a region with OCR.
1704
+ - generic: multi-pass (psm 6 & 4) + unsharp + binarize
1705
+ - model : tuned to catch ET-xxxx
1706
+ Returns lines + normalized variants so TextMatcher has more chances.
1707
+ """
1708
+ try:
1709
+ pil_image = Image.open(image) if isinstance(image, (str, Path)) else image
1710
+ pad = 10
1711
+ x1 = max(0, detection_box.x1 - pad)
1712
+ y1 = max(0, detection_box.y1 - pad)
1713
+ x2 = min(pil_image.width - 1, detection_box.x2 + pad)
1714
+ y2 = min(pil_image.height - 1, detection_box.y2 + pad)
1715
+
1716
+ # ENSURE VALID CROP COORDINATES
1717
+ if x1 >= x2:
1718
+ x2 = x1 + 10
1719
+ if y1 >= y2:
1720
+ y2 = y1 + 10
1721
+
1722
+ crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
1723
+
1724
+ def _prep(arr):
1725
+ g = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
1726
+ g = cv2.resize(g, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_CUBIC)
1727
+ blur = cv2.GaussianBlur(g, (0, 0), sigmaX=1.0)
1728
+ sharp = cv2.addWeighted(g, 1.6, blur, -0.6, 0)
1729
+ _, th = cv2.threshold(sharp, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
1730
+ return th
1731
+
1732
+ if mode == "model":
1733
+ th = _prep(np.array(crop_rgb))
1734
+ crop = Image.fromarray(th).convert("L")
1735
+ cfg = "--oem 3 --psm 6 -l eng -c tessedit_char_whitelist=ETet0123456789-ABCDEFGHIJKLMNOPQRSTUVWXYZ"
1736
+ raw = pytesseract.image_to_string(crop, config=cfg)
1737
+ lines = [ln.strip() for ln in raw.splitlines() if ln.strip()]
1738
+ else:
1739
+ arr = np.array(crop_rgb)
1740
+ th = _prep(arr)
1741
+ # two passes help for 'Goodbye Cartridges' on light box
1742
+ raw1 = pytesseract.image_to_string(Image.fromarray(th), config="--psm 6 -l eng")
1743
+ raw2 = pytesseract.image_to_string(Image.fromarray(th), config="--psm 4 -l eng")
1744
+ raw = raw1 + "\n" + raw2
1745
+ lines = [ln.strip() for ln in raw.splitlines() if ln.strip()]
1746
+
1747
+ # Add normalized variants to help TextMatcher:
1748
+ # - lowercase, punctuation stripped
1749
+ # - a single combined line
1750
+ def norm(s: str) -> str:
1751
+ s = s.lower()
1752
+ s = re.sub(r"[^a-z0-9\s]", " ", s) # drop punctuation like colons
1753
+ s = re.sub(r"\s+", " ", s).strip()
1754
+ return s
1755
+
1756
+ variants = [norm(ln) for ln in lines if ln]
1757
+ if variants:
1758
+ variants.append(norm(" ".join(lines)))
1759
+
1760
+ # merge unique while preserving originals first
1761
+ out = lines[:]
1762
+ for v in variants:
1763
+ if v and v not in out:
1764
+ out.append(v)
1765
+
1766
+ return out
1767
+
1768
+ except Exception as e:
1769
+ self.logger.error(f"Text extraction failed: {e}")
1770
+ return []
1771
+
1772
+ def _get_image(
1773
+ self,
1774
+ image: Union[str, Path, Image.Image]
1775
+ ) -> Image.Image:
1776
+ """Load image from path or return copy if already PIL"""
1777
+
1778
+ if isinstance(image, (str, Path)):
1779
+ pil_image = Image.open(image).copy()
1780
+ else:
1781
+ pil_image = image.copy()
1782
+ return pil_image
1783
+
1784
+ def _create_annotated_image(
1785
+ self,
1786
+ image: Image.Image,
1787
+ detections: List[DetectionBox]
1788
+ ) -> Image.Image:
1789
+ """Create an annotated image with detection boxes and IDs"""
1790
+
1791
+ draw = ImageDraw.Draw(image)
1792
+
1793
+ for i, detection in enumerate(detections):
1794
+ # Draw bounding box
1795
+ draw.rectangle(
1796
+ [(detection.x1, detection.y1), (detection.x2, detection.y2)],
1797
+ outline="red", width=2
1798
+ )
1799
+
1800
+ # Add detection ID and confidence
1801
+ label = f"ID:{i+1} ({detection.confidence:.2f})"
1802
+ draw.text((detection.x1, detection.y1 - 20), label, fill="red")
1803
+
1804
+ return image
1805
+
1806
+ def _build_gemini_identification_prompt(
1807
+ self,
1808
+ detections: List[DetectionBox],
1809
+ shelf_regions: List[ShelfRegion],
1810
+ partial_prompt: str
1811
+ ) -> str:
1812
+ """Builds a more detailed prompt to help Gemini differentiate similar products."""
1813
+ detection_lines = ["\nDETECTED OBJECTS (with pre-assigned IDs):"]
1814
+ if detections:
1815
+ for i, detection in enumerate(detections, 1):
1816
+ detection_lines.append(
1817
+ f"ID {i}: Initial class '{detection.class_name}' at bbox ({detection.x1},{detection.y1},{detection.x2},{detection.y2})"
1818
+ )
1819
+ else:
1820
+ detection_lines.append("None")
1821
+
1822
+ shelf_definitions = ["\n**VALID SHELF NAMES & LOCATIONS (Ground Truth):**"]
1823
+ valid_shelf_names = []
1824
+ num_detections = len(detections)
1825
+ for shelf in shelf_regions:
1826
+ # if shelf.level in ['header', 'middle', 'bottom']:
1827
+ valid_shelf_names.append(f"'{shelf.level}'")
1828
+ shelf_definitions.append(f"- Shelf '{shelf.level}': Covers the vertical pixel range from y={shelf.bbox.y1} to y={shelf.bbox.y2}.")
1829
+ shelf_definitions.append(f"\n**RULE:** For the `shelf_location` field, you MUST use one of these exact names: {', '.join(valid_shelf_names)}.")
1830
+
1831
+ # REVISED: Enhanced prompt with new rules
1832
+ prompt = f"""
1833
+ You are an expert at identifying retail products in planogram displays.
1834
+ I have provided an image of a retail endcap, labeled reference images, and a list of {num_detections} pre-detected objects.
1835
+
1836
+ {''.join(detection_lines)}
1837
+ {''.join(shelf_definitions)}
1838
+
1839
+ **YOUR TASK:**
1840
+ For each distinct product, you must first analyze its visual features according to the guide, state your reasoning, and then provide the final identification.
1841
+
1842
+ """
1843
+ partial_prompt = partial_prompt.strip().format(
1844
+ num_detections=num_detections,
1845
+ shelf_names=", ".join(valid_shelf_names)
1846
+ )
1847
+ prompt += partial_prompt
1848
+ prompt += f"""
1849
+ ---
1850
+
1851
+ **JSON OUTPUT FORMAT:**
1852
+ Respond with a single JSON object. For each **distinct product** you identify, provide an entry in the 'detections' list.
1853
+
1854
+ - **detection_id**: The pre-detected ID number, or `null` for newly found items.
1855
+ - **detection_box**: **REQUIRED** if `detection_id` is `null`. An array of four numbers `[x1, y1, x2, y2]`.
1856
+ - **product_type**: printer, tv, product_box, fact_tag, promotional_graphic, or ink_bottle.
1857
+ - **product_model**: Follow naming rules above.
1858
+ - **confidence**: Your confidence (0.0-1.0).
1859
+ - **visual_features**: List of key visual features as if device is turned on, color, size, brightness or any other visual features.
1860
+ - **reasoning**: A brief sentence explaining your identification based on the visual guide. Example: "Reasoning: The control panel has a physical key pad, which matches the ET-3950 guide."
1861
+ - **reference_match**: Which reference image name matches (or "none").
1862
+ - **shelf_location**: {', '.join(valid_shelf_names)}.
1863
+ - **position_on_shelf**: 'left', 'center', or 'right'.
1864
+
1865
+ **!! FINAL CHECK !!**
1866
+ - Ensure your response contains **NO DUPLICATE** entries for the same physical object.
1867
+ - **CRITICAL**: Verify that any item with `detection_id: null` also includes a `detection_box`.
1868
+
1869
+ Analyze all provided images and return the complete JSON response.
1870
+ """
1871
+ return prompt
1872
+
1873
+ def _calculate_visual_feature_match(self, expected_features: List[str], detected_features: List[str]) -> float:
1874
+ """
1875
+ Enhanced visual feature matching with semantic understanding
1876
+ """
1877
+ if not expected_features:
1878
+ return 1.0 # No requirements = full match
1879
+
1880
+ if not detected_features:
1881
+ return 0.0 # No detected features but requirements exist
1882
+
1883
+ # Normalize and create keyword sets for semantic matching
1884
+ def extract_keywords(text):
1885
+ """Extract meaningful keywords from feature text"""
1886
+ text = text.lower().strip()
1887
+ # Remove common words that don't add meaning
1888
+ stop_words = {'a', 'an', 'the', 'is', 'are', 'on', 'of', 'in', 'at', 'to', 'for', 'with', 'visible', 'displayed', 'showing'}
1889
+ words = [w for w in text.split() if w not in stop_words and len(w) > 1]
1890
+ return set(words)
1891
+
1892
+ # Special semantic mappings for common concepts
1893
+ semantic_mappings = {
1894
+ 'active': ['active', 'on', 'powered', 'illuminated', 'lit'],
1895
+ 'display': ['display', 'screen', 'tv', 'television', 'monitor'],
1896
+ 'illuminated': ['illuminated', 'backlit', 'lit', 'bright', 'glowing'],
1897
+ 'logo': ['logo', 'text', 'branding', 'brand'],
1898
+ 'dynamic': ['dynamic', 'colorful', 'graphics', 'content'],
1899
+ 'official': ['official', 'partner'],
1900
+ 'white': ['white', 'large']
1901
+ }
1902
+
1903
+ def semantic_match(expected_word, detected_keywords):
1904
+ """Check if expected word semantically matches any detected keywords"""
1905
+ if expected_word in detected_keywords:
1906
+ return True
1907
+
1908
+ # Check semantic mappings
1909
+ if expected_word in semantic_mappings:
1910
+ synonyms = semantic_mappings[expected_word]
1911
+ return any(syn in detected_keywords for syn in synonyms)
1912
+
1913
+ # Check if any detected keyword contains the expected word
1914
+ return any(expected_word in keyword for keyword in detected_keywords)
1915
+
1916
+ matches = 0
1917
+ for expected in expected_features:
1918
+ expected_keywords = extract_keywords(expected)
1919
+
1920
+ # Combine all detected feature keywords
1921
+ all_detected_keywords = set()
1922
+ for detected in detected_features:
1923
+ all_detected_keywords.update(extract_keywords(detected))
1924
+
1925
+ # Check if any expected keyword has a semantic match
1926
+ feature_matched = False
1927
+ for exp_keyword in expected_keywords:
1928
+ if semantic_match(exp_keyword, all_detected_keywords):
1929
+ feature_matched = True
1930
+ break
1931
+
1932
+ if feature_matched:
1933
+ matches += 1
1934
+
1935
+ score = matches / len(expected_features)
1936
+ return score
1937
+
1938
+ def check_planogram_compliance(
1939
+ self,
1940
+ identified_products: List[IdentifiedProduct],
1941
+ planogram_description: PlanogramDescription,
1942
+ ) -> List[ComplianceResult]:
1943
+ """Check compliance of identified products against the planogram."""
1944
+ def _matches(ek, fk) -> bool:
1945
+ (e_ptype, e_base), (f_ptype, f_base) = ek, fk
1946
+ if e_ptype != f_ptype:
1947
+ return False
1948
+ if not e_base or not f_base:
1949
+ return True
1950
+ # If no base model specified in planogram, accept type-only match
1951
+ if not e_base:
1952
+ return True
1953
+ if f_base == e_base or e_base in f_base or f_base in e_base:
1954
+ return True
1955
+ if f_base == e_base:
1956
+ return True
1957
+ # NEW: allow cross-slug promo matching if synonyms overlap
1958
+ if e_ptype == "promotional_graphic":
1959
+ fam = lambda s: "canvas-tv" if "canvas-tv" in s else s
1960
+ return fam(e_base) == fam(f_base)
1961
+ # containment: allow 'et-4950' inside 'epson et-4950 bundle' etc.
1962
+ return e_base in f_base or f_base in e_base
1963
+
1964
+ results: List[ComplianceResult] = []
1965
+
1966
+ planogram_brand = planogram_description.brand.lower()
1967
+ found_brand_product = next((
1968
+ p for p in identified_products if p.brand and p.brand.lower() == planogram_brand
1969
+ ), None)
1970
+
1971
+ brand = getattr(planogram_description, 'brand', planogram_brand)
1972
+
1973
+ brand_compliance_result = BrandComplianceResult(
1974
+ expected_brand=planogram_description.brand,
1975
+ found_brand=found_brand_product.brand if found_brand_product else None,
1976
+ found=bool(found_brand_product),
1977
+ confidence=found_brand_product.confidence if found_brand_product else 0.0
1978
+ )
1979
+ brand_check_ok = brand_compliance_result.found
1980
+ by_shelf = defaultdict(list)
1981
+
1982
+ for p in identified_products:
1983
+ by_shelf[p.shelf_location].append(p)
1984
+
1985
+ for shelf_cfg in planogram_description.shelves:
1986
+ shelf_level = shelf_cfg.level
1987
+ products_on_shelf = by_shelf.get(shelf_level, [])
1988
+ expected = []
1989
+ # --- 1. Main matching loop for expected products ---
1990
+ for sp in shelf_cfg.products:
1991
+ if sp.product_type in ("fact_tag", "price_tag", "slot"):
1992
+ continue
1993
+
1994
+ e_ptype, e_base = self._canonical_expected_key(sp, brand=brand)
1995
+ expected.append((e_ptype, e_base))
1996
+
1997
+ # --- Build canonical FOUND keys for this shelf (and keep refs for reporting) ---
1998
+ found_keys = [] # list[(ptype, base_model)]
1999
+ found_lookup = [] # parallel to found_keys to map back to strings for reporting
2000
+ promos = []
2001
+ for p in products_on_shelf:
2002
+ if p.product_type in ("fact_tag", "price_tag", "slot", "brand_logo"):
2003
+ continue
2004
+ f_ptype, f_base, f_conf = self._canonical_found_key(p, brand=brand)
2005
+ found_keys.append((f_ptype, f_base))
2006
+ if p.product_type == "promotional_graphic":
2007
+ promos.append(p)
2008
+
2009
+ # for human-readable 'found_products' list later:
2010
+ label = p.product_model or p.product_type or "unknown"
2011
+ found_lookup.append((f_ptype, f_base, label))
2012
+
2013
+ # --- Matching: (ptype must match) AND (base_model equal OR base_model contained in planogram name) ---
2014
+ matched = [False] * len(expected)
2015
+ consumed = [False] * len(found_keys)
2016
+ visual_feature_scores = [] # Track visual feature matching scores
2017
+
2018
+ # Greedy 1:1 matching
2019
+ for i, ek in enumerate(expected):
2020
+ for j, fk in enumerate(found_keys):
2021
+ if matched[i] or consumed[j]:
2022
+ continue
2023
+ if _matches(ek, fk):
2024
+ matched[i] = True
2025
+ consumed[j] = True
2026
+
2027
+ # ADD VISUAL FEATURE MATCHING HERE
2028
+ # Find the corresponding ShelfProduct and IdentifiedProduct
2029
+ shelf_product = shelf_cfg.products[i] # Get the shelf product config
2030
+ identified_product = products_on_shelf[j] # Get the identified product
2031
+
2032
+ # Calculate visual feature match score
2033
+ if hasattr(shelf_product, 'visual_features') and shelf_product.visual_features:
2034
+ detected_features = getattr(identified_product, 'visual_features', []) or []
2035
+ vf_score = self._calculate_visual_feature_match(
2036
+ shelf_product.visual_features,
2037
+ detected_features
2038
+ )
2039
+ visual_feature_scores.append(vf_score)
2040
+ break
2041
+
2042
+ # Compute lists for reporting/scoring
2043
+ expected_readable = [
2044
+ f"{e_ptype}:{e_base}" if e_base else f"{e_ptype}" for (e_ptype, e_base) in expected
2045
+ ]
2046
+ found_readable = []
2047
+ for (used, (f_ptype, f_base), (_, _, original_label)) in zip(consumed, found_keys, found_lookup):
2048
+ # Keep the original label for readability but also show our canonicalization
2049
+ tag = original_label
2050
+ if f_base:
2051
+ tag = f"{original_label} [{f_ptype}:{f_base}]"
2052
+ found_readable.append(tag)
2053
+
2054
+ missing = [expected_readable[i] for i, ok in enumerate(matched) if not ok]
2055
+ # If extras not allowed, mark unexpected any unconsumed found
2056
+ unexpected = []
2057
+ if not shelf_cfg.allow_extra_products:
2058
+ for used, (f_ptype, f_base), (_, _, original_label) in zip(consumed, found_keys, found_lookup):
2059
+ if not used:
2060
+ lbl = original_label
2061
+ if f_base:
2062
+ lbl = f"{original_label} [{f_ptype}:{f_base}]"
2063
+ unexpected.append(lbl)
2064
+
2065
+ # Product score = fraction of expected matched
2066
+ basic_score = (sum(1 for ok in matched if ok) / (len(expected) or 1.0))
2067
+
2068
+ # ADD VISUAL FEATURE SCORING
2069
+ visual_feature_score = 1.0
2070
+ if visual_feature_scores:
2071
+ visual_feature_score = sum(visual_feature_scores) / len(visual_feature_scores)
2072
+
2073
+ text_results, text_score, overall_text_ok = [], 1.0, True
2074
+
2075
+ endcap = planogram_description.advertisement_endcap
2076
+ if endcap and endcap.enabled and endcap.position == shelf_level:
2077
+ if endcap.text_requirements:
2078
+ # Combine visual features from all promotional items
2079
+ all_features = []
2080
+ ocr_blocks = []
2081
+ for promo in promos:
2082
+ if getattr(promo, "visual_features", None):
2083
+ all_features.extend(promo.visual_features)
2084
+ for feat in promo.visual_features:
2085
+ if isinstance(feat, str) and feat.startswith("ocr:"):
2086
+ ocr_blocks.append(feat[4:].strip())
2087
+ # if promo have ocr_text, add that too
2088
+ ocr_text = getattr(promo.detection_box, 'ocr_text', '')
2089
+ if ocr_text:
2090
+ ocr_blocks.append(ocr_text.strip())
2091
+
2092
+ if ocr_blocks:
2093
+ ocr_norm = self._normalize_ocr_text(" ".join(ocr_blocks))
2094
+ if ocr_norm:
2095
+ all_features.append(ocr_norm)
2096
+
2097
+ # If no promotional graphics found but text required, create default failure
2098
+ if not promos and shelf_level == "header":
2099
+ self.logger.warning(
2100
+ f"No promotional graphics found on {shelf_level} shelf but text requirements exist"
2101
+ )
2102
+ overall_text_ok = False
2103
+ for text_req in endcap.text_requirements:
2104
+ text_results.append(TextComplianceResult(
2105
+ required_text=text_req.required_text,
2106
+ found=False,
2107
+ matched_features=[],
2108
+ confidence=0.0,
2109
+ match_type=text_req.match_type
2110
+ ))
2111
+ else:
2112
+ # Check text requirements against found features
2113
+ for text_req in endcap.text_requirements:
2114
+ result = TextMatcher.check_text_match(
2115
+ required_text=text_req.required_text,
2116
+ visual_features=all_features,
2117
+ match_type=text_req.match_type,
2118
+ case_sensitive=text_req.case_sensitive,
2119
+ confidence_threshold=text_req.confidence_threshold
2120
+ )
2121
+ text_results.append(result)
2122
+
2123
+ if not result.found and text_req.mandatory:
2124
+ overall_text_ok = False
2125
+
2126
+ # Calculate text compliance score
2127
+ if text_results:
2128
+ text_score = sum(r.confidence for r in text_results if r.found) / len(text_results)
2129
+
2130
+ elif shelf_level != "header":
2131
+ overall_text_ok = True
2132
+ text_score = 1.0
2133
+
2134
+ threshold = getattr(
2135
+ shelf_cfg, "compliance_threshold", planogram_description.global_compliance_threshold or 0.8
2136
+ )
2137
+
2138
+ major_unexpected = [
2139
+ p for p in unexpected if "ink" not in p.lower() and "price tag" not in p.lower()
2140
+ ]
2141
+
2142
+ # MODIFIED: Status determination logic with brand check override
2143
+ status = ComplianceStatus.NON_COMPLIANT # Default status
2144
+ if shelf_level != "header":
2145
+ if basic_score >= threshold and not major_unexpected:
2146
+ status = ComplianceStatus.COMPLIANT
2147
+ elif basic_score == 0.0 and len(expected) > 0:
2148
+ status = ComplianceStatus.MISSING
2149
+ else: # Header shelf logic
2150
+ # The brand check is now a mandatory condition for compliance
2151
+ if not brand_check_ok:
2152
+ status = ComplianceStatus.NON_COMPLIANT # OVERRIDE: Brand check failed
2153
+ elif basic_score >= threshold and not major_unexpected and overall_text_ok:
2154
+ status = ComplianceStatus.COMPLIANT
2155
+ elif basic_score == 0.0 and len(expected) > 0:
2156
+ status = ComplianceStatus.MISSING
2157
+ else:
2158
+ status = ComplianceStatus.NON_COMPLIANT
2159
+
2160
+ # MODIFIED: Combined score calculation with visual features
2161
+ # Use the existing visual_features_weight from CategoryDetectionConfig
2162
+ visual_weight = getattr(
2163
+ planogram_description,
2164
+ 'visual_features_weight',
2165
+ 0.2
2166
+ ) # Default 20%
2167
+
2168
+ if shelf_level == "header" and endcap:
2169
+ # Adjust product weight to make room for visual features
2170
+ adjusted_product_weight = endcap.product_weight * (1 - visual_weight)
2171
+ visual_feature_weight = endcap.product_weight * visual_weight
2172
+ combined_score = (
2173
+ (basic_score * adjusted_product_weight) +
2174
+ (text_score * endcap.text_weight) +
2175
+ (brand_compliance_result.confidence * getattr(endcap, "brand_weight", 0.0)) +
2176
+ (visual_feature_score * visual_feature_weight)
2177
+ )
2178
+ else:
2179
+ combined_score = (
2180
+ basic_score * (1 - visual_weight) +
2181
+ text_score * 0.1 +
2182
+ visual_feature_score * visual_weight
2183
+ )
2184
+
2185
+ # Ensure score never exceeds 1.0
2186
+ combined_score = min(1.0, max(0.0, combined_score))
2187
+ text_score = min(1.0, max(0.0, text_score))
2188
+
2189
+ # Prepare human-readable outputs
2190
+ expected = expected_readable
2191
+ found = found_readable
2192
+ results.append(
2193
+ ComplianceResult(
2194
+ shelf_level=shelf_level,
2195
+ expected_products=expected,
2196
+ found_products=found,
2197
+ missing_products=missing,
2198
+ unexpected_products=unexpected,
2199
+ compliance_status=status,
2200
+ compliance_score=combined_score,
2201
+ text_compliance_results=text_results,
2202
+ text_compliance_score=text_score,
2203
+ overall_text_compliant=overall_text_ok,
2204
+ brand_compliance_result=brand_compliance_result
2205
+ )
2206
+ )
2207
+
2208
+ return results
2209
+
2210
+ def _base_model_from_str(self, s: str, brand: str = None) -> str:
2211
+ """
2212
+ Extract normalized base model from any text, supporting multiple brands.
2213
+
2214
+ Args:
2215
+ s: String to extract model from
2216
+ brand: Optional brand hint to improve extraction
2217
+
2218
+ Returns:
2219
+ Normalized model string or empty string if no model found
2220
+ """
2221
+ if not s:
2222
+ return ""
2223
+
2224
+ t = s.lower().strip()
2225
+ # normalize separators
2226
+ t = t.replace("—", "-").replace("–", "-").replace("_", "-")
2227
+
2228
+ # Brand-specific patterns
2229
+ if brand and brand.lower() == "epson":
2230
+ # EPSON EcoTank models: ET-2980, ET-3950, ET-4950
2231
+ m = re.search(r"(et)[- ]?(\d{4})", t)
2232
+ if m:
2233
+ return f"{m.group(1)}-{m.group(2)}"
2234
+
2235
+ elif brand and brand.lower() == "hisense":
2236
+ # HISENSE TV models: U6, U7, U8, plus potential series numbers
2237
+ # Patterns: U7, U8, U6, 55U8, U7K, etc.
2238
+ if re.search(r"canvas[\s-]*tv", t):
2239
+ return "canvas-tv"
2240
+ if re.search(r"canvas", t):
2241
+ return "canvas"
2242
+ patterns = [
2243
+ r"(\d*)(u\d+)([a-z]*)", # 55U8K, U7, U8K, etc.
2244
+ r"(u\d+)", # Simple U6, U7, U8
2245
+ ]
2246
+ for pattern in patterns:
2247
+ m = re.search(pattern, t)
2248
+ if m:
2249
+ if len(m.groups()) >= 2:
2250
+ # Extract size + series + variant if available
2251
+ size = m.group(1) if m.group(1) else ""
2252
+ series = m.group(2)
2253
+ variant = m.group(3) if len(m.groups()) > 2 and m.group(3) else ""
2254
+ return f"{size}{series}{variant}".lower()
2255
+ else:
2256
+ return m.group(1).lower()
2257
+
2258
+ # Generic patterns for any brand
2259
+ generic_patterns = [
2260
+ # Model with dashes: ABC-1234, XYZ-567
2261
+ r"([a-z]+)[- ]?(\d{3,4})",
2262
+ # Series patterns: U7, U8, A6, etc.
2263
+ r"([a-z]\d+)",
2264
+ # Number-letter combinations: 4950, 2980 (for fallback)
2265
+ r"(\d{4})",
2266
+ ]
2267
+
2268
+ for pattern in generic_patterns:
2269
+ m = re.search(pattern, t)
2270
+ if m:
2271
+ if len(m.groups()) >= 2:
2272
+ return f"{m.group(1)}-{m.group(2)}"
2273
+ else:
2274
+ return m.group(1).lower()
2275
+
2276
+ return ""
2277
+
2278
+ def _looks_like_box(self, visual_features: list[str] | None) -> bool:
2279
+ """Heuristic: does the detection look like packaging?"""
2280
+ if not visual_features:
2281
+ return False
2282
+ keywords = {"packaging", "package", "cardboard", "box", "blue packaging", "printer image on box"}
2283
+ norm = " ".join(visual_features).lower()
2284
+ return any(k in norm for k in keywords)
2285
+
2286
+ def _canonical_expected_key(self, sp: str, brand: str) -> tuple[str, str]:
2287
+ """
2288
+ From planogram product spec -> (product_type, base_model).
2289
+ Example: name='ET-4950', product_type='product_box' -> ('product_box','et-4950')
2290
+ """
2291
+ ptype = (sp.product_type or "").strip().lower()
2292
+ # Normalize product types
2293
+ type_mappings = {
2294
+ "tv_demonstration": "tv",
2295
+ "promotional_graphic": "promotional_graphic",
2296
+ "product_box": "product_box",
2297
+ "printer": "printer",
2298
+ "promotional_materials": "promotional_materials"
2299
+ }
2300
+ ptype = type_mappings.get(ptype, ptype)
2301
+ model_str = getattr(sp, "name", "") or getattr(sp, "product_model", "") or ""
2302
+ base = self._base_model_from_str(model_str, brand=brand)
2303
+ return ptype or "unknown", base or ""
2304
+
2305
+ def _canonical_found_key(self, p: str, brand: str) -> tuple[str, str, float]:
2306
+ """
2307
+ From IdentifiedProduct -> (resolved_product_type, base_model, adjusted_confidence).
2308
+ If visual features scream 'box', coerce/confirm product_type as 'product_box' and boost conf a bit.
2309
+ """
2310
+ ptype = (p.product_type or "").strip().lower()
2311
+ # Normalize product types
2312
+ type_mappings = {
2313
+ "tv_demonstration": "tv",
2314
+ "promotional_graphic": "promotional_graphic",
2315
+ "product_box": "product_box",
2316
+ "printer": "printer",
2317
+ "promotional_material": "promotional_material",
2318
+ "promotional_display": "promotional_display"
2319
+ }
2320
+ ptype = type_mappings.get(ptype, ptype)
2321
+ model_str = p.product_model or p.product_type or ""
2322
+ base = self._base_model_from_str(model_str, brand=brand)
2323
+ conf = float(getattr(p, "confidence", 0.0) or 0.0)
2324
+
2325
+ if self._looks_like_box(getattr(p, "visual_features", None)):
2326
+ if ptype != "product_box":
2327
+ ptype = "product_box"
2328
+ conf = min(1.0, conf + 0.05) # gentle nudge for box evidence
2329
+ return ptype or "unknown", base or "", conf
2330
+
2331
+ async def _find_poster(
2332
+ self,
2333
+ image: Image.Image,
2334
+ planogram: PlanogramDescription,
2335
+ partial_prompt: str
2336
+ ) -> tuple[Detections, Detections, Detections, Detections]:
2337
+ """
2338
+ Ask VISION Model to find the main promotional graphic for the given brand/tags.
2339
+ Returns (x1,y1,x2,y2) in absolute pixels, and the parsed JSON for logging.
2340
+ """
2341
+ brand = (getattr(planogram, "brand", "") or "").strip()
2342
+ tags = [t.strip() for t in getattr(planogram, "tags", []) or []]
2343
+ endcap = getattr(planogram, "advertisement_endcap", None)
2344
+ geometry = self.planogram_config.endcap_geometry
2345
+ if endcap and getattr(endcap, "text_requirements", None):
2346
+ for tr in endcap.text_requirements:
2347
+ if getattr(tr, "required_text", None):
2348
+ tags.append(tr.required_text)
2349
+ tag_hint = ", ".join(sorted(set(f"'{t}'" for t in tags if t)))
2350
+
2351
+ # downscale for LLM
2352
+ image_small = self._downscale_image(image, max_side=1024, quality=78)
2353
+ prompt = partial_prompt.format(
2354
+ brand=brand,
2355
+ tag_hint=tag_hint,
2356
+ image_size=image_small.size
2357
+ )
2358
+ max_attempts = 2 # Initial attempt + 1 retry
2359
+ retry_delay_seconds = 10
2360
+ msg = None
2361
+ for attempt in range(max_attempts):
2362
+ try:
2363
+ async with self.roi_client as client:
2364
+ msg = await client.ask_to_image(
2365
+ image=image_small,
2366
+ prompt=prompt,
2367
+ model="gemini-2.5-flash",
2368
+ no_memory=True,
2369
+ structured_output=Detections,
2370
+ max_tokens=8192
2371
+ )
2372
+ # If the call succeeds, break out of the loop
2373
+ break
2374
+ except ServerError as e:
2375
+ # Check if this was the last attempt
2376
+ if attempt < max_attempts - 1:
2377
+ print(
2378
+ f"WARNING: Model is overloaded. Retrying in {retry_delay_seconds} seconds... (Attempt {attempt + 1}/{max_attempts})"
2379
+ )
2380
+ await asyncio.sleep(retry_delay_seconds)
2381
+ else:
2382
+ print(
2383
+ f"ERROR: Model is still overloaded after {max_attempts} attempts. Failing."
2384
+ )
2385
+ # Re-raise the exception if the last attempt fails
2386
+ raise e
2387
+ # Evaluate the Output:
2388
+ # print('MSG >> ', msg)
2389
+ # print('OUTPUT > ', msg.output)
2390
+ data = msg.structured_output or msg.output or {}
2391
+ dets = data.detections or []
2392
+ if not dets:
2393
+ return None, data
2394
+ # pick detections
2395
+ panel_det = next(
2396
+ (d for d in dets if d.label == "poster_panel"), None) \
2397
+ or next((d for d in dets if d.label == "poster"), None) \
2398
+ or (max(dets, key=lambda x: float(x.confidence)) if dets else None
2399
+ )
2400
+ # poster text:
2401
+ text_det = next((d for d in dets if d.label == "poster_text"), None)
2402
+ # brand logo:
2403
+ brand_det = next((d for d in dets if d.label == "brand_logo"), None)
2404
+ if not panel_det:
2405
+ self.logger.error("Critical failure: Could not detect the poster_panel.")
2406
+ return None, None, None, None
2407
+
2408
+ # promotional graphic (inside the panel):
2409
+ promo_graphic_det = next(
2410
+ (d for d in dets if d.label == "promotional_graphic"), None
2411
+ )
2412
+
2413
+ # check if promo_graphic is contained by panel_det, if not, increase the panel:
2414
+ if promo_graphic_det and panel_det:
2415
+ # If promo graphic is outside panel, expand panel to include it
2416
+ if not (
2417
+ promo_graphic_det.bbox.x1 >= panel_det.bbox.x1 and
2418
+ promo_graphic_det.bbox.x2 <= panel_det.bbox.x2
2419
+ ):
2420
+ self.logger.info("Expanding poster_panel to include promotional_graphic.")
2421
+ panel_det.bbox.x1 = min(panel_det.bbox.x1, promo_graphic_det.bbox.x1)
2422
+ panel_det.bbox.x2 = max(panel_det.bbox.x2, promo_graphic_det.bbox.x2)
2423
+
2424
+ # Get planogram advertisement config with safe defaults
2425
+ advertisement_config = getattr(planogram, "advertisement_endcap", {})
2426
+ # # Default values if not in planogram, normalized to image (not ROI)
2427
+ # config_width_percent = advertisement_config.width_margin_percent
2428
+ # config_height_percent = advertisement_config.height_margin_percent
2429
+ # config_top_margin_percent = advertisement_config.top_margin_percent
2430
+ # # E.g., 5% of panel width
2431
+ # side_margin_percent = advertisement_config.side_margin_percent
2432
+
2433
+ config_width_percent = geometry.width_margin_percent
2434
+ config_height_percent = geometry.height_margin_percent
2435
+ config_top_margin_percent = geometry.top_margin_percent
2436
+ side_margin_percent = geometry.side_margin_percent
2437
+
2438
+ # --- Refined Panel Padding ---
2439
+ # Apply padding to the panel_det itself to ensure it captures the full visual area
2440
+ panel_det.bbox.x1 = max(0.0, panel_det.bbox.x1 - side_margin_percent)
2441
+ panel_det.bbox.x2 = min(1.0, panel_det.bbox.x2 + side_margin_percent)
2442
+
2443
+ if panel_det and text_det:
2444
+ text_bottom_y2 = text_det.bbox.y2
2445
+ padding = 0.08
2446
+ new_panel_y2 = min(text_bottom_y2 + padding, 1.0)
2447
+ panel_det.bbox.y2 = new_panel_y2
2448
+
2449
+ # --- endcap Detected:
2450
+ endcap_det = next((d for d in dets if d.label == "endcap"), None)
2451
+
2452
+ # panel
2453
+ px1, py1, px2, py2 = panel_det.bbox.x1, panel_det.bbox.y1, panel_det.bbox.x2, panel_det.bbox.y2
2454
+
2455
+ # Initial endcap box: Use the LLM's endcap detection if it exists, otherwise fall back to the panel
2456
+ if endcap_det:
2457
+ ex1, ey1, ex2, ey2 = endcap_det.bbox.x1, endcap_det.bbox.y1, endcap_det.bbox.x2, endcap_det.bbox.y2
2458
+ else:
2459
+ ex1, ey1, ex2, ey2 = px1, py1, px2, py2
2460
+
2461
+ if endcap_det is None:
2462
+ panel_h = py2 - py1
2463
+ ratio = max(1e-6, float(config_height_percent))
2464
+ top_margin = float(config_top_margin_percent)
2465
+ ey1 = max(0.0, py1 - top_margin)
2466
+ ey2 = min(1.0, ey1 + panel_h / ratio)
2467
+
2468
+ x_buffer = max(self.left_margin_ratio * (px2-px1), self.right_margin_ratio * (px2-px1))
2469
+ ex1 = min(ex1, px1 - x_buffer)
2470
+ ex2 = max(ex2, px2 + x_buffer)
2471
+
2472
+ # Clamp & monotonic
2473
+ ex1 = max(0.0, ex1)
2474
+ ex2 = min(1.0, ex2)
2475
+ if ex2 <= ex1:
2476
+ ex2 = ex1 + 1e-6
2477
+ ey1 = max(0.0, ey1)
2478
+ ey2 = min(1.0, ey2)
2479
+ if ey2 <= ey1:
2480
+ ey2 = ey1 + 1e-6
2481
+
2482
+ # Update the endcap_det bbox with the corrected values
2483
+ if endcap_det is None:
2484
+ endcap_det = DetectionBox(
2485
+ x1=ex1, y1=ey1, x2=ex2, y2=ey2,
2486
+ confidence=0.9, # Assign a default confidence
2487
+ label="endcap"
2488
+ )
2489
+ else:
2490
+ endcap_det.bbox.x1 = ex1
2491
+ endcap_det.bbox.x2 = ex2
2492
+ endcap_det.bbox.y1 = ey1
2493
+ endcap_det.bbox.y2 = ey2
2494
+
2495
+ return endcap_det, panel_det, brand_det, text_det, dets
2496
+
2497
+ # Complete Pipeline
2498
+ async def run(
2499
+ self,
2500
+ image: Union[str, Path, Image.Image],
2501
+ debug_raw="/tmp/data/yolo_raw_debug.png",
2502
+ return_overlay: Optional[str] = None, # "identified" | "detections" | "both" | None
2503
+ overlay_save_path: Optional[Union[str, Path]] = None,
2504
+ ) -> Dict[str, Any]:
2505
+ """
2506
+ Run the complete 3-step planogram compliance pipeline
2507
+
2508
+ Returns:
2509
+ Complete analysis results including all steps
2510
+ """
2511
+ self.logger.debug("Step 1: Find Region of Interest...")
2512
+ # Optimize Image for Classification:
2513
+ img = self.open_image(image)
2514
+
2515
+ # ROI detection:
2516
+ img_array = np.array(img) # RGB
2517
+
2518
+ # 1) Find the poster:
2519
+ planogram_description = self.planogram_config.get_planogram_description()
2520
+ endcap, ad, brand, panel_text, dets = await self._find_poster(
2521
+ img,
2522
+ planogram_description,
2523
+ partial_prompt=self.planogram_config.roi_detection_prompt
2524
+ )
2525
+ if return_overlay == 'detections' or return_overlay == 'both':
2526
+ debug_poster_path = debug_raw.replace(".png", "_poster_debug.png") if debug_raw else None
2527
+ panel_px = ad.bbox.get_coordinates()
2528
+ self._save_detections(
2529
+ image, panel_px, dets, debug_poster_path
2530
+ )
2531
+ # Check if detections are valid before proceeding
2532
+ if not endcap or not ad:
2533
+ print("ERROR: Failed to get required detections.")
2534
+ return # or raise an exception
2535
+
2536
+ # Locate Shelves and Objects:
2537
+ shelf_regions, detections = await self.detect_objects_and_shelves(
2538
+ image,
2539
+ img_array,
2540
+ endcap=endcap,
2541
+ ad=ad,
2542
+ brand=brand,
2543
+ panel_text=panel_text,
2544
+ planogram_description=planogram_description
2545
+ )
2546
+
2547
+ self.logger.debug(
2548
+ f"Found {len(detections)} objects in {len(shelf_regions)} shelf regions"
2549
+ )
2550
+
2551
+ self.logger.notice("Step 2: Identifying objects with LLM...")
2552
+ identified_products = await self.identify_objects_with_references(
2553
+ image,
2554
+ detections,
2555
+ shelf_regions,
2556
+ self.reference_images,
2557
+ prompt=self.planogram_config.object_identification_prompt
2558
+ )
2559
+
2560
+ self.logger.debug(
2561
+ f"Identified Products: {identified_products}"
2562
+ )
2563
+
2564
+ compliance_results = self.check_planogram_compliance(
2565
+ identified_products, planogram_description
2566
+ )
2567
+
2568
+ # Calculate overall compliance
2569
+ total_score = sum(
2570
+ r.compliance_score for r in compliance_results
2571
+ ) / len(compliance_results) if compliance_results else 0.0
2572
+ if total_score >= (planogram_description.global_compliance_threshold or 0.8):
2573
+ overall_compliant = True
2574
+ else:
2575
+ overall_compliant = all(
2576
+ r.compliance_status == ComplianceStatus.COMPLIANT for r in compliance_results
2577
+ )
2578
+ overlay_image = None
2579
+ overlay_path = None
2580
+ if return_overlay == 'identified' or return_overlay == 'both':
2581
+ try:
2582
+ overlay_image = self.render_evaluated_image(
2583
+ image,
2584
+ shelf_regions=shelf_regions,
2585
+ detections=detections,
2586
+ identified_products=identified_products,
2587
+ mode=return_overlay,
2588
+ show_shelves=True,
2589
+ save_to=overlay_save_path,
2590
+ )
2591
+ if overlay_save_path:
2592
+ overlay_path = str(Path(overlay_save_path))
2593
+ except Exception as e:
2594
+ self.logger.error(f"Failed to render overlay image: {e}")
2595
+ # is not mandatory to fail the whole pipeline
2596
+ overlay_image = None
2597
+ overlay_path = None
2598
+
2599
+ return {
2600
+ "step1_detections": detections,
2601
+ "step1_shelf_regions": shelf_regions,
2602
+ "step2_identified_products": identified_products,
2603
+ "step3_compliance_results": compliance_results,
2604
+ "overall_compliance_score": total_score,
2605
+ "overall_compliant": overall_compliant,
2606
+ "analysis_timestamp": datetime.now(),
2607
+ "overlay_image": overlay_image,
2608
+ "overlay_path": overlay_path,
2609
+ }
2610
+
2611
+ def render_evaluated_image(
2612
+ self,
2613
+ image: Union[str, Path, Image.Image],
2614
+ *,
2615
+ shelf_regions: Optional[List[ShelfRegion]] = None,
2616
+ detections: Optional[List[DetectionBox]] = None,
2617
+ identified_products: Optional[List[IdentifiedProduct]] = None,
2618
+ mode: str = "identified",
2619
+ show_shelves: bool = True,
2620
+ save_to: Optional[Union[str, Path]] = None,
2621
+ ) -> Image.Image:
2622
+ """
2623
+ Enhanced render with safe coordinate handling
2624
+ """
2625
+ def _norm_box(x1, y1, x2, y2):
2626
+ """Normalize box coordinates to ensure valid rectangle"""
2627
+ x1, x2 = int(x1), int(x2)
2628
+ y1, y2 = int(y1), int(y2)
2629
+
2630
+ # Ensure coordinates are in correct order
2631
+ if x1 > x2:
2632
+ x1, x2 = x2, x1
2633
+ if y1 > y2:
2634
+ y1, y2 = y2, y1
2635
+
2636
+ # Ensure minimum size
2637
+ if x2 - x1 < 1:
2638
+ x2 = x1 + 1
2639
+ if y2 - y1 < 1:
2640
+ y2 = y1 + 1
2641
+
2642
+ return x1, y1, x2, y2
2643
+
2644
+ # Get base image
2645
+ if isinstance(image, (str, Path)):
2646
+ base = Image.open(image).convert("RGB").copy()
2647
+ else:
2648
+ base = image.convert("RGB").copy()
2649
+
2650
+ draw = ImageDraw.Draw(base)
2651
+ try:
2652
+ font = ImageFont.load_default()
2653
+ except Exception:
2654
+ font = None
2655
+
2656
+ W, H = base.size
2657
+
2658
+ def _clip(x1, y1, x2, y2):
2659
+ """Clip coordinates to image bounds"""
2660
+ return max(0, x1), max(0, y1), min(W-1, x2), min(H-1, y2)
2661
+
2662
+ def _txt(draw_obj, xy, text, fill, bg=None):
2663
+ """Safe text drawing with error handling"""
2664
+ try:
2665
+ if not font:
2666
+ draw_obj.text(xy, text, fill=fill)
2667
+ return
2668
+ bbox = draw_obj.textbbox(xy, text, font=font)
2669
+ if bg is not None:
2670
+ draw_obj.rectangle(bbox, fill=bg)
2671
+ draw_obj.text(xy, text, fill=fill, font=font)
2672
+ except Exception:
2673
+ # Fallback to simple text if there's any error
2674
+ try:
2675
+ draw_obj.text(xy, text, fill=fill)
2676
+ except Exception:
2677
+ pass # Skip this text if it still fails
2678
+
2679
+ # Colors per product type
2680
+ colors = {
2681
+ "tv_demonstration": (0, 255, 0), # green for TVs
2682
+ "promotional_graphic": (255, 0, 255), # magenta for logos
2683
+ "promotional_base": (0, 0, 255), # blue for partner branding
2684
+ "fact_tag": (255, 255, 0), # yellow for info displays
2685
+ "product_box": (255, 128, 0), # orange
2686
+ "printer": (255, 0, 0), # red
2687
+ "unknown": (200, 200, 200), # gray
2688
+ }
2689
+
2690
+ # Draw shelves
2691
+ if show_shelves and shelf_regions:
2692
+ for sr in shelf_regions:
2693
+ try:
2694
+ x1, y1, x2, y2 = _clip(sr.bbox.x1, sr.bbox.y1, sr.bbox.x2, sr.bbox.y2)
2695
+ x1, y1, x2, y2 = _norm_box(x1, y1, x2, y2)
2696
+ draw.rectangle([x1, y1, x2, y2], outline=(255, 255, 0), width=3)
2697
+ _txt(draw, (x1+3, max(0, y1-14)), f"SHELF {sr.level}", fill=(0, 0, 0), bg=(255, 255, 0))
2698
+ except Exception as e:
2699
+ print(f"Warning: Could not draw shelf {sr.level}: {e}")
2700
+
2701
+ # Draw detections (thin)
2702
+ if mode in ("detections", "both") and detections:
2703
+ for i, d in enumerate(detections, start=1):
2704
+ try:
2705
+ x1, y1, x2, y2 = _clip(d.x1, d.y1, d.x2, d.y2)
2706
+ x1, y1, x2, y2 = _norm_box(x1, y1, x2, y2)
2707
+ draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0), width=2)
2708
+ lbl = f"ID:{i} {d.class_name} {d.confidence:.2f}"
2709
+ _txt(draw, (x1+2, max(0, y1-12)), lbl, fill=(0, 0, 0), bg=(255, 0, 0))
2710
+ except Exception as e:
2711
+ print(f"Warning: Could not draw detection {i}: {e}")
2712
+
2713
+ # Draw identified products (thick)
2714
+ if mode in ("identified", "both") and identified_products:
2715
+ for p in sorted(identified_products, key=lambda x: (x.detection_box.area if x.detection_box else 0), reverse=True):
2716
+ if not p.detection_box:
2717
+ continue
2718
+ try:
2719
+ x1, y1, x2, y2 = _clip(p.detection_box.x1, p.detection_box.y1, p.detection_box.x2, p.detection_box.y2)
2720
+ x1, y1, x2, y2 = _norm_box(x1, y1, x2, y2)
2721
+
2722
+ c = colors.get(p.product_type, (255, 0, 255))
2723
+ draw.rectangle([x1, y1, x2, y2], outline=c, width=5)
2724
+
2725
+ # Label
2726
+ pid = p.detection_id if p.detection_id is not None else "NEW"
2727
+ mm = f" {p.product_model}" if p.product_model else ""
2728
+ lab = f"#{pid} {p.product_type}{mm} ({p.confidence:.2f})"
2729
+ _txt(draw, (x1+3, max(0, y1-14)), lab, fill=(0, 0, 0), bg=c)
2730
+
2731
+ except Exception as e:
2732
+ print(f"Warning: Could not draw product {p.product_model}: {e}")
2733
+
2734
+ # Add legend
2735
+ legend_y = 8
2736
+ for key in ("tv_demonstration", "promotional_graphic", "promotional_base", "fact_tag"):
2737
+ if key in colors:
2738
+ try:
2739
+ c = colors[key]
2740
+ draw.rectangle([8, legend_y, 28, legend_y+10], fill=c)
2741
+ _txt(draw, (34, legend_y-2), key, fill=(255,255,255))
2742
+ legend_y += 14
2743
+ except Exception:
2744
+ pass
2745
+
2746
+ # Save if requested
2747
+ if save_to:
2748
+ try:
2749
+ save_to = Path(save_to)
2750
+ save_to.parent.mkdir(parents=True, exist_ok=True)
2751
+ base.save(save_to, quality=90)
2752
+ print(f"Overlay saved to: {save_to}")
2753
+ except Exception as e:
2754
+ print(f"Warning: Could not save overlay: {e}")
2755
+
2756
+ return base
2757
+
2758
+ def generate_compliance_json(self, results: Dict[str, Any]) -> Dict[str, Any]:
2759
+ """
2760
+ Generate comprehensive JSON report from pipeline results.
2761
+
2762
+ Args:
2763
+ results: Complete results object from pipeline.run()
2764
+
2765
+ Returns:
2766
+ Dictionary containing comprehensive compliance report
2767
+ """
2768
+ compliance_results = results['step3_compliance_results']
2769
+
2770
+ def serialize_compliance_result(result) -> Dict[str, Any]:
2771
+ """Convert ComplianceResult to serializable dictionary."""
2772
+ result_dict = {
2773
+ "shelf_level": result.shelf_level,
2774
+ "compliance_status": result.compliance_status.value,
2775
+ "compliance_score": round(result.compliance_score, 3),
2776
+ "expected_products": result.expected_products,
2777
+ "found_products": result.found_products,
2778
+ "missing_products": result.missing_products,
2779
+ "unexpected_products": result.unexpected_products,
2780
+ "text_compliance": {
2781
+ "score": round(result.text_compliance_score, 3),
2782
+ "overall_compliant": result.overall_text_compliant,
2783
+ "requirements": []
2784
+ }
2785
+ }
2786
+
2787
+ # Add text compliance details
2788
+ for text_result in result.text_compliance_results:
2789
+ text_dict = {
2790
+ "required_text": text_result.required_text,
2791
+ "found": text_result.found,
2792
+ "confidence": round(text_result.confidence, 3),
2793
+ "match_type": text_result.match_type,
2794
+ "matched_features": text_result.matched_features
2795
+ }
2796
+ result_dict["text_compliance"]["requirements"].append(text_dict)
2797
+
2798
+ # Add brand compliance if present
2799
+ if hasattr(result, 'brand_compliance_result') and result.brand_compliance_result:
2800
+ result_dict["brand_compliance"] = {
2801
+ "expected_brand": result.brand_compliance_result.expected_brand,
2802
+ "found_brand": result.brand_compliance_result.found_brand,
2803
+ "found": result.brand_compliance_result.found,
2804
+ "confidence": round(result.brand_compliance_result.confidence, 3)
2805
+ }
2806
+
2807
+ return result_dict
2808
+
2809
+ # Build the main report structure
2810
+ report = {
2811
+ "metadata": {
2812
+ "analysis_timestamp": results['analysis_timestamp'].isoformat(),
2813
+ "report_version": "1.0",
2814
+ "total_shelves_analyzed": len(compliance_results)
2815
+ },
2816
+ "overall_compliance": {
2817
+ "compliant": results['overall_compliant'],
2818
+ "score": round(results['overall_compliance_score'], 3),
2819
+ "percentage": f"{results['overall_compliance_score']:.1%}"
2820
+ },
2821
+ "shelf_results": [serialize_compliance_result(result) for result in compliance_results],
2822
+ "summary": {
2823
+ "compliant_shelves": sum(1 for r in compliance_results if r.compliance_status.value == "compliant"),
2824
+ "non_compliant_shelves": sum(1 for r in compliance_results if r.compliance_status.value == "non_compliant"),
2825
+ "missing_shelves": sum(1 for r in compliance_results if r.compliance_status.value == "missing"),
2826
+ "average_shelf_score": round(sum(r.compliance_score for r in compliance_results) / len(compliance_results), 3) if compliance_results else 0.0
2827
+ }
2828
+ }
2829
+
2830
+ # Add overlay path if provided
2831
+ if 'overlay_path' in results and results['overlay_path']:
2832
+ report["artifacts"] = {
2833
+ "overlay_image_path": str(results['overlay_path'])
2834
+ }
2835
+
2836
+ return report
2837
+
2838
+ def generate_compliance_markdown(
2839
+ self,
2840
+ results: Dict[str, Any],
2841
+ brand_name: Optional[str] = None,
2842
+ additional_notes: Optional[str] = None
2843
+ ) -> str:
2844
+ """
2845
+ Generate comprehensive Markdown report from pipeline results.
2846
+
2847
+ Args:
2848
+ results: Complete results object from pipeline.run()
2849
+ brand_name: Brand being analyzed (optional)
2850
+ additional_notes: Additional notes to include (optional)
2851
+
2852
+ Returns:
2853
+ Formatted Markdown string
2854
+ """
2855
+ compliance_results = results['step3_compliance_results']
2856
+ overall_compliance_score = results['overall_compliance_score']
2857
+ overall_compliant = results['overall_compliant']
2858
+ analysis_timestamp = results['analysis_timestamp']
2859
+ overlay_path = results.get('overlay_path')
2860
+
2861
+ def status_emoji(status: str) -> str:
2862
+ """Get emoji for compliance status."""
2863
+ status_map = {
2864
+ "compliant": "✅",
2865
+ "non_compliant": "❌",
2866
+ "missing": "⚠️",
2867
+ "misplaced": "🔄"
2868
+ }
2869
+ return status_map.get(status, "❓")
2870
+
2871
+ def format_percentage(score: float) -> str:
2872
+ """Format score as percentage."""
2873
+ return f"{score:.1%}"
2874
+
2875
+ # Start building the markdown
2876
+ lines = []
2877
+
2878
+ # Header
2879
+ brand_title = f" - {brand_name}" if brand_name else ""
2880
+ lines.append(f"# Planogram Compliance Report{brand_title}")
2881
+ lines.append("")
2882
+ lines.append(
2883
+ f"**Analysis Date:** {analysis_timestamp.strftime('%Y-%m-%d %H:%M:%S')}"
2884
+ )
2885
+ lines.append("")
2886
+
2887
+ # Overall Compliance Section
2888
+ overall_emoji = "✅" if overall_compliant else "❌"
2889
+ lines.append("## Overall Compliance")
2890
+ lines.append("")
2891
+ lines.append(f"**Status:** {overall_emoji} {'COMPLIANT' if overall_compliant else 'NON-COMPLIANT'}")
2892
+ lines.append(f"**Score:** {format_percentage(overall_compliance_score)}")
2893
+ lines.append("")
2894
+
2895
+ # Summary Statistics
2896
+ compliant_count = sum(1 for r in compliance_results if r.compliance_status.value == "compliant")
2897
+ total_count = len(compliance_results)
2898
+
2899
+ lines.append("## Summary")
2900
+ lines.append("")
2901
+ lines.append(f"- **Total Shelves:** {total_count}")
2902
+ lines.append(f"- **Compliant Shelves:** {compliant_count}/{total_count}")
2903
+ lines.append(f"- **Non-Compliant Shelves:** {total_count - compliant_count}/{total_count}")
2904
+
2905
+ if compliance_results:
2906
+ avg_score = sum(r.compliance_score for r in compliance_results) / len(compliance_results)
2907
+ lines.append(f"- **Average Shelf Score:** {format_percentage(avg_score)}")
2908
+ lines.append("")
2909
+
2910
+ # Detailed Shelf Results
2911
+ lines.append("## Detailed Results by Shelf")
2912
+ lines.append("")
2913
+
2914
+ for result in compliance_results:
2915
+ shelf_emoji = status_emoji(result.compliance_status.value)
2916
+ lines.append(f"### {result.shelf_level.upper().replace('_', ' ')}")
2917
+ lines.append("")
2918
+ lines.append(f"**Status:** {shelf_emoji} {result.compliance_status.value.upper()}")
2919
+ lines.append(f"**Score:** {format_percentage(result.compliance_score)}")
2920
+ lines.append("")
2921
+
2922
+ # Products
2923
+ lines.append("**Expected Products:**")
2924
+ for product in result.expected_products:
2925
+ lines.append(f"- {product}")
2926
+ lines.append("")
2927
+
2928
+ lines.append("**Found Products:**")
2929
+ if result.found_products:
2930
+ for product in result.found_products:
2931
+ lines.append(f"- {product}")
2932
+ else:
2933
+ lines.append("- *(None)*")
2934
+ lines.append("")
2935
+
2936
+ # Missing/Unexpected
2937
+ if result.missing_products:
2938
+ lines.append("**Missing Products:**")
2939
+ for product in result.missing_products:
2940
+ lines.append(f"- ❌ {product}")
2941
+ lines.append("")
2942
+
2943
+ if result.unexpected_products:
2944
+ lines.append("**Unexpected Products:**")
2945
+ for product in result.unexpected_products:
2946
+ lines.append(f"- ⚠️ {product}")
2947
+ lines.append("")
2948
+
2949
+ # Text Compliance
2950
+ if result.text_compliance_results:
2951
+ text_emoji = "✅" if result.overall_text_compliant else "❌"
2952
+ lines.append(f"**Text Compliance:** {text_emoji} {format_percentage(result.text_compliance_score)}")
2953
+ lines.append("")
2954
+
2955
+ for text_result in result.text_compliance_results:
2956
+ req_emoji = "✅" if text_result.found else "❌"
2957
+ lines.append(f"- {req_emoji} '{text_result.required_text}' (confidence: {text_result.confidence:.2f})")
2958
+ if text_result.matched_features:
2959
+ lines.append(f" - Matched: {', '.join(text_result.matched_features)}")
2960
+ lines.append("")
2961
+
2962
+ # Brand Compliance - only show on promotional graphic shelves
2963
+ if (hasattr(result, 'brand_compliance_result') and
2964
+ result.brand_compliance_result and
2965
+ 'promotional_graphic' in str(result.expected_products).lower()):
2966
+ brand_emoji = "✅" if result.brand_compliance_result.found else "❌"
2967
+ lines.append(f"**Brand Compliance:** {brand_emoji}")
2968
+ lines.append(f"- Expected: {result.brand_compliance_result.expected_brand}")
2969
+ if result.brand_compliance_result.found_brand:
2970
+ lines.append(f"- Found: {result.brand_compliance_result.found_brand}")
2971
+ lines.append(f"- Confidence: {result.brand_compliance_result.confidence:.2f}")
2972
+ else:
2973
+ lines.append("- Found: *(None)*")
2974
+ lines.append("")
2975
+
2976
+ lines.append("---")
2977
+ lines.append("")
2978
+
2979
+ # Artifacts Section
2980
+ if overlay_path:
2981
+ lines.append("## Analysis Artifacts")
2982
+ lines.append("")
2983
+ lines.append(f"**Overlay Image:** `{overlay_path}`")
2984
+ lines.append("")
2985
+
2986
+ # Add image link if it's a web-accessible path
2987
+ if str(overlay_path).startswith(('http://', 'https://')):
2988
+ lines.append(f"![Compliance Overlay]({overlay_path})")
2989
+ lines.append("")
2990
+
2991
+ # Additional Notes
2992
+ if additional_notes:
2993
+ lines.append("## Additional Notes")
2994
+ lines.append("")
2995
+ lines.append(additional_notes)
2996
+ lines.append("")
2997
+
2998
+ # Footer
2999
+ lines.append("---")
3000
+ lines.append("*Report generated by AI-Parrot Planogram Compliance Pipeline*")
3001
+
3002
+ return '\n'.join(lines)