ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentui/.prettierrc +15 -0
- agentui/QUICKSTART.md +272 -0
- agentui/README.md +59 -0
- agentui/env.example +16 -0
- agentui/jsconfig.json +14 -0
- agentui/package-lock.json +4242 -0
- agentui/package.json +34 -0
- agentui/scripts/postinstall/apply-patches.mjs +260 -0
- agentui/src/app.css +61 -0
- agentui/src/app.d.ts +13 -0
- agentui/src/app.html +12 -0
- agentui/src/components/LoadingSpinner.svelte +64 -0
- agentui/src/components/ThemeSwitcher.svelte +159 -0
- agentui/src/components/index.js +4 -0
- agentui/src/lib/api/bots.ts +60 -0
- agentui/src/lib/api/chat.ts +22 -0
- agentui/src/lib/api/http.ts +25 -0
- agentui/src/lib/components/BotCard.svelte +33 -0
- agentui/src/lib/components/ChatBubble.svelte +63 -0
- agentui/src/lib/components/Toast.svelte +21 -0
- agentui/src/lib/config.ts +20 -0
- agentui/src/lib/stores/auth.svelte.ts +73 -0
- agentui/src/lib/stores/theme.svelte.js +64 -0
- agentui/src/lib/stores/toast.svelte.ts +31 -0
- agentui/src/lib/utils/conversation.ts +39 -0
- agentui/src/routes/+layout.svelte +20 -0
- agentui/src/routes/+page.svelte +232 -0
- agentui/src/routes/login/+page.svelte +200 -0
- agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
- agentui/src/routes/talk/[agentId]/+page.ts +7 -0
- agentui/static/README.md +1 -0
- agentui/svelte.config.js +11 -0
- agentui/tailwind.config.ts +53 -0
- agentui/tsconfig.json +3 -0
- agentui/vite.config.ts +10 -0
- ai_parrot-0.17.2.dist-info/METADATA +472 -0
- ai_parrot-0.17.2.dist-info/RECORD +535 -0
- ai_parrot-0.17.2.dist-info/WHEEL +6 -0
- ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
- ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
- ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
- crew-builder/.prettierrc +15 -0
- crew-builder/QUICKSTART.md +259 -0
- crew-builder/README.md +113 -0
- crew-builder/env.example +17 -0
- crew-builder/jsconfig.json +14 -0
- crew-builder/package-lock.json +4182 -0
- crew-builder/package.json +37 -0
- crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
- crew-builder/src/app.css +62 -0
- crew-builder/src/app.d.ts +13 -0
- crew-builder/src/app.html +12 -0
- crew-builder/src/components/LoadingSpinner.svelte +64 -0
- crew-builder/src/components/ThemeSwitcher.svelte +149 -0
- crew-builder/src/components/index.js +9 -0
- crew-builder/src/lib/api/bots.ts +60 -0
- crew-builder/src/lib/api/chat.ts +80 -0
- crew-builder/src/lib/api/client.ts +56 -0
- crew-builder/src/lib/api/crew/crew.ts +136 -0
- crew-builder/src/lib/api/index.ts +5 -0
- crew-builder/src/lib/api/o365/auth.ts +65 -0
- crew-builder/src/lib/auth/auth.ts +54 -0
- crew-builder/src/lib/components/AgentNode.svelte +43 -0
- crew-builder/src/lib/components/BotCard.svelte +33 -0
- crew-builder/src/lib/components/ChatBubble.svelte +67 -0
- crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
- crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
- crew-builder/src/lib/components/JsonViewer.svelte +24 -0
- crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
- crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
- crew-builder/src/lib/components/Toast.svelte +67 -0
- crew-builder/src/lib/components/Toolbar.svelte +157 -0
- crew-builder/src/lib/components/index.ts +10 -0
- crew-builder/src/lib/config.ts +8 -0
- crew-builder/src/lib/stores/auth.svelte.ts +228 -0
- crew-builder/src/lib/stores/crewStore.ts +369 -0
- crew-builder/src/lib/stores/theme.svelte.js +145 -0
- crew-builder/src/lib/stores/toast.svelte.ts +69 -0
- crew-builder/src/lib/utils/conversation.ts +39 -0
- crew-builder/src/lib/utils/markdown.ts +122 -0
- crew-builder/src/lib/utils/talkHistory.ts +47 -0
- crew-builder/src/routes/+layout.svelte +20 -0
- crew-builder/src/routes/+page.svelte +539 -0
- crew-builder/src/routes/agents/+page.svelte +247 -0
- crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
- crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
- crew-builder/src/routes/builder/+page.svelte +204 -0
- crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
- crew-builder/src/routes/crew/ask/+page.ts +1 -0
- crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
- crew-builder/src/routes/login/+page.svelte +197 -0
- crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
- crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
- crew-builder/static/README.md +1 -0
- crew-builder/svelte.config.js +11 -0
- crew-builder/tailwind.config.ts +53 -0
- crew-builder/tsconfig.json +3 -0
- crew-builder/vite.config.ts +10 -0
- mcp_servers/calculator_server.py +309 -0
- parrot/__init__.py +27 -0
- parrot/__pycache__/__init__.cpython-310.pyc +0 -0
- parrot/__pycache__/version.cpython-310.pyc +0 -0
- parrot/_version.py +34 -0
- parrot/a2a/__init__.py +48 -0
- parrot/a2a/client.py +658 -0
- parrot/a2a/discovery.py +89 -0
- parrot/a2a/mixin.py +257 -0
- parrot/a2a/models.py +376 -0
- parrot/a2a/server.py +770 -0
- parrot/agents/__init__.py +29 -0
- parrot/bots/__init__.py +12 -0
- parrot/bots/a2a_agent.py +19 -0
- parrot/bots/abstract.py +3139 -0
- parrot/bots/agent.py +1129 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/chatbot.py +669 -0
- parrot/bots/data.py +1618 -0
- parrot/bots/database/__init__.py +5 -0
- parrot/bots/database/abstract.py +3071 -0
- parrot/bots/database/cache.py +286 -0
- parrot/bots/database/models.py +468 -0
- parrot/bots/database/prompts.py +154 -0
- parrot/bots/database/retries.py +98 -0
- parrot/bots/database/router.py +269 -0
- parrot/bots/database/sql.py +41 -0
- parrot/bots/db/__init__.py +6 -0
- parrot/bots/db/abstract.py +556 -0
- parrot/bots/db/bigquery.py +602 -0
- parrot/bots/db/cache.py +85 -0
- parrot/bots/db/documentdb.py +668 -0
- parrot/bots/db/elastic.py +1014 -0
- parrot/bots/db/influx.py +898 -0
- parrot/bots/db/mock.py +96 -0
- parrot/bots/db/multi.py +783 -0
- parrot/bots/db/prompts.py +185 -0
- parrot/bots/db/sql.py +1255 -0
- parrot/bots/db/tools.py +212 -0
- parrot/bots/document.py +680 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/kb.py +170 -0
- parrot/bots/mcp.py +36 -0
- parrot/bots/orchestration/README.md +463 -0
- parrot/bots/orchestration/__init__.py +1 -0
- parrot/bots/orchestration/agent.py +155 -0
- parrot/bots/orchestration/crew.py +3330 -0
- parrot/bots/orchestration/fsm.py +1179 -0
- parrot/bots/orchestration/hr.py +434 -0
- parrot/bots/orchestration/storage/__init__.py +4 -0
- parrot/bots/orchestration/storage/memory.py +100 -0
- parrot/bots/orchestration/storage/mixin.py +119 -0
- parrot/bots/orchestration/verify.py +202 -0
- parrot/bots/product.py +204 -0
- parrot/bots/prompts/__init__.py +96 -0
- parrot/bots/prompts/agents.py +155 -0
- parrot/bots/prompts/data.py +216 -0
- parrot/bots/prompts/output_generation.py +8 -0
- parrot/bots/scraper/__init__.py +3 -0
- parrot/bots/scraper/models.py +122 -0
- parrot/bots/scraper/scraper.py +1173 -0
- parrot/bots/scraper/templates.py +115 -0
- parrot/bots/stores/__init__.py +5 -0
- parrot/bots/stores/local.py +172 -0
- parrot/bots/webdev.py +81 -0
- parrot/cli.py +17 -0
- parrot/clients/__init__.py +16 -0
- parrot/clients/base.py +1491 -0
- parrot/clients/claude.py +1191 -0
- parrot/clients/factory.py +129 -0
- parrot/clients/google.py +4567 -0
- parrot/clients/gpt.py +1975 -0
- parrot/clients/grok.py +432 -0
- parrot/clients/groq.py +986 -0
- parrot/clients/hf.py +582 -0
- parrot/clients/models.py +18 -0
- parrot/conf.py +395 -0
- parrot/embeddings/__init__.py +9 -0
- parrot/embeddings/base.py +157 -0
- parrot/embeddings/google.py +98 -0
- parrot/embeddings/huggingface.py +74 -0
- parrot/embeddings/openai.py +84 -0
- parrot/embeddings/processor.py +88 -0
- parrot/exceptions.c +13868 -0
- parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/exceptions.pxd +22 -0
- parrot/exceptions.pxi +15 -0
- parrot/exceptions.pyx +44 -0
- parrot/generators/__init__.py +29 -0
- parrot/generators/base.py +200 -0
- parrot/generators/html.py +293 -0
- parrot/generators/react.py +205 -0
- parrot/generators/streamlit.py +203 -0
- parrot/generators/template.py +105 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agent.py +861 -0
- parrot/handlers/agents/__init__.py +1 -0
- parrot/handlers/agents/abstract.py +900 -0
- parrot/handlers/bots.py +338 -0
- parrot/handlers/chat.py +915 -0
- parrot/handlers/creation.sql +192 -0
- parrot/handlers/crew/ARCHITECTURE.md +362 -0
- parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
- parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
- parrot/handlers/crew/__init__.py +0 -0
- parrot/handlers/crew/handler.py +801 -0
- parrot/handlers/crew/models.py +229 -0
- parrot/handlers/crew/redis_persistence.py +523 -0
- parrot/handlers/jobs/__init__.py +10 -0
- parrot/handlers/jobs/job.py +384 -0
- parrot/handlers/jobs/mixin.py +627 -0
- parrot/handlers/jobs/models.py +115 -0
- parrot/handlers/jobs/worker.py +31 -0
- parrot/handlers/models.py +596 -0
- parrot/handlers/o365_auth.py +105 -0
- parrot/handlers/stream.py +337 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/aws.py +143 -0
- parrot/interfaces/credentials.py +113 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/google.py +1123 -0
- parrot/interfaces/hierarchy.py +1227 -0
- parrot/interfaces/http.py +651 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +24 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/analisys.py +148 -0
- parrot/interfaces/images/plugins/classify.py +150 -0
- parrot/interfaces/images/plugins/classifybase.py +182 -0
- parrot/interfaces/images/plugins/detect.py +150 -0
- parrot/interfaces/images/plugins/exif.py +1103 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/interfaces/o365.py +978 -0
- parrot/interfaces/onedrive.py +822 -0
- parrot/interfaces/sharepoint.py +1435 -0
- parrot/interfaces/soap.py +257 -0
- parrot/loaders/__init__.py +8 -0
- parrot/loaders/abstract.py +1131 -0
- parrot/loaders/audio.py +199 -0
- parrot/loaders/basepdf.py +53 -0
- parrot/loaders/basevideo.py +1568 -0
- parrot/loaders/csv.py +409 -0
- parrot/loaders/docx.py +116 -0
- parrot/loaders/epubloader.py +316 -0
- parrot/loaders/excel.py +199 -0
- parrot/loaders/factory.py +55 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/html.py +26 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/html.py +152 -0
- parrot/loaders/markdown.py +442 -0
- parrot/loaders/pdf.py +373 -0
- parrot/loaders/pdfmark.py +320 -0
- parrot/loaders/pdftables.py +506 -0
- parrot/loaders/ppt.py +476 -0
- parrot/loaders/qa.py +63 -0
- parrot/loaders/splitters/__init__.py +10 -0
- parrot/loaders/splitters/base.py +138 -0
- parrot/loaders/splitters/md.py +228 -0
- parrot/loaders/splitters/token.py +143 -0
- parrot/loaders/txt.py +26 -0
- parrot/loaders/video.py +89 -0
- parrot/loaders/videolocal.py +218 -0
- parrot/loaders/videounderstanding.py +377 -0
- parrot/loaders/vimeo.py +167 -0
- parrot/loaders/web.py +599 -0
- parrot/loaders/youtube.py +504 -0
- parrot/manager/__init__.py +5 -0
- parrot/manager/manager.py +1030 -0
- parrot/mcp/__init__.py +28 -0
- parrot/mcp/adapter.py +105 -0
- parrot/mcp/cli.py +174 -0
- parrot/mcp/client.py +119 -0
- parrot/mcp/config.py +75 -0
- parrot/mcp/integration.py +842 -0
- parrot/mcp/oauth.py +933 -0
- parrot/mcp/server.py +225 -0
- parrot/mcp/transports/__init__.py +3 -0
- parrot/mcp/transports/base.py +279 -0
- parrot/mcp/transports/grpc_session.py +163 -0
- parrot/mcp/transports/http.py +312 -0
- parrot/mcp/transports/mcp.proto +108 -0
- parrot/mcp/transports/quic.py +1082 -0
- parrot/mcp/transports/sse.py +330 -0
- parrot/mcp/transports/stdio.py +309 -0
- parrot/mcp/transports/unix.py +395 -0
- parrot/mcp/transports/websocket.py +547 -0
- parrot/memory/__init__.py +16 -0
- parrot/memory/abstract.py +209 -0
- parrot/memory/agent.py +32 -0
- parrot/memory/cache.py +175 -0
- parrot/memory/core.py +555 -0
- parrot/memory/file.py +153 -0
- parrot/memory/mem.py +131 -0
- parrot/memory/redis.py +613 -0
- parrot/models/__init__.py +46 -0
- parrot/models/basic.py +118 -0
- parrot/models/compliance.py +208 -0
- parrot/models/crew.py +395 -0
- parrot/models/detections.py +654 -0
- parrot/models/generation.py +85 -0
- parrot/models/google.py +223 -0
- parrot/models/groq.py +23 -0
- parrot/models/openai.py +30 -0
- parrot/models/outputs.py +285 -0
- parrot/models/responses.py +938 -0
- parrot/notifications/__init__.py +743 -0
- parrot/openapi/__init__.py +3 -0
- parrot/openapi/components.yaml +641 -0
- parrot/openapi/config.py +322 -0
- parrot/outputs/__init__.py +32 -0
- parrot/outputs/formats/__init__.py +108 -0
- parrot/outputs/formats/altair.py +359 -0
- parrot/outputs/formats/application.py +122 -0
- parrot/outputs/formats/base.py +351 -0
- parrot/outputs/formats/bokeh.py +356 -0
- parrot/outputs/formats/card.py +424 -0
- parrot/outputs/formats/chart.py +436 -0
- parrot/outputs/formats/d3.py +255 -0
- parrot/outputs/formats/echarts.py +310 -0
- parrot/outputs/formats/generators/__init__.py +0 -0
- parrot/outputs/formats/generators/abstract.py +61 -0
- parrot/outputs/formats/generators/panel.py +145 -0
- parrot/outputs/formats/generators/streamlit.py +86 -0
- parrot/outputs/formats/generators/terminal.py +63 -0
- parrot/outputs/formats/holoviews.py +310 -0
- parrot/outputs/formats/html.py +147 -0
- parrot/outputs/formats/jinja2.py +46 -0
- parrot/outputs/formats/json.py +87 -0
- parrot/outputs/formats/map.py +933 -0
- parrot/outputs/formats/markdown.py +172 -0
- parrot/outputs/formats/matplotlib.py +237 -0
- parrot/outputs/formats/mixins/__init__.py +0 -0
- parrot/outputs/formats/mixins/emaps.py +855 -0
- parrot/outputs/formats/plotly.py +341 -0
- parrot/outputs/formats/seaborn.py +310 -0
- parrot/outputs/formats/table.py +397 -0
- parrot/outputs/formats/template_report.py +138 -0
- parrot/outputs/formats/yaml.py +125 -0
- parrot/outputs/formatter.py +152 -0
- parrot/outputs/templates/__init__.py +95 -0
- parrot/pipelines/__init__.py +0 -0
- parrot/pipelines/abstract.py +210 -0
- parrot/pipelines/detector.py +124 -0
- parrot/pipelines/models.py +90 -0
- parrot/pipelines/planogram.py +3002 -0
- parrot/pipelines/table.sql +97 -0
- parrot/plugins/__init__.py +106 -0
- parrot/plugins/importer.py +80 -0
- parrot/py.typed +0 -0
- parrot/registry/__init__.py +18 -0
- parrot/registry/registry.py +594 -0
- parrot/scheduler/__init__.py +1189 -0
- parrot/scheduler/models.py +60 -0
- parrot/security/__init__.py +16 -0
- parrot/security/prompt_injection.py +268 -0
- parrot/security/security_events.sql +25 -0
- parrot/services/__init__.py +1 -0
- parrot/services/mcp/__init__.py +8 -0
- parrot/services/mcp/config.py +13 -0
- parrot/services/mcp/server.py +295 -0
- parrot/services/o365_remote_auth.py +235 -0
- parrot/stores/__init__.py +7 -0
- parrot/stores/abstract.py +352 -0
- parrot/stores/arango.py +1090 -0
- parrot/stores/bigquery.py +1377 -0
- parrot/stores/cache.py +106 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss_store.py +1157 -0
- parrot/stores/kb/__init__.py +9 -0
- parrot/stores/kb/abstract.py +68 -0
- parrot/stores/kb/cache.py +165 -0
- parrot/stores/kb/doc.py +325 -0
- parrot/stores/kb/hierarchy.py +346 -0
- parrot/stores/kb/local.py +457 -0
- parrot/stores/kb/prompt.py +28 -0
- parrot/stores/kb/redis.py +659 -0
- parrot/stores/kb/store.py +115 -0
- parrot/stores/kb/user.py +374 -0
- parrot/stores/models.py +59 -0
- parrot/stores/pgvector.py +3 -0
- parrot/stores/postgres.py +2853 -0
- parrot/stores/utils/__init__.py +0 -0
- parrot/stores/utils/chunking.py +197 -0
- parrot/telemetry/__init__.py +3 -0
- parrot/telemetry/mixin.py +111 -0
- parrot/template/__init__.py +3 -0
- parrot/template/engine.py +259 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +644 -0
- parrot/tools/agent.py +363 -0
- parrot/tools/arangodbsearch.py +537 -0
- parrot/tools/arxiv_tool.py +188 -0
- parrot/tools/calculator/__init__.py +3 -0
- parrot/tools/calculator/operations/__init__.py +38 -0
- parrot/tools/calculator/operations/calculus.py +80 -0
- parrot/tools/calculator/operations/statistics.py +76 -0
- parrot/tools/calculator/tool.py +150 -0
- parrot/tools/cloudwatch.py +988 -0
- parrot/tools/codeinterpreter/__init__.py +127 -0
- parrot/tools/codeinterpreter/executor.py +371 -0
- parrot/tools/codeinterpreter/internals.py +473 -0
- parrot/tools/codeinterpreter/models.py +643 -0
- parrot/tools/codeinterpreter/prompts.py +224 -0
- parrot/tools/codeinterpreter/tool.py +664 -0
- parrot/tools/company_info/__init__.py +6 -0
- parrot/tools/company_info/tool.py +1138 -0
- parrot/tools/correlationanalysis.py +437 -0
- parrot/tools/database/abstract.py +286 -0
- parrot/tools/database/bq.py +115 -0
- parrot/tools/database/cache.py +284 -0
- parrot/tools/database/models.py +95 -0
- parrot/tools/database/pg.py +343 -0
- parrot/tools/databasequery.py +1159 -0
- parrot/tools/db.py +1800 -0
- parrot/tools/ddgo.py +370 -0
- parrot/tools/decorators.py +271 -0
- parrot/tools/dftohtml.py +282 -0
- parrot/tools/document.py +549 -0
- parrot/tools/ecs.py +819 -0
- parrot/tools/edareport.py +368 -0
- parrot/tools/elasticsearch.py +1049 -0
- parrot/tools/employees.py +462 -0
- parrot/tools/epson/__init__.py +96 -0
- parrot/tools/excel.py +683 -0
- parrot/tools/file/__init__.py +13 -0
- parrot/tools/file/abstract.py +76 -0
- parrot/tools/file/gcs.py +378 -0
- parrot/tools/file/local.py +284 -0
- parrot/tools/file/s3.py +511 -0
- parrot/tools/file/tmp.py +309 -0
- parrot/tools/file/tool.py +501 -0
- parrot/tools/file_reader.py +129 -0
- parrot/tools/flowtask/__init__.py +19 -0
- parrot/tools/flowtask/tool.py +761 -0
- parrot/tools/gittoolkit.py +508 -0
- parrot/tools/google/__init__.py +18 -0
- parrot/tools/google/base.py +169 -0
- parrot/tools/google/tools.py +1251 -0
- parrot/tools/googlelocation.py +5 -0
- parrot/tools/googleroutes.py +5 -0
- parrot/tools/googlesearch.py +5 -0
- parrot/tools/googlesitesearch.py +5 -0
- parrot/tools/googlevoice.py +2 -0
- parrot/tools/gvoice.py +695 -0
- parrot/tools/ibisworld/README.md +225 -0
- parrot/tools/ibisworld/__init__.py +11 -0
- parrot/tools/ibisworld/tool.py +366 -0
- parrot/tools/jiratoolkit.py +1718 -0
- parrot/tools/manager.py +1098 -0
- parrot/tools/math.py +152 -0
- parrot/tools/metadata.py +476 -0
- parrot/tools/msteams.py +1621 -0
- parrot/tools/msword.py +635 -0
- parrot/tools/multidb.py +580 -0
- parrot/tools/multistoresearch.py +369 -0
- parrot/tools/networkninja.py +167 -0
- parrot/tools/nextstop/__init__.py +4 -0
- parrot/tools/nextstop/base.py +286 -0
- parrot/tools/nextstop/employee.py +733 -0
- parrot/tools/nextstop/store.py +462 -0
- parrot/tools/notification.py +435 -0
- parrot/tools/o365/__init__.py +42 -0
- parrot/tools/o365/base.py +295 -0
- parrot/tools/o365/bundle.py +522 -0
- parrot/tools/o365/events.py +554 -0
- parrot/tools/o365/mail.py +992 -0
- parrot/tools/o365/onedrive.py +497 -0
- parrot/tools/o365/sharepoint.py +641 -0
- parrot/tools/openapi_toolkit.py +904 -0
- parrot/tools/openweather.py +527 -0
- parrot/tools/pdfprint.py +1001 -0
- parrot/tools/powerbi.py +518 -0
- parrot/tools/powerpoint.py +1113 -0
- parrot/tools/pricestool.py +146 -0
- parrot/tools/products/__init__.py +246 -0
- parrot/tools/prophet_tool.py +171 -0
- parrot/tools/pythonpandas.py +630 -0
- parrot/tools/pythonrepl.py +910 -0
- parrot/tools/qsource.py +436 -0
- parrot/tools/querytoolkit.py +395 -0
- parrot/tools/quickeda.py +827 -0
- parrot/tools/resttool.py +553 -0
- parrot/tools/retail/__init__.py +0 -0
- parrot/tools/retail/bby.py +528 -0
- parrot/tools/sandboxtool.py +703 -0
- parrot/tools/sassie/__init__.py +352 -0
- parrot/tools/scraping/__init__.py +7 -0
- parrot/tools/scraping/docs/select.md +466 -0
- parrot/tools/scraping/documentation.md +1278 -0
- parrot/tools/scraping/driver.py +436 -0
- parrot/tools/scraping/models.py +576 -0
- parrot/tools/scraping/options.py +85 -0
- parrot/tools/scraping/orchestrator.py +517 -0
- parrot/tools/scraping/readme.md +740 -0
- parrot/tools/scraping/tool.py +3115 -0
- parrot/tools/seasonaldetection.py +642 -0
- parrot/tools/shell_tool/__init__.py +5 -0
- parrot/tools/shell_tool/actions.py +408 -0
- parrot/tools/shell_tool/engine.py +155 -0
- parrot/tools/shell_tool/models.py +322 -0
- parrot/tools/shell_tool/tool.py +442 -0
- parrot/tools/site_search.py +214 -0
- parrot/tools/textfile.py +418 -0
- parrot/tools/think.py +378 -0
- parrot/tools/toolkit.py +298 -0
- parrot/tools/webapp_tool.py +187 -0
- parrot/tools/whatif.py +1279 -0
- parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
- parrot/tools/workday/__init__.py +6 -0
- parrot/tools/workday/models.py +1389 -0
- parrot/tools/workday/tool.py +1293 -0
- parrot/tools/yfinance_tool.py +306 -0
- parrot/tools/zipcode.py +217 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/helpers.py +73 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.c +12078 -0
- parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/parsers/toml.pyx +21 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpp +20936 -0
- parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
- parrot/utils/types.pyx +213 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- parrot/yaml-rs/Cargo.lock +350 -0
- parrot/yaml-rs/Cargo.toml +19 -0
- parrot/yaml-rs/pyproject.toml +19 -0
- parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
- parrot/yaml-rs/src/lib.rs +222 -0
- requirements/docker-compose.yml +24 -0
- requirements/requirements-dev.txt +21 -0
|
@@ -0,0 +1,3002 @@
|
|
|
1
|
+
"""
|
|
2
|
+
3-Step Planogram Compliance Pipeline
|
|
3
|
+
Step 1: Object Detection (YOLO/ResNet)
|
|
4
|
+
Step 2: LLM Object Identification with Reference Images
|
|
5
|
+
Step 3: Planogram Comparison and Compliance Verification
|
|
6
|
+
"""
|
|
7
|
+
import asyncio
|
|
8
|
+
import os
|
|
9
|
+
from typing import List, Dict, Any, Optional, Union, Tuple
|
|
10
|
+
from collections import defaultdict, Counter
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
import unicodedata
|
|
13
|
+
import re
|
|
14
|
+
import traceback
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
import math
|
|
17
|
+
import pytesseract
|
|
18
|
+
from PIL import (
|
|
19
|
+
Image,
|
|
20
|
+
ImageDraw,
|
|
21
|
+
ImageFont,
|
|
22
|
+
ImageEnhance,
|
|
23
|
+
ImageOps
|
|
24
|
+
)
|
|
25
|
+
import numpy as np
|
|
26
|
+
from pydantic import BaseModel, Field
|
|
27
|
+
import cv2
|
|
28
|
+
import torch
|
|
29
|
+
from google.genai.errors import ServerError
|
|
30
|
+
from .abstract import AbstractPipeline
|
|
31
|
+
from ..models.detections import (
|
|
32
|
+
DetectionBox,
|
|
33
|
+
Detection,
|
|
34
|
+
Detections,
|
|
35
|
+
ShelfRegion,
|
|
36
|
+
IdentifiedProduct,
|
|
37
|
+
PlanogramDescription
|
|
38
|
+
)
|
|
39
|
+
from ..models.compliance import (
|
|
40
|
+
ComplianceResult,
|
|
41
|
+
ComplianceStatus,
|
|
42
|
+
TextComplianceResult,
|
|
43
|
+
TextMatcher,
|
|
44
|
+
BrandComplianceResult
|
|
45
|
+
)
|
|
46
|
+
from .detector import AbstractDetector
|
|
47
|
+
from .models import PlanogramConfig
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
CID = {
|
|
51
|
+
"promotional_candidate": 103,
|
|
52
|
+
"product_candidate": 100,
|
|
53
|
+
"box_candidate": 101,
|
|
54
|
+
"price_tag": 102,
|
|
55
|
+
"shelf_region": 190,
|
|
56
|
+
"brand_logo": 105,
|
|
57
|
+
"poster_text": 106,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
class RetailDetector(AbstractDetector):
|
|
61
|
+
"""
|
|
62
|
+
Reference-guided Phase-1 detector.
|
|
63
|
+
|
|
64
|
+
1) Enhance image (contrast/brightness) to help OCR/YOLO/CLIP.
|
|
65
|
+
2) Localize the promotional poster using:
|
|
66
|
+
- OCR ('EPSON', 'Hello', 'Savings', etc.)
|
|
67
|
+
- CLIP similarity with your FIRST reference image.
|
|
68
|
+
3) Crop to poster width (+ margin) to form an endcap ROI (remember offsets).
|
|
69
|
+
4) Detect shelf lines within ROI (Hough) => top/middle/bottom bands.
|
|
70
|
+
5) YOLO proposals inside ROI (low conf, class-agnostic).
|
|
71
|
+
6) For each proposal: OCR + CLIP vs remaining reference images
|
|
72
|
+
=> label as promotional/product/box candidate.
|
|
73
|
+
7) Shrink, merge, suppress items that are inside the poster.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
yolo_model: str = "yolo12l.pt",
|
|
79
|
+
conf: float = 0.15,
|
|
80
|
+
iou: float = 0.5,
|
|
81
|
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
|
82
|
+
reference_images: Optional[List[str]] = None, # first is the poster
|
|
83
|
+
**kwargs
|
|
84
|
+
):
|
|
85
|
+
super().__init__(
|
|
86
|
+
yolo_model=yolo_model,
|
|
87
|
+
conf=conf,
|
|
88
|
+
iou=iou,
|
|
89
|
+
device=device,
|
|
90
|
+
**kwargs
|
|
91
|
+
)
|
|
92
|
+
# Shelf split defaults: header/middle/bottom
|
|
93
|
+
self.shelf_split = (0.40, 0.25, 0.35) # sums to ~1.0
|
|
94
|
+
# Useful elsewhere (price tag guardrails, etc.)
|
|
95
|
+
self.label_strip_ratio = 0.06
|
|
96
|
+
self.ref_paths = reference_images or []
|
|
97
|
+
self.ref_ad = self.ref_paths[0] if self.ref_paths else None
|
|
98
|
+
self.ref_products = self.ref_paths[1:] if len(self.ref_paths) > 1 else []
|
|
99
|
+
self.ref_ad_feat = self._embed_image(self.ref_ad) if self.ref_ad else None
|
|
100
|
+
self.ref_prod_feats = [
|
|
101
|
+
self._embed_image(p) for p in self.ref_products
|
|
102
|
+
] if self.ref_products else []
|
|
103
|
+
|
|
104
|
+
# -------------------------- Main Detection Entry ---------------------------------
|
|
105
|
+
async def detect(
|
|
106
|
+
self,
|
|
107
|
+
image: Image.Image,
|
|
108
|
+
image_array: np.array,
|
|
109
|
+
endcap: Detection,
|
|
110
|
+
ad: Detection,
|
|
111
|
+
planogram: Optional[PlanogramDescription] = None,
|
|
112
|
+
debug_yolo: Optional[str] = None,
|
|
113
|
+
debug_phase1: Optional[str] = None,
|
|
114
|
+
debug_phases: Optional[str] = None,
|
|
115
|
+
):
|
|
116
|
+
h, w = image_array.shape[:2]
|
|
117
|
+
# text prompts (backup if no product refs)
|
|
118
|
+
text = [f"a photo of a {t}" for t in planogram.text_tokens if t]
|
|
119
|
+
if not text:
|
|
120
|
+
text = [
|
|
121
|
+
"a photo of a retail promotional poster lightbox",
|
|
122
|
+
"a photo of a product box",
|
|
123
|
+
"a photo of a product cartridge bottle",
|
|
124
|
+
"a photo of a price tag"
|
|
125
|
+
]
|
|
126
|
+
self.text_tokens = self.proc(
|
|
127
|
+
text=text,
|
|
128
|
+
return_tensors="pt",
|
|
129
|
+
padding=True
|
|
130
|
+
).to(self.device)
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
self.text_feats = self.clip.get_text_features(**self.text_tokens)
|
|
133
|
+
self.text_feats = self.text_feats / self.text_feats.norm(dim=-1, keepdim=True)
|
|
134
|
+
|
|
135
|
+
# Check if detections are valid before proceeding
|
|
136
|
+
if not endcap or not ad:
|
|
137
|
+
print("ERROR: Failed to get required detections.")
|
|
138
|
+
return # or raise an exception
|
|
139
|
+
|
|
140
|
+
# 2) endcap ROI
|
|
141
|
+
roi_box = endcap.bbox.get_pixel_coordinates(width=w, height=h)
|
|
142
|
+
ad_box = ad.bbox.get_pixel_coordinates(width=w, height=h)
|
|
143
|
+
|
|
144
|
+
# Unpack the Pixel coordinates
|
|
145
|
+
rx1, ry1, rx2, ry2 = roi_box
|
|
146
|
+
|
|
147
|
+
roi = image_array[ry1:ry2, rx1:rx2]
|
|
148
|
+
|
|
149
|
+
# 4) YOLO inside ROI
|
|
150
|
+
yolo_props = self._yolo_props(roi, rx1, ry1)
|
|
151
|
+
|
|
152
|
+
# Extract planogram config for shelf layout
|
|
153
|
+
planogram_config = None
|
|
154
|
+
if planogram:
|
|
155
|
+
planogram_config = {
|
|
156
|
+
'shelves': [
|
|
157
|
+
{
|
|
158
|
+
'level': shelf.level,
|
|
159
|
+
'height_ratio': getattr(shelf, 'height_ratio', None),
|
|
160
|
+
'products': [
|
|
161
|
+
{
|
|
162
|
+
'name': product.name,
|
|
163
|
+
'product_type': product.product_type
|
|
164
|
+
} for product in shelf.products
|
|
165
|
+
]
|
|
166
|
+
} for shelf in planogram.shelves
|
|
167
|
+
]
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
# 3) shelves
|
|
171
|
+
shelf_lines, bands = self._find_shelves(
|
|
172
|
+
roi_box=roi_box,
|
|
173
|
+
ad_box=ad_box,
|
|
174
|
+
w=w,
|
|
175
|
+
h=h,
|
|
176
|
+
planogram_config=planogram_config
|
|
177
|
+
)
|
|
178
|
+
# header_limit_y = min(v[0] for v in bands.values()) if bands else int(0.4 * h)
|
|
179
|
+
# classification fallback limit = header bottom (or 40% of ROI height)
|
|
180
|
+
if bands and "header" in bands:
|
|
181
|
+
header_limit_y = bands["header"][1]
|
|
182
|
+
else:
|
|
183
|
+
roi_h = max(1, ry2 - ry1)
|
|
184
|
+
header_limit_y = ry1 + int(0.4 * roi_h)
|
|
185
|
+
|
|
186
|
+
if debug_yolo:
|
|
187
|
+
dbg = self._draw_phase_areas(image_array.copy(), yolo_props, roi_box)
|
|
188
|
+
if debug_phases:
|
|
189
|
+
cv2.imwrite(
|
|
190
|
+
debug_phases,
|
|
191
|
+
cv2.cvtColor(dbg, cv2.COLOR_RGB2BGR)
|
|
192
|
+
)
|
|
193
|
+
dbg = self._draw_yolo(image_array.copy(), yolo_props, roi_box, shelf_lines)
|
|
194
|
+
cv2.imwrite(
|
|
195
|
+
debug_yolo,
|
|
196
|
+
cv2.cvtColor(dbg, cv2.COLOR_RGB2BGR)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# 5) classify YOLO → proposals (works w/ bands={}, header_limit_y above)
|
|
200
|
+
proposals = await self._classify_proposals(
|
|
201
|
+
image_array,
|
|
202
|
+
yolo_props,
|
|
203
|
+
bands,
|
|
204
|
+
header_limit_y,
|
|
205
|
+
ad_box
|
|
206
|
+
)
|
|
207
|
+
# 6) shrink -> merge -> remove those fully inside the poster
|
|
208
|
+
proposals = self._merge(proposals, iou_same=0.45)
|
|
209
|
+
|
|
210
|
+
# shelves dict to satisfy callers; in flat mode keep it empty
|
|
211
|
+
shelves = {
|
|
212
|
+
name: DetectionBox(
|
|
213
|
+
x1=rx1, y1=y1, x2=rx2, y2=y2,
|
|
214
|
+
confidence=1.0,
|
|
215
|
+
class_id=190, class_name="shelf_region",
|
|
216
|
+
area=(rx2-rx1)*(y2-y1),
|
|
217
|
+
)
|
|
218
|
+
for name, (y1, y2) in bands.items()
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
# (OPTIONAL) draw Phase-1 debug
|
|
222
|
+
if debug_phase1:
|
|
223
|
+
dbg = self._draw_phase1(
|
|
224
|
+
image_array.copy(),
|
|
225
|
+
roi_box,
|
|
226
|
+
shelf_lines,
|
|
227
|
+
proposals,
|
|
228
|
+
ad_box
|
|
229
|
+
)
|
|
230
|
+
cv2.imwrite(
|
|
231
|
+
debug_phase1,
|
|
232
|
+
cv2.cvtColor(dbg, cv2.COLOR_RGB2BGR)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# 8) ensure the promo exists exactly once
|
|
236
|
+
if ad_box is not None and not any(d.class_name == "promotional_candidate" and self._iou_box_tuple(d, ad_box) > 0.7 for d in proposals):
|
|
237
|
+
x1, y1, x2, y2 = ad_box
|
|
238
|
+
proposals.append(
|
|
239
|
+
DetectionBox(
|
|
240
|
+
x1=x1, y1=y1, x2=x2, y2=y2,
|
|
241
|
+
confidence=0.95,
|
|
242
|
+
class_id=103,
|
|
243
|
+
class_name="promotional_candidate",
|
|
244
|
+
area=(x2-x1)*(y2-y1)
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return {"shelves": shelves, "proposals": proposals}
|
|
249
|
+
|
|
250
|
+
# --------------------------- shelves -------------------------------------
|
|
251
|
+
def _find_shelves(
|
|
252
|
+
self,
|
|
253
|
+
roi_box: tuple[int, int, int, int],
|
|
254
|
+
ad_box: tuple[int, int, int, int],
|
|
255
|
+
h: int,
|
|
256
|
+
w: int,
|
|
257
|
+
planogram_config: dict = None
|
|
258
|
+
) -> tuple[List[int], dict]:
|
|
259
|
+
"""
|
|
260
|
+
Detects shelf bands based on planogram configuration, prioritizing the
|
|
261
|
+
dynamically detected ad_box for the header.
|
|
262
|
+
"""
|
|
263
|
+
rx1, ry1, rx2, ry2 = map(int, roi_box)
|
|
264
|
+
_, ad_y1, _, ad_y2 = map(int, ad_box)
|
|
265
|
+
roi_h = max(1, ry2 - ry1)
|
|
266
|
+
|
|
267
|
+
# Fallback to the old proportional method if no planogram is provided
|
|
268
|
+
if not planogram_config or 'shelves' not in planogram_config:
|
|
269
|
+
return self._find_shelves_proportional(roi_box, rx1, ry1, rx2, ry2, h)
|
|
270
|
+
|
|
271
|
+
shelf_configs = planogram_config['shelves']
|
|
272
|
+
if not shelf_configs:
|
|
273
|
+
return [], {}
|
|
274
|
+
|
|
275
|
+
bands = {}
|
|
276
|
+
levels = []
|
|
277
|
+
|
|
278
|
+
# --- 1. Prioritize the Header based on ad_box ---
|
|
279
|
+
# The header starts at the top of the ROI and ends at the bottom of the ad_box
|
|
280
|
+
header_config = next((s for s in shelf_configs if s.get('level') == 'header'), None)
|
|
281
|
+
if header_config:
|
|
282
|
+
# Use the detected ad_box y-coordinates for the header band
|
|
283
|
+
header_top = ad_y1
|
|
284
|
+
header_bottom = ad_y2
|
|
285
|
+
bands[header_config['level']] = (header_top, header_bottom)
|
|
286
|
+
current_y = header_bottom
|
|
287
|
+
remaining_configs = [s for s in shelf_configs if s.get('level') != 'header']
|
|
288
|
+
else:
|
|
289
|
+
# If no header is defined, start from the top of the ROI
|
|
290
|
+
current_y = ry1
|
|
291
|
+
remaining_configs = shelf_configs
|
|
292
|
+
|
|
293
|
+
# --- 2. Calculate space for remaining shelves ---
|
|
294
|
+
remaining_roi_h = max(1, ry2 - current_y)
|
|
295
|
+
|
|
296
|
+
# Calculate space consumed by shelves with a fixed height_ratio
|
|
297
|
+
height_from_ratios = 0
|
|
298
|
+
shelves_without_ratio = []
|
|
299
|
+
for shelf_config in remaining_configs:
|
|
300
|
+
if 'height_ratio' in shelf_config and shelf_config['height_ratio'] is not None:
|
|
301
|
+
# height_ratio is a percentage of the TOTAL ROI height
|
|
302
|
+
height_from_ratios += int(shelf_config['height_ratio'] * roi_h)
|
|
303
|
+
else:
|
|
304
|
+
shelves_without_ratio.append(shelf_config)
|
|
305
|
+
|
|
306
|
+
# Calculate height for each shelf without a specified ratio
|
|
307
|
+
auto_size_h = max(0, remaining_roi_h - height_from_ratios)
|
|
308
|
+
auto_shelf_height = int(auto_size_h / len(shelves_without_ratio)) if shelves_without_ratio else 0
|
|
309
|
+
|
|
310
|
+
# --- 3. Build the bands for the remaining shelves ---
|
|
311
|
+
for i, shelf_config in enumerate(remaining_configs):
|
|
312
|
+
shelf_level = shelf_config['level']
|
|
313
|
+
|
|
314
|
+
if 'height_ratio' in shelf_config and shelf_config['height_ratio'] is not None:
|
|
315
|
+
shelf_pixel_height = int(shelf_config['height_ratio'] * roi_h)
|
|
316
|
+
else:
|
|
317
|
+
shelf_pixel_height = auto_shelf_height
|
|
318
|
+
|
|
319
|
+
shelf_bottom = current_y + shelf_pixel_height
|
|
320
|
+
|
|
321
|
+
# For the very last shelf, ensure it extends to the bottom of the ROI
|
|
322
|
+
if i == len(remaining_configs) - 1:
|
|
323
|
+
shelf_bottom = ry2
|
|
324
|
+
|
|
325
|
+
# VALIDATION: Ensure valid bounding box
|
|
326
|
+
if shelf_bottom <= current_y:
|
|
327
|
+
print(
|
|
328
|
+
f"WARNING: Invalid shelf {shelf_level}: y1={current_y}, y2={shelf_bottom}"
|
|
329
|
+
)
|
|
330
|
+
shelf_bottom = current_y + 50 # Minimum height
|
|
331
|
+
|
|
332
|
+
bands[shelf_level] = (current_y, shelf_bottom)
|
|
333
|
+
current_y = shelf_bottom
|
|
334
|
+
|
|
335
|
+
# --- 4. Create the levels list (separator lines) ---
|
|
336
|
+
# The levels are the bottom coordinate of each shelf band, except for the last one
|
|
337
|
+
if bands:
|
|
338
|
+
# Ensure order from top to bottom based on the planogram config
|
|
339
|
+
ordered_levels = [bands[s['level']][1] for s in shelf_configs if s['level'] in bands]
|
|
340
|
+
levels = ordered_levels[:-1]
|
|
341
|
+
|
|
342
|
+
self.logger.debug(
|
|
343
|
+
f"📊 Planogram Shelves: {len(shelf_configs)} shelves configured, "
|
|
344
|
+
f"ROI height={roi_h}, bands={bands}"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return levels, bands
|
|
348
|
+
|
|
349
|
+
def _find_shelves_proportional(self, roi: tuple, rx1, ry1, rx2, ry2, H, planogram_config: dict = None):
|
|
350
|
+
"""
|
|
351
|
+
Fallback proportional layout using planogram config or default 3-shelf layout.
|
|
352
|
+
"""
|
|
353
|
+
roi_h = max(1, ry2 - ry1)
|
|
354
|
+
|
|
355
|
+
# Use planogram config if available
|
|
356
|
+
if planogram_config and 'shelves' in planogram_config:
|
|
357
|
+
shelf_configs = planogram_config['shelves']
|
|
358
|
+
num_shelves = len(shelf_configs)
|
|
359
|
+
|
|
360
|
+
if num_shelves > 0:
|
|
361
|
+
# Equal division among configured shelves
|
|
362
|
+
shelf_height = roi_h // num_shelves
|
|
363
|
+
|
|
364
|
+
levels = []
|
|
365
|
+
bands = {}
|
|
366
|
+
current_y = ry1
|
|
367
|
+
|
|
368
|
+
for i, shelf_config in enumerate(shelf_configs):
|
|
369
|
+
shelf_level = shelf_config['level']
|
|
370
|
+
shelf_bottom = current_y + shelf_height
|
|
371
|
+
|
|
372
|
+
# For the last shelf, extend to ROI bottom
|
|
373
|
+
if i == len(shelf_configs) - 1:
|
|
374
|
+
shelf_bottom = ry2
|
|
375
|
+
|
|
376
|
+
bands[shelf_level] = (current_y, shelf_bottom)
|
|
377
|
+
if i < len(shelf_configs) - 1: # Don't add last boundary to levels
|
|
378
|
+
levels.append(shelf_bottom)
|
|
379
|
+
|
|
380
|
+
current_y = shelf_bottom
|
|
381
|
+
|
|
382
|
+
return levels, bands
|
|
383
|
+
|
|
384
|
+
# Default fallback: 3-shelf layout if no config
|
|
385
|
+
hdr_r, mid_r, bot_r = 0.40, 0.30, 0.30
|
|
386
|
+
|
|
387
|
+
header_bottom = ry1 + int(hdr_r * roi_h)
|
|
388
|
+
middle_bottom = header_bottom + int(mid_r * roi_h)
|
|
389
|
+
|
|
390
|
+
# Ensure boundaries don't exceed ROI
|
|
391
|
+
header_bottom = max(ry1 + 20, min(header_bottom, ry2 - 40))
|
|
392
|
+
middle_bottom = max(header_bottom + 20, min(middle_bottom, ry2 - 20))
|
|
393
|
+
|
|
394
|
+
levels = [header_bottom, middle_bottom]
|
|
395
|
+
bands = {
|
|
396
|
+
"header": (ry1, header_bottom),
|
|
397
|
+
"middle": (header_bottom, middle_bottom),
|
|
398
|
+
"bottom": (middle_bottom, ry2),
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
return levels, bands
|
|
402
|
+
|
|
403
|
+
# ---------------------------- YOLO ---------------------------------------
|
|
404
|
+
def _preprocess_roi_for_detection(self, roi: np.ndarray) -> np.ndarray:
|
|
405
|
+
"""
|
|
406
|
+
Ultra-minimal preprocessing - only applies when absolutely necessary.
|
|
407
|
+
Use this version if you want maximum preservation of original image quality.
|
|
408
|
+
"""
|
|
409
|
+
try:
|
|
410
|
+
# Convert BGR to RGB if needed
|
|
411
|
+
if len(roi.shape) == 3 and roi.shape[2] == 3:
|
|
412
|
+
rgb_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
|
|
413
|
+
else:
|
|
414
|
+
rgb_roi = roi.copy()
|
|
415
|
+
|
|
416
|
+
# Quick contrast check
|
|
417
|
+
gray = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2GRAY)
|
|
418
|
+
contrast = gray.std()
|
|
419
|
+
|
|
420
|
+
# Only process if contrast is very low
|
|
421
|
+
if contrast > 35:
|
|
422
|
+
# Good contrast - return original with minimal sharpening
|
|
423
|
+
result = rgb_roi.astype(np.float32)
|
|
424
|
+
|
|
425
|
+
# Ultra-subtle sharpening
|
|
426
|
+
kernel = np.array([[0, -0.05, 0],
|
|
427
|
+
[-0.05, 1.2, -0.05],
|
|
428
|
+
[0, -0.05, 0]])
|
|
429
|
+
|
|
430
|
+
for i in range(3):
|
|
431
|
+
result[:,:,i] = cv2.filter2D(result[:,:,i], -1, kernel)
|
|
432
|
+
|
|
433
|
+
result = np.clip(result, 0, 255).astype(np.uint8)
|
|
434
|
+
else:
|
|
435
|
+
# Low contrast - apply gentle CLAHE only
|
|
436
|
+
lab = cv2.cvtColor(rgb_roi, cv2.COLOR_RGB2LAB)
|
|
437
|
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(10,10))
|
|
438
|
+
lab[:,:,0] = clahe.apply(lab[:,:,0])
|
|
439
|
+
result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
|
|
440
|
+
|
|
441
|
+
# Convert back to BGR if needed
|
|
442
|
+
if len(roi.shape) == 3 and roi.shape[2] == 3:
|
|
443
|
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
|
444
|
+
|
|
445
|
+
return result
|
|
446
|
+
|
|
447
|
+
except Exception as e:
|
|
448
|
+
self.logger.warning(f"Minimal ROI preprocessing failed: {e}")
|
|
449
|
+
return roi
|
|
450
|
+
|
|
451
|
+
def _yolo_props(self, roi: np.ndarray, rx1, ry1, detection_phases: Optional[List[Dict[str, Any]]] = None):
|
|
452
|
+
"""
|
|
453
|
+
Multi-phase YOLO detection with configurable confidence levels and weighted scoring.
|
|
454
|
+
Returns proposals in the same format expected by existing _classify_proposals method.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
roi: ROI image array
|
|
458
|
+
rx1, ry1: ROI offset coordinates
|
|
459
|
+
detection_phases: List of phase configurations. If None, uses default 2-phase approach.
|
|
460
|
+
"""
|
|
461
|
+
# printer ≈ 5–9%, product_box ≈ 7–12%, promotional_graphic ≥ 20%
|
|
462
|
+
CLASS_LIMITS = {
|
|
463
|
+
# Base retail categories
|
|
464
|
+
"poster": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5},
|
|
465
|
+
"person": {"min_area": 0.02, "max_area": 0.60, "min_ar": 0.3, "max_ar": 3.5},
|
|
466
|
+
"printer": {"min_area": 0.010, "max_area": 0.28, "min_ar": 0.6, "max_ar": 2.8},
|
|
467
|
+
"product_box": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2},
|
|
468
|
+
"price_tag": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0},
|
|
469
|
+
|
|
470
|
+
# YOLO classes mapped to retail categories with their own limits
|
|
471
|
+
"tv": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5}, # → poster
|
|
472
|
+
"monitor": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5}, # → poster
|
|
473
|
+
"laptop": {"min_area": 0.06, "max_area": 0.95, "min_ar": 0.5, "max_ar": 3.5}, # → poster
|
|
474
|
+
"microwave": {"min_area": 0.010, "max_area": 0.28, "min_ar": 0.6, "max_ar": 2.8}, # → printer
|
|
475
|
+
"book": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2}, # → product_box
|
|
476
|
+
"box": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2}, # → product_box
|
|
477
|
+
"suitcase": {"min_area": 0.003, "max_area": 0.20, "min_ar": 0.4, "max_ar": 3.2}, # → product_box
|
|
478
|
+
"bottle": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
|
|
479
|
+
"clock": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
|
|
480
|
+
"mouse": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
|
|
481
|
+
"remote": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
|
|
482
|
+
"cell phone": {"min_area": 0.0006,"max_area": 0.010,"min_ar": 1.6, "max_ar": 8.0}, # → price_tag
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
# Mapping from YOLO classes to retail categories
|
|
486
|
+
YOLO_TO_RETAIL = {
|
|
487
|
+
"tv": "poster",
|
|
488
|
+
"monitor": "poster",
|
|
489
|
+
"laptop": "poster",
|
|
490
|
+
"microwave": "printer",
|
|
491
|
+
"keyboard": "product_box",
|
|
492
|
+
"book": "product_box",
|
|
493
|
+
"box": "product_box",
|
|
494
|
+
"suitcase": "product_box",
|
|
495
|
+
"bottle": "price_tag",
|
|
496
|
+
"clock": "price_tag",
|
|
497
|
+
"mouse": "price_tag",
|
|
498
|
+
"remote": "price_tag",
|
|
499
|
+
"cell phone": "price_tag",
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
def _get_class_limits(yolo_class: str) -> Optional[Dict[str, float]]:
|
|
503
|
+
"""Get class limits for a YOLO class"""
|
|
504
|
+
return CLASS_LIMITS.get(yolo_class, None)
|
|
505
|
+
|
|
506
|
+
def _get_retail_category(yolo_class: str) -> str:
|
|
507
|
+
"""Map YOLO class to retail category"""
|
|
508
|
+
return YOLO_TO_RETAIL.get(yolo_class, yolo_class)
|
|
509
|
+
|
|
510
|
+
def _passes_class_limits(yolo_class: str, area_ratio: float, aspect_ratio: float) -> tuple[bool, str]:
|
|
511
|
+
"""Check if detection passes class-specific limits"""
|
|
512
|
+
limits = _get_class_limits(yolo_class)
|
|
513
|
+
if not limits:
|
|
514
|
+
# Use generic fallback limits if no class-specific ones
|
|
515
|
+
generic_ok = (0.0008 <= area_ratio <= 0.9 and 0.1 <= aspect_ratio <= 10.0)
|
|
516
|
+
return generic_ok, "generic_limits"
|
|
517
|
+
|
|
518
|
+
area_ok = limits["min_area"] <= area_ratio <= limits["max_area"]
|
|
519
|
+
ar_ok = limits["min_ar"] <= aspect_ratio <= limits["max_ar"]
|
|
520
|
+
|
|
521
|
+
if area_ok and ar_ok:
|
|
522
|
+
retail_category = _get_retail_category(yolo_class)
|
|
523
|
+
return True, f"class_limits_{yolo_class}→{retail_category}"
|
|
524
|
+
else:
|
|
525
|
+
# Provide specific failure reason for debugging
|
|
526
|
+
reasons = []
|
|
527
|
+
if not area_ok:
|
|
528
|
+
reasons.append(
|
|
529
|
+
f"area={area_ratio:.4f} not in [{limits['min_area']:.4f}, {limits['max_area']:.4f}]"
|
|
530
|
+
)
|
|
531
|
+
if not ar_ok:
|
|
532
|
+
reasons.append(
|
|
533
|
+
f"ar={aspect_ratio:.2f} not in [{limits['min_ar']:.2f}, {limits['max_ar']:.2f}]"
|
|
534
|
+
)
|
|
535
|
+
return False, f"failed_{yolo_class}: {'; '.join(reasons)}"
|
|
536
|
+
|
|
537
|
+
# Preprocess ROI to enhance detection of similar-colored objects
|
|
538
|
+
enhanced_roi = self._preprocess_roi_for_detection(roi)
|
|
539
|
+
|
|
540
|
+
if detection_phases is None:
|
|
541
|
+
detection_phases = [
|
|
542
|
+
{ # Coarse: quickly find large boxes (e.g., header, promo)
|
|
543
|
+
"name": "coarse",
|
|
544
|
+
"conf": 0.35,
|
|
545
|
+
"iou": 0.35,
|
|
546
|
+
"weight": 0.20,
|
|
547
|
+
"min_area": 0.05, # >= 5% of ROI
|
|
548
|
+
"description": "High confidence pass for large objects",
|
|
549
|
+
},
|
|
550
|
+
# Standard: main workhorse for printers & boxes
|
|
551
|
+
{
|
|
552
|
+
"name": "standard",
|
|
553
|
+
"conf": 0.05,
|
|
554
|
+
"iou": 0.20,
|
|
555
|
+
"weight": 0.70,
|
|
556
|
+
"min_area": 0.001,
|
|
557
|
+
"description": "High confidence pass for clear objects"
|
|
558
|
+
},
|
|
559
|
+
# Aggressive: recover misses but still bounded by class limits
|
|
560
|
+
{
|
|
561
|
+
"name": "aggressive",
|
|
562
|
+
"conf": 0.008,
|
|
563
|
+
"iou": 0.15,
|
|
564
|
+
"weight": 0.10,
|
|
565
|
+
"min_area": 0.0006,
|
|
566
|
+
"description": "Selective aggressive pass for missed objects only"
|
|
567
|
+
},
|
|
568
|
+
]
|
|
569
|
+
|
|
570
|
+
try:
|
|
571
|
+
H, W = roi.shape[:2]
|
|
572
|
+
roi_area = H * W
|
|
573
|
+
all_proposals = []
|
|
574
|
+
|
|
575
|
+
print(f"\n🔄 Detection with Your Preferred Settings on ROI {W}x{H}")
|
|
576
|
+
print(" " + "="*70)
|
|
577
|
+
|
|
578
|
+
# Statistics tracking
|
|
579
|
+
stats = {
|
|
580
|
+
"total_detections": 0,
|
|
581
|
+
"passed_confidence": 0,
|
|
582
|
+
"passed_size": 0,
|
|
583
|
+
"passed_class_limits": 0,
|
|
584
|
+
"rejected_class_limits": 0
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
# Run both phases with your settings
|
|
588
|
+
for phase_idx, phase in enumerate(detection_phases):
|
|
589
|
+
phase_name = phase["name"]
|
|
590
|
+
conf_thresh = phase["conf"]
|
|
591
|
+
iou_thresh = phase["iou"]
|
|
592
|
+
weight = phase["weight"]
|
|
593
|
+
|
|
594
|
+
print(
|
|
595
|
+
f"\n📡 Phase {phase_idx + 1}: {phase_name}"
|
|
596
|
+
)
|
|
597
|
+
print(
|
|
598
|
+
f" Config: conf={conf_thresh}, iou={iou_thresh}, weight={weight}"
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
r = self.yolo(enhanced_roi, conf=conf_thresh, iou=iou_thresh, verbose=False)[0]
|
|
602
|
+
|
|
603
|
+
if not hasattr(r, 'boxes') or r.boxes is None:
|
|
604
|
+
print(f" 📊 No boxes detected in {phase_name}")
|
|
605
|
+
continue
|
|
606
|
+
|
|
607
|
+
xyxy = r.boxes.xyxy.cpu().numpy()
|
|
608
|
+
confs = r.boxes.conf.cpu().numpy()
|
|
609
|
+
classes = r.boxes.cls.cpu().numpy().astype(int)
|
|
610
|
+
names = r.names
|
|
611
|
+
|
|
612
|
+
print(
|
|
613
|
+
f" 📊 Raw YOLO output: {len(xyxy)} detections"
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
phase_count = 0
|
|
617
|
+
phase_rejected = 0
|
|
618
|
+
|
|
619
|
+
for _, ((x1, y1, x2, y2), conf, cls_id) in enumerate(zip(xyxy, confs, classes)):
|
|
620
|
+
gx1, gy1, gx2, gy2 = int(x1) + rx1, int(y1) + ry1, int(x2) + rx1, int(y2) + ry1
|
|
621
|
+
|
|
622
|
+
width, height = x2 - x1, y2 - y1
|
|
623
|
+
if width <= 0 or height <= 0 or width < 8 or height < 8:
|
|
624
|
+
continue
|
|
625
|
+
|
|
626
|
+
if conf < conf_thresh:
|
|
627
|
+
continue
|
|
628
|
+
|
|
629
|
+
stats["passed_confidence"] += 1
|
|
630
|
+
|
|
631
|
+
area = width * height
|
|
632
|
+
area_ratio = area / roi_area
|
|
633
|
+
aspect_ratio = width / max(height, 1)
|
|
634
|
+
yolo_class = names[cls_id]
|
|
635
|
+
|
|
636
|
+
min_area = phase.get("min_area")
|
|
637
|
+
if min_area and area_ratio < float(min_area):
|
|
638
|
+
continue
|
|
639
|
+
|
|
640
|
+
stats["passed_size"] += 1
|
|
641
|
+
|
|
642
|
+
# Apply class-specific limits
|
|
643
|
+
limits_passed, limit_reason = _passes_class_limits(yolo_class, area_ratio, aspect_ratio)
|
|
644
|
+
|
|
645
|
+
if not limits_passed:
|
|
646
|
+
phase_rejected += 1
|
|
647
|
+
stats["rejected_class_limits"] += 1
|
|
648
|
+
if phase_rejected <= 3: # Log first few rejections for debugging
|
|
649
|
+
print(f" ❌ Rejected {yolo_class}: {limit_reason}")
|
|
650
|
+
continue
|
|
651
|
+
|
|
652
|
+
ocr_text = None
|
|
653
|
+
orientation = self._detect_orientation(gx1, gy1, gx2, gy2)
|
|
654
|
+
if (area_ratio >= 0.0008 and area_ratio <= 0.9):
|
|
655
|
+
# Only run OCR on boxes with an area > 5% of the ROI
|
|
656
|
+
if area_ratio > 0.05:
|
|
657
|
+
try:
|
|
658
|
+
# Crop the specific proposal from the ROI image
|
|
659
|
+
# Use local coordinates (x1, y1, x2, y2) for this
|
|
660
|
+
proposal_img_crop = roi[int(y1):int(y2), int(x1):int(x2)]
|
|
661
|
+
|
|
662
|
+
# --- ROTATION LOGIC for VERTICAL BOXES ---
|
|
663
|
+
if orientation == 'vertical':
|
|
664
|
+
# Rotate the crop 90 degrees counter-clockwise to make text horizontal
|
|
665
|
+
proposal_img_crop = cv2.rotate(
|
|
666
|
+
proposal_img_crop,
|
|
667
|
+
cv2.ROTATE_90_CLOCKWISE
|
|
668
|
+
)
|
|
669
|
+
text = pytesseract.image_to_string(
|
|
670
|
+
proposal_img_crop,
|
|
671
|
+
# config='--psm 6'
|
|
672
|
+
config="--psm 6 -l eng"
|
|
673
|
+
)
|
|
674
|
+
proposal_img_crop = cv2.rotate(
|
|
675
|
+
proposal_img_crop,
|
|
676
|
+
cv2.ROTATE_90_COUNTERCLOCKWISE
|
|
677
|
+
)
|
|
678
|
+
vtext = pytesseract.image_to_string(
|
|
679
|
+
proposal_img_crop,
|
|
680
|
+
# config='--psm 6'
|
|
681
|
+
config="--psm 6 -l eng"
|
|
682
|
+
)
|
|
683
|
+
raw_text = text + ' | ' + vtext
|
|
684
|
+
else:
|
|
685
|
+
# Run Tesseract on the crop
|
|
686
|
+
raw_text = pytesseract.image_to_string(
|
|
687
|
+
proposal_img_crop,
|
|
688
|
+
# config='--psm 6'
|
|
689
|
+
config="--psm 6 -l eng"
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
# Clean up the text
|
|
693
|
+
ocr_text = " ".join(raw_text.strip().split())
|
|
694
|
+
except Exception as ocr_error:
|
|
695
|
+
self.logger.warning(
|
|
696
|
+
f"OCR failed for a proposal: {ocr_error}"
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
orientation = self._detect_orientation(gx1, gy1, gx2, gy2)
|
|
700
|
+
weighted_conf = float(conf) * weight
|
|
701
|
+
proposal = {
|
|
702
|
+
"yolo_label": yolo_class,
|
|
703
|
+
"yolo_conf": float(conf),
|
|
704
|
+
"weighted_conf": weighted_conf,
|
|
705
|
+
"box": (gx1, gy1, gx2, gy2),
|
|
706
|
+
"area_ratio": area_ratio,
|
|
707
|
+
"aspect_ratio": aspect_ratio,
|
|
708
|
+
"orientation": orientation,
|
|
709
|
+
"retail_candidates": self._get_retail_candidates(yolo_class),
|
|
710
|
+
"raw_index": len(all_proposals) + 1,
|
|
711
|
+
"ocr_text": ocr_text,
|
|
712
|
+
"phase": phase_name
|
|
713
|
+
}
|
|
714
|
+
# print('PROPOSAL > ', proposal)
|
|
715
|
+
all_proposals.append(proposal)
|
|
716
|
+
stats["total_detections"] += 1
|
|
717
|
+
phase_count += 1
|
|
718
|
+
|
|
719
|
+
print(f" ✅ Kept {phase_count} detections from {phase_name}")
|
|
720
|
+
|
|
721
|
+
# Light deduplication (let classification handle quality control)
|
|
722
|
+
deduplicated = self._object_deduplication(all_proposals)
|
|
723
|
+
|
|
724
|
+
print(f"\n📊 Detection Summary: {len(deduplicated)} total proposals")
|
|
725
|
+
print(" Focus: Let classification phase handle object type distinction")
|
|
726
|
+
|
|
727
|
+
# Print final statistics
|
|
728
|
+
print(f"\n📊 Detection Summary:")
|
|
729
|
+
print(f" Total YOLO detections: {stats['total_detections']}")
|
|
730
|
+
print(f" Passed confidence: {stats['passed_confidence']}")
|
|
731
|
+
print(f" Passed basic size: {stats['passed_size']}")
|
|
732
|
+
print(f" Passed class limits: {stats['passed_class_limits']}")
|
|
733
|
+
print(f" Rejected by class limits: {stats['rejected_class_limits']}")
|
|
734
|
+
print(f" Final after deduplication: {len(deduplicated)}")
|
|
735
|
+
return deduplicated
|
|
736
|
+
|
|
737
|
+
except Exception as e:
|
|
738
|
+
print(f"Detection failed: {e}")
|
|
739
|
+
traceback.print_exc()
|
|
740
|
+
return []
|
|
741
|
+
|
|
742
|
+
def _determine_shelf_level(self, center_y: float, bands: Dict[str, tuple]) -> str:
|
|
743
|
+
"""Enhanced shelf level determination"""
|
|
744
|
+
if not bands:
|
|
745
|
+
return "unknown"
|
|
746
|
+
|
|
747
|
+
for level, (y1, y2) in bands.items():
|
|
748
|
+
if y1 <= center_y <= y2:
|
|
749
|
+
return level
|
|
750
|
+
|
|
751
|
+
# If not in any band, find closest
|
|
752
|
+
min_distance = float('inf')
|
|
753
|
+
closest_level = "unknown"
|
|
754
|
+
for level, (y1, y2) in bands.items():
|
|
755
|
+
band_center = (y1 + y2) / 2
|
|
756
|
+
distance = abs(center_y - band_center)
|
|
757
|
+
if distance < min_distance:
|
|
758
|
+
min_distance = distance
|
|
759
|
+
closest_level = level
|
|
760
|
+
|
|
761
|
+
return closest_level
|
|
762
|
+
|
|
763
|
+
def _detect_orientation(self, x1: int, y1: int, x2: int, y2: int) -> str:
|
|
764
|
+
"""Detect orientation from bounding box dimensions"""
|
|
765
|
+
width = x2 - x1
|
|
766
|
+
height = y2 - y1
|
|
767
|
+
aspect_ratio = width / max(height, 1)
|
|
768
|
+
|
|
769
|
+
if aspect_ratio < 0.8:
|
|
770
|
+
return "vertical"
|
|
771
|
+
elif aspect_ratio > 1.5:
|
|
772
|
+
return "horizontal"
|
|
773
|
+
else:
|
|
774
|
+
return "square"
|
|
775
|
+
|
|
776
|
+
def _get_retail_candidates(self, yolo_class: str) -> List[str]:
|
|
777
|
+
"""Light retail candidate mapping - let classification do the heavy work"""
|
|
778
|
+
mapping = {
|
|
779
|
+
"microwave": ["printer", "product_box"],
|
|
780
|
+
"tv": ["promotional_graphic", "tv"],
|
|
781
|
+
"television": ["tv"],
|
|
782
|
+
"monitor": ["promotional_graphic"],
|
|
783
|
+
"laptop": ["promotional_graphic"],
|
|
784
|
+
"book": ["product_box"],
|
|
785
|
+
"box": ["product_box"],
|
|
786
|
+
"suitcase": ["product_box", "printer"],
|
|
787
|
+
"bottle": ["ink_bottle", "price_tag"],
|
|
788
|
+
"person": ["promotional_graphic"],
|
|
789
|
+
"clock": ["small_object", "price_tag"],
|
|
790
|
+
"cell phone": ["small_object", "price_tag"],
|
|
791
|
+
}
|
|
792
|
+
return mapping.get(yolo_class, ["product_candidate"])
|
|
793
|
+
|
|
794
|
+
def _object_deduplication(self, all_detections: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
795
|
+
"""
|
|
796
|
+
Enhanced deduplication with container/contained logic and better IoU thresholds
|
|
797
|
+
"""
|
|
798
|
+
if not all_detections:
|
|
799
|
+
return []
|
|
800
|
+
|
|
801
|
+
# Sort by weighted confidence (highest first)
|
|
802
|
+
sorted_detections = sorted(all_detections, key=lambda x: x["weighted_conf"], reverse=True)
|
|
803
|
+
|
|
804
|
+
deduplicated = []
|
|
805
|
+
for detection in sorted_detections:
|
|
806
|
+
detection_box = detection["box"]
|
|
807
|
+
x1, y1, x2, y2 = detection_box
|
|
808
|
+
detection_area = (x2 - x1) * (y2 - y1)
|
|
809
|
+
|
|
810
|
+
is_duplicate = False
|
|
811
|
+
is_contained = False
|
|
812
|
+
|
|
813
|
+
for kept in deduplicated[:]:
|
|
814
|
+
kept_box = kept["box"]
|
|
815
|
+
kx1, ky1, kx2, ky2 = kept_box
|
|
816
|
+
kept_area = (kx2 - kx1) * (ky2 - ky1)
|
|
817
|
+
|
|
818
|
+
iou = self._calculate_iou_tuples(detection_box, kept_box)
|
|
819
|
+
|
|
820
|
+
# Standard IoU-based deduplication (lowered threshold)
|
|
821
|
+
if iou > 0.5: # Reduced from 0.7 to 0.5
|
|
822
|
+
is_duplicate = True
|
|
823
|
+
break
|
|
824
|
+
|
|
825
|
+
# (e.g., individual box vs. entire shelf detection)
|
|
826
|
+
if kept_area > detection_area * 3: # Kept is 3x larger
|
|
827
|
+
# Check if detection is substantially contained within kept
|
|
828
|
+
overlap_area = max(0, min(x2, kx2) - max(x1, kx1)) * max(0, min(y2, ky2) - max(y1, ky1))
|
|
829
|
+
contained_ratio = overlap_area / detection_area
|
|
830
|
+
if contained_ratio > 0.8: # 80% of detection is inside kept
|
|
831
|
+
is_contained = True
|
|
832
|
+
break
|
|
833
|
+
|
|
834
|
+
# Check if kept detection is contained within current (much larger) detection
|
|
835
|
+
elif detection_area > kept_area * 3: # Current is 3x larger
|
|
836
|
+
overlap_area = max(0, min(x2, kx2) - max(x1, kx1)) * max(0, min(y2, ky2) - max(y1, ky1))
|
|
837
|
+
contained_ratio = overlap_area / kept_area
|
|
838
|
+
if contained_ratio > 0.8: # 80% of kept is inside current
|
|
839
|
+
# Remove the contained detection and replace with current
|
|
840
|
+
deduplicated.remove(kept)
|
|
841
|
+
|
|
842
|
+
if not is_duplicate and not is_contained:
|
|
843
|
+
deduplicated.append(detection)
|
|
844
|
+
|
|
845
|
+
print(
|
|
846
|
+
f" 🔄 Deduplication: {len(sorted_detections)} → {len(deduplicated)} detections"
|
|
847
|
+
)
|
|
848
|
+
return deduplicated
|
|
849
|
+
|
|
850
|
+
# Additional helper method for phase configuration
|
|
851
|
+
def set_detection_phases(self, phases: List[Dict[str, Any]]):
|
|
852
|
+
"""
|
|
853
|
+
Set custom detection phases for the RetailDetector
|
|
854
|
+
|
|
855
|
+
Args:
|
|
856
|
+
phases: List of phase configurations, each containing:
|
|
857
|
+
- name: Phase identifier
|
|
858
|
+
- conf: Confidence threshold
|
|
859
|
+
- iou: IoU threshold
|
|
860
|
+
- weight: Weight for this phase (should sum to 1.0 across all phases)
|
|
861
|
+
- description: Optional description
|
|
862
|
+
|
|
863
|
+
Example:
|
|
864
|
+
detector.set_detection_phases([
|
|
865
|
+
{
|
|
866
|
+
"name": "ultra_high_conf",
|
|
867
|
+
"conf": 0.5,
|
|
868
|
+
"iou": 0.6,
|
|
869
|
+
"weight": 0.5,
|
|
870
|
+
"description": "Ultra high confidence for definite objects"
|
|
871
|
+
},
|
|
872
|
+
{
|
|
873
|
+
"name": "medium_conf",
|
|
874
|
+
"conf": 0.15,
|
|
875
|
+
"iou": 0.4,
|
|
876
|
+
"weight": 0.3,
|
|
877
|
+
"description": "Medium confidence for likely objects"
|
|
878
|
+
},
|
|
879
|
+
{
|
|
880
|
+
"name": "aggressive",
|
|
881
|
+
"conf": 0.005,
|
|
882
|
+
"iou": 0.15,
|
|
883
|
+
"weight": 0.2,
|
|
884
|
+
"description": "Aggressive pass for missed objects"
|
|
885
|
+
}
|
|
886
|
+
])
|
|
887
|
+
"""
|
|
888
|
+
# Validate phase configuration
|
|
889
|
+
total_weight = sum(phase.get("weight", 0) for phase in phases)
|
|
890
|
+
if abs(total_weight - 1.0) > 0.01:
|
|
891
|
+
print(f"⚠️ Warning: Phase weights sum to {total_weight:.3f}, not 1.0")
|
|
892
|
+
|
|
893
|
+
# Validate required fields
|
|
894
|
+
for i, phase in enumerate(phases):
|
|
895
|
+
required_fields = ["name", "conf", "iou", "weight"]
|
|
896
|
+
missing = [field for field in required_fields if field not in phase]
|
|
897
|
+
if missing:
|
|
898
|
+
raise ValueError(f"Phase {i} missing required fields: {missing}")
|
|
899
|
+
|
|
900
|
+
self.detection_phases = phases
|
|
901
|
+
print(f"✅ Configured {len(phases)} detection phases")
|
|
902
|
+
for i, phase in enumerate(phases):
|
|
903
|
+
print(f" Phase {i+1}: {phase['name']} (conf={phase['conf']}, weight={phase['weight']})")
|
|
904
|
+
|
|
905
|
+
def _calculate_iou_tuples(self, box1: tuple, box2: tuple) -> float:
|
|
906
|
+
"""Calculate IoU between two bounding boxes in tuple format"""
|
|
907
|
+
x1_1, y1_1, x2_1, y2_1 = box1
|
|
908
|
+
x1_2, y1_2, x2_2, y2_2 = box2
|
|
909
|
+
|
|
910
|
+
# Calculate intersection
|
|
911
|
+
ix1, iy1 = max(x1_1, x1_2), max(y1_1, y1_2)
|
|
912
|
+
ix2, iy2 = min(x2_1, x2_2), min(y2_1, y2_2)
|
|
913
|
+
|
|
914
|
+
if ix2 <= ix1 or iy2 <= iy1:
|
|
915
|
+
return 0.0
|
|
916
|
+
|
|
917
|
+
intersection = (ix2 - ix1) * (iy2 - iy1)
|
|
918
|
+
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
|
|
919
|
+
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
|
920
|
+
union = area1 + area2 - intersection
|
|
921
|
+
|
|
922
|
+
return intersection / max(union, 1)
|
|
923
|
+
|
|
924
|
+
# ------------------- OCR + CLIP preselection -----------------------------
|
|
925
|
+
def _analyze_crop_visuals(self, crop_bgr: np.ndarray) -> dict:
|
|
926
|
+
"""Analyzes a crop for dominant color properties to distinguish printers from boxes."""
|
|
927
|
+
if crop_bgr.size == 0:
|
|
928
|
+
return {"is_mostly_white": False, "is_mostly_blue": False}
|
|
929
|
+
|
|
930
|
+
# Convert to HSV for better color analysis
|
|
931
|
+
hsv = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2HSV)
|
|
932
|
+
|
|
933
|
+
# --- White/Gray Detection ---
|
|
934
|
+
# Define a broad range for white, light gray, and silver colors
|
|
935
|
+
lower_white = np.array([0, 0, 150])
|
|
936
|
+
upper_white = np.array([180, 50, 255])
|
|
937
|
+
white_mask = cv2.inRange(hsv, lower_white, upper_white)
|
|
938
|
+
|
|
939
|
+
# --- Blue Detection ---
|
|
940
|
+
# Define a range for the Epson blue
|
|
941
|
+
lower_blue = np.array([95, 80, 40])
|
|
942
|
+
upper_blue = np.array([125, 255, 255])
|
|
943
|
+
blue_mask = cv2.inRange(hsv, lower_blue, upper_blue)
|
|
944
|
+
|
|
945
|
+
# Calculate the percentage of the image that is white or blue
|
|
946
|
+
total_pixels = crop_bgr.shape[0] * crop_bgr.shape[1]
|
|
947
|
+
white_percentage = (cv2.countNonZero(white_mask) / total_pixels) * 100
|
|
948
|
+
blue_percentage = (cv2.countNonZero(blue_mask) / total_pixels) * 100
|
|
949
|
+
|
|
950
|
+
# Determine if the object is primarily one color
|
|
951
|
+
# Thresholds can be tuned, but these are generally effective.
|
|
952
|
+
is_mostly_white = white_percentage > 40
|
|
953
|
+
is_mostly_blue = blue_percentage > 35
|
|
954
|
+
|
|
955
|
+
return {
|
|
956
|
+
"is_mostly_white": is_mostly_white,
|
|
957
|
+
"is_mostly_blue": is_mostly_blue,
|
|
958
|
+
"white_pct": white_percentage,
|
|
959
|
+
"blue_pct": blue_percentage,
|
|
960
|
+
}
|
|
961
|
+
|
|
962
|
+
async def _classify_proposals(self, img, props, bands, header_limit_y, ad_box=None):
|
|
963
|
+
"""
|
|
964
|
+
ENHANCED proposal classification with a robust, heuristic-first decision process.
|
|
965
|
+
1. Identify price tags by size.
|
|
966
|
+
2. Identify promotional graphics by position.
|
|
967
|
+
3. For remaining objects, use strong visual heuristics (color) to classify.
|
|
968
|
+
4. Use CLIP similarity only as a fallback for ambiguous cases.
|
|
969
|
+
"""
|
|
970
|
+
H, W = img.shape[:2]
|
|
971
|
+
final_proposals = []
|
|
972
|
+
PRICE_TAG_AREA_THRESHOLD = 0.005 # 0.5% of total image area
|
|
973
|
+
|
|
974
|
+
print(f"\n🎯 Enhanced Classification: Running {len(props)} proposals...")
|
|
975
|
+
print(" " + "="*60)
|
|
976
|
+
|
|
977
|
+
for p in props:
|
|
978
|
+
x1, y1, x2, y2 = p["box"]
|
|
979
|
+
area = (x2 - x1) * (y2 - y1)
|
|
980
|
+
area_ratio = area / (H * W)
|
|
981
|
+
center_y = (y1 + y2) / 2
|
|
982
|
+
|
|
983
|
+
# Helper to determine shelf level for context
|
|
984
|
+
shelf_level = self._determine_shelf_level(center_y, bands)
|
|
985
|
+
|
|
986
|
+
# --- 1. Price Tag Check (by size) ---
|
|
987
|
+
if area_ratio < PRICE_TAG_AREA_THRESHOLD:
|
|
988
|
+
final_proposals.append(
|
|
989
|
+
DetectionBox(
|
|
990
|
+
x1=x1, y1=y1, x2=x2, y2=y2,
|
|
991
|
+
confidence=p.get('yolo_conf', 0.8),
|
|
992
|
+
class_id=CID["price_tag"],
|
|
993
|
+
class_name="price_tag",
|
|
994
|
+
area=area,
|
|
995
|
+
ocr_text=p.get('ocr_text')
|
|
996
|
+
)
|
|
997
|
+
)
|
|
998
|
+
continue
|
|
999
|
+
|
|
1000
|
+
# --- 2. Promotional Graphic Check (by position) ---
|
|
1001
|
+
if center_y < header_limit_y:
|
|
1002
|
+
final_proposals.append(
|
|
1003
|
+
DetectionBox(
|
|
1004
|
+
x1=x1, y1=y1, x2=x2, y2=y2,
|
|
1005
|
+
confidence=p.get('yolo_conf', 0.9),
|
|
1006
|
+
class_id=CID["promotional_candidate"],
|
|
1007
|
+
class_name="promotional_candidate",
|
|
1008
|
+
area=area,
|
|
1009
|
+
ocr_text=p.get('ocr_text')
|
|
1010
|
+
)
|
|
1011
|
+
)
|
|
1012
|
+
continue
|
|
1013
|
+
|
|
1014
|
+
# --- 3. Heuristic & CLIP Classification for Products/Boxes ---
|
|
1015
|
+
try:
|
|
1016
|
+
crop_bgr = img[y1:y2, x1:x2]
|
|
1017
|
+
if crop_bgr.size == 0:
|
|
1018
|
+
continue
|
|
1019
|
+
|
|
1020
|
+
# Get visual heuristics and CLIP scores
|
|
1021
|
+
visuals = self._analyze_crop_visuals(crop_bgr)
|
|
1022
|
+
|
|
1023
|
+
crop_pil = Image.fromarray(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB))
|
|
1024
|
+
with torch.no_grad():
|
|
1025
|
+
ip = self.proc(images=crop_pil, return_tensors="pt").to(self.device)
|
|
1026
|
+
img_feat = self.clip.get_image_features(**ip)
|
|
1027
|
+
img_feat /= img_feat.norm(dim=-1, keepdim=True)
|
|
1028
|
+
text_sims = (img_feat @ self.text_feats.T).squeeze().tolist()
|
|
1029
|
+
s_poster, s_printer, s_box = text_sims[0], text_sims[1], text_sims[2]
|
|
1030
|
+
|
|
1031
|
+
# --- New Decision Logic ---
|
|
1032
|
+
class_name = None
|
|
1033
|
+
confidence = 0.8 # Default confidence for heuristic-based decision
|
|
1034
|
+
|
|
1035
|
+
# Priority 1: Strong color evidence overrides everything.
|
|
1036
|
+
if visuals["is_mostly_white"] and not visuals["is_mostly_blue"]:
|
|
1037
|
+
class_name = "product_candidate" # It's a white printer device
|
|
1038
|
+
confidence = 0.95 # High confidence in color heuristic
|
|
1039
|
+
elif visuals["is_mostly_blue"]:
|
|
1040
|
+
class_name = "box_candidate" # It's a blue product box
|
|
1041
|
+
confidence = 0.95
|
|
1042
|
+
|
|
1043
|
+
# Priority 2: If color is ambiguous, use shelf location as a strong hint.
|
|
1044
|
+
if not class_name:
|
|
1045
|
+
if shelf_level == "middle":
|
|
1046
|
+
class_name = "product_candidate"
|
|
1047
|
+
confidence = 0.85
|
|
1048
|
+
elif shelf_level == "bottom":
|
|
1049
|
+
class_name = "box_candidate"
|
|
1050
|
+
confidence = 0.85
|
|
1051
|
+
|
|
1052
|
+
# Priority 3 (Fallback): If still undecided, use the original CLIP score.
|
|
1053
|
+
if not class_name:
|
|
1054
|
+
if s_printer > s_box:
|
|
1055
|
+
class_name = "product_candidate"
|
|
1056
|
+
confidence = s_printer
|
|
1057
|
+
else:
|
|
1058
|
+
class_name = "box_candidate"
|
|
1059
|
+
confidence = s_box
|
|
1060
|
+
|
|
1061
|
+
final_class_id = CID[class_name]
|
|
1062
|
+
final_proposals.append(
|
|
1063
|
+
DetectionBox(
|
|
1064
|
+
x1=x1, y1=y1, x2=x2, y2=y2,
|
|
1065
|
+
confidence=confidence,
|
|
1066
|
+
class_id=final_class_id,
|
|
1067
|
+
class_name=class_name,
|
|
1068
|
+
area=area,
|
|
1069
|
+
ocr_text=p.get('ocr_text')
|
|
1070
|
+
)
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
except Exception as e:
|
|
1074
|
+
self.logger.error(f"Failed to classify proposal with heuristics/CLIP: {e}")
|
|
1075
|
+
|
|
1076
|
+
return final_proposals
|
|
1077
|
+
|
|
1078
|
+
# --------------------- merge/cleanup ------------------------------
|
|
1079
|
+
def _merge(self, dets: List[DetectionBox], iou_same=0.3) -> List[DetectionBox]:
|
|
1080
|
+
"""Enhanced merge with size-aware logic"""
|
|
1081
|
+
dets = sorted(dets, key=lambda d: (d.class_name, -d.confidence, -d.area))
|
|
1082
|
+
out = []
|
|
1083
|
+
|
|
1084
|
+
for d in dets:
|
|
1085
|
+
placed = False
|
|
1086
|
+
for m in out:
|
|
1087
|
+
if d.class_name == m.class_name:
|
|
1088
|
+
iou = self._iou(d, m)
|
|
1089
|
+
|
|
1090
|
+
# Different merge strategies based on class
|
|
1091
|
+
if d.class_name == "box_candidate":
|
|
1092
|
+
# More aggressive merging for boxes (they're often tightly packed)
|
|
1093
|
+
merge_threshold = 0.25
|
|
1094
|
+
elif d.class_name == "product_candidate":
|
|
1095
|
+
# Conservative merging for printers (they're usually separate)
|
|
1096
|
+
merge_threshold = 0.4
|
|
1097
|
+
else:
|
|
1098
|
+
merge_threshold = iou_same
|
|
1099
|
+
|
|
1100
|
+
if iou > merge_threshold:
|
|
1101
|
+
# Merge by taking the union
|
|
1102
|
+
m.x1 = min(m.x1, d.x1)
|
|
1103
|
+
m.y1 = min(m.y1, d.y1)
|
|
1104
|
+
m.x2 = max(m.x2, d.x2)
|
|
1105
|
+
m.y2 = max(m.y2, d.y2)
|
|
1106
|
+
m.area = (m.x2 - m.x1) * (m.y2 - m.y1)
|
|
1107
|
+
m.confidence = max(m.confidence, d.confidence)
|
|
1108
|
+
placed = True
|
|
1109
|
+
print(f" 🔄 Merged {d.class_name} with IoU={iou:.3f}")
|
|
1110
|
+
break
|
|
1111
|
+
|
|
1112
|
+
if not placed:
|
|
1113
|
+
out.append(d)
|
|
1114
|
+
|
|
1115
|
+
return out
|
|
1116
|
+
|
|
1117
|
+
# ------------------------------ debug ------------------------------------
|
|
1118
|
+
def _rectangle_dashed(self, img, pt1, pt2, color, thickness=2, gap=9):
|
|
1119
|
+
x1, y1 = pt1
|
|
1120
|
+
x2, y2 = pt2
|
|
1121
|
+
# top
|
|
1122
|
+
for x in range(x1, x2, gap * 2):
|
|
1123
|
+
cv2.line(img, (x, y1), (min(x + gap, x2), y1), color, thickness)
|
|
1124
|
+
# bottom
|
|
1125
|
+
for x in range(x1, x2, gap * 2):
|
|
1126
|
+
cv2.line(img, (x, y2), (min(x + gap, x2), y2), color, thickness)
|
|
1127
|
+
# left
|
|
1128
|
+
for y in range(y1, y2, gap * 2):
|
|
1129
|
+
cv2.line(img, (x1, y), (x1, min(y + gap, y2)), color, thickness)
|
|
1130
|
+
# right
|
|
1131
|
+
for y in range(y1, y2, gap * 2):
|
|
1132
|
+
cv2.line(img, (x2, y), (x2, min(y + gap, y2)), color, thickness)
|
|
1133
|
+
|
|
1134
|
+
def _draw_corners(self, img, pt1, pt2, color, length=12, thickness=2):
|
|
1135
|
+
x1, y1 = pt1
|
|
1136
|
+
x2, y2 = pt2
|
|
1137
|
+
# TL
|
|
1138
|
+
cv2.line(img, (x1, y1), (x1 + length, y1), color, thickness)
|
|
1139
|
+
cv2.line(img, (x1, y1), (x1, y1 + length), color, thickness)
|
|
1140
|
+
# TR
|
|
1141
|
+
cv2.line(img, (x2, y1), (x2 - length, y1), color, thickness)
|
|
1142
|
+
cv2.line(img, (x2, y1), (x2, y1 + length), color, thickness)
|
|
1143
|
+
# BL
|
|
1144
|
+
cv2.line(img, (x1, y2), (x1 + length, y2), color, thickness)
|
|
1145
|
+
cv2.line(img, (x1, y2), (x1, y2 - length), color, thickness)
|
|
1146
|
+
# BR
|
|
1147
|
+
cv2.line(img, (x2, y2), (x2 - length, y2), color, thickness)
|
|
1148
|
+
cv2.line(img, (x2, y2), (x2, y2 - length), color, thickness)
|
|
1149
|
+
|
|
1150
|
+
def _draw_phase_areas(self, img, props, roi_box, show_labels=True):
|
|
1151
|
+
"""
|
|
1152
|
+
Draw per-phase borders (no fill). Thickness encodes confidence.
|
|
1153
|
+
poster_high = magenta (solid), high_confidence = green (solid), aggressive = orange (dashed).
|
|
1154
|
+
"""
|
|
1155
|
+
phase_colors = {
|
|
1156
|
+
"poster_high": (200, 0, 200), # BGR
|
|
1157
|
+
"high_confidence": (0, 220, 0),
|
|
1158
|
+
"aggressive": (0, 165, 255),
|
|
1159
|
+
}
|
|
1160
|
+
dashed = {"poster_high": False, "high_confidence": False, "aggressive": True}
|
|
1161
|
+
|
|
1162
|
+
# --- legend counts
|
|
1163
|
+
counts = Counter(p.get("phase", "aggressive") for p in props)
|
|
1164
|
+
|
|
1165
|
+
# --- draw ROI
|
|
1166
|
+
rx1, ry1, rx2, ry2 = roi_box
|
|
1167
|
+
cv2.rectangle(img, (rx1, ry1), (rx2, ry2), (0, 255, 0), 2)
|
|
1168
|
+
|
|
1169
|
+
# --- per-proposal borders
|
|
1170
|
+
for p in props:
|
|
1171
|
+
x1, y1, x2, y2 = p["box"]
|
|
1172
|
+
phase = p.get("phase", "aggressive")
|
|
1173
|
+
conf = float(p.get("confidence", 0.0))
|
|
1174
|
+
color = phase_colors.get(phase, (180, 180, 180))
|
|
1175
|
+
|
|
1176
|
+
# thickness: 1..5 with a gentle curve so small conf doesn't vanish
|
|
1177
|
+
t = max(1, min(5, int(round(1 + 4 * math.sqrt(max(0.0, min(conf, 1.0)))))))
|
|
1178
|
+
|
|
1179
|
+
if dashed.get(phase, False):
|
|
1180
|
+
self._rectangle_dashed(img, (x1, y1), (x2, y2), color, thickness=t, gap=9)
|
|
1181
|
+
else:
|
|
1182
|
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, t)
|
|
1183
|
+
|
|
1184
|
+
# add subtle phase corners to help when borders overlap
|
|
1185
|
+
self._draw_corners(img, (x1, y1), (x2, y2), color, length=10, thickness=max(1, t - 1))
|
|
1186
|
+
|
|
1187
|
+
if show_labels:
|
|
1188
|
+
lbl = f"{phase.split('_')[0][:1].upper()} {conf:.2f}"
|
|
1189
|
+
ty = max(12, y1 - 6)
|
|
1190
|
+
cv2.putText(img, lbl, (x1 + 2, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
|
|
1191
|
+
|
|
1192
|
+
# --- legend (top-left of ROI)
|
|
1193
|
+
legend_items = [("poster_high", "Poster"), ("high_confidence", "High"), ("aggressive", "Agg")]
|
|
1194
|
+
lx, ly = rx1 + 6, max(18, ry1 + 16)
|
|
1195
|
+
for key, name in legend_items:
|
|
1196
|
+
col = phase_colors[key]
|
|
1197
|
+
cv2.rectangle(img, (lx, ly - 10), (lx + 18, ly - 2), col, -1)
|
|
1198
|
+
text = f"{name}: {counts.get(key, 0)}"
|
|
1199
|
+
cv2.putText(img, text, (lx + 24, ly - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1, cv2.LINE_AA)
|
|
1200
|
+
ly += 16
|
|
1201
|
+
|
|
1202
|
+
return img
|
|
1203
|
+
|
|
1204
|
+
def _draw_yolo(self, img, props, roi_box, shelf_lines):
|
|
1205
|
+
"""
|
|
1206
|
+
Draw raw YOLO detections with detailed labels
|
|
1207
|
+
"""
|
|
1208
|
+
rx1, ry1, rx2, ry2 = roi_box
|
|
1209
|
+
|
|
1210
|
+
# Draw ROI box
|
|
1211
|
+
cv2.rectangle(img, (rx1, ry1), (rx2, ry2), (0, 255, 0), 3)
|
|
1212
|
+
cv2.putText(img, f"ROI: {rx2-rx1}x{ry2-ry1}", (rx1, ry1-10),
|
|
1213
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
|
1214
|
+
|
|
1215
|
+
# Draw shelf lines
|
|
1216
|
+
for i, y in enumerate(shelf_lines):
|
|
1217
|
+
cv2.line(img, (rx1, y), (rx2, y), (0, 255, 255), 2)
|
|
1218
|
+
cv2.putText(img, f"Shelf{i+1}", (rx1+5, y-5),
|
|
1219
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 255), 1)
|
|
1220
|
+
|
|
1221
|
+
# Color mapping for retail candidates
|
|
1222
|
+
candidate_colors = {
|
|
1223
|
+
"promotional_graphic": (255, 0, 255), # Magenta
|
|
1224
|
+
"printer": (255, 140, 0), # Orange
|
|
1225
|
+
"tv": (0, 200, 0), # Green
|
|
1226
|
+
"product_candidate": (200, 200, 0), # Yellow
|
|
1227
|
+
"product_box": (0, 140, 255), # Blue
|
|
1228
|
+
"small_object": (128, 128, 128), # Gray
|
|
1229
|
+
"ink_bottle": (160, 0, 200), # Purple
|
|
1230
|
+
}
|
|
1231
|
+
|
|
1232
|
+
for p in props:
|
|
1233
|
+
(x1, y1, x2, y2) = p["box"]
|
|
1234
|
+
|
|
1235
|
+
# Choose color based on primary retail candidate
|
|
1236
|
+
candidates = p.get("retail_candidates", ["unknown"])
|
|
1237
|
+
primary_candidate = candidates[0] if candidates else "unknown"
|
|
1238
|
+
color = candidate_colors.get(primary_candidate, (255, 255, 255))
|
|
1239
|
+
|
|
1240
|
+
# Draw detection
|
|
1241
|
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
|
1242
|
+
|
|
1243
|
+
# Enhanced label
|
|
1244
|
+
idx = p["raw_index"]
|
|
1245
|
+
yolo_class = p["yolo_label"]
|
|
1246
|
+
conf = p["yolo_conf"]
|
|
1247
|
+
area_pct = p["area_ratio"] * 100
|
|
1248
|
+
|
|
1249
|
+
label1 = f"#{idx} {yolo_class}→{primary_candidate}"
|
|
1250
|
+
label2 = f"conf:{conf:.3f} area:{area_pct:.1f}%"
|
|
1251
|
+
|
|
1252
|
+
cv2.putText(img, label1, (x1, max(15, y1 - 5)),
|
|
1253
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA)
|
|
1254
|
+
cv2.putText(img, label2, (x1, max(30, y1 + 15)),
|
|
1255
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1, cv2.LINE_AA)
|
|
1256
|
+
|
|
1257
|
+
return img
|
|
1258
|
+
|
|
1259
|
+
def _draw_phase1(self, img, roi_box, shelf_lines, dets, ad_box=None):
|
|
1260
|
+
"""
|
|
1261
|
+
FIXED: Phase-1 debug drawing with better info
|
|
1262
|
+
"""
|
|
1263
|
+
rx1, ry1, rx2, ry2 = roi_box
|
|
1264
|
+
cv2.rectangle(img, (rx1, ry1), (rx2, ry2), (0, 255, 0), 2)
|
|
1265
|
+
|
|
1266
|
+
for y in shelf_lines:
|
|
1267
|
+
cv2.line(img, (rx1, y), (rx2, y), (0, 255, 255), 2)
|
|
1268
|
+
|
|
1269
|
+
colors = {
|
|
1270
|
+
"promotional_candidate": (0, 200, 0),
|
|
1271
|
+
"product_candidate": (255, 140, 0),
|
|
1272
|
+
"box_candidate": (0, 140, 255),
|
|
1273
|
+
"price_tag": (255, 0, 255),
|
|
1274
|
+
}
|
|
1275
|
+
|
|
1276
|
+
for i, d in enumerate(dets, 1):
|
|
1277
|
+
c = colors.get(d.class_name, (200, 200, 200))
|
|
1278
|
+
cv2.rectangle(img, (d.x1, d.y1), (d.x2, d.y2), c, 2)
|
|
1279
|
+
|
|
1280
|
+
# Enhanced label with detection info
|
|
1281
|
+
w, h = d.x2 - d.x1, d.y2 - d.y1
|
|
1282
|
+
area_pct = (d.area / (img.shape[0] * img.shape[1])) * 100
|
|
1283
|
+
aspect = w / max(h, 1)
|
|
1284
|
+
center_y = (d.y1 + d.y2) / 2
|
|
1285
|
+
|
|
1286
|
+
print(f" #{i:2d}: {d.class_name:20s} conf={d.confidence:.3f} "
|
|
1287
|
+
f"area={area_pct:.2f}% AR={aspect:.2f} center_y={center_y:.0f}")
|
|
1288
|
+
|
|
1289
|
+
label = f"#{i} {d.class_name} {d.confidence:.2f}"
|
|
1290
|
+
cv2.putText(img, label, (d.x1, max(15, d.y1 - 4)),
|
|
1291
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.45, c, 1, cv2.LINE_AA)
|
|
1292
|
+
|
|
1293
|
+
if ad_box is not None:
|
|
1294
|
+
cv2.rectangle(img, (ad_box[0], ad_box[1]), (ad_box[2], ad_box[3]), (0, 255, 128), 2)
|
|
1295
|
+
cv2.putText(
|
|
1296
|
+
img, "poster_roi",
|
|
1297
|
+
(ad_box[0], max(12, ad_box[1] - 4)),
|
|
1298
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1299
|
+
0.4, (0, 255, 128), 1, cv2.LINE_AA,
|
|
1300
|
+
)
|
|
1301
|
+
|
|
1302
|
+
return img
|
|
1303
|
+
|
|
1304
|
+
|
|
1305
|
+
class PlanogramCompliancePipeline(AbstractPipeline):
|
|
1306
|
+
"""
|
|
1307
|
+
Pipeline for planogram compliance checking.
|
|
1308
|
+
|
|
1309
|
+
3-Step planogram compliance pipeline:
|
|
1310
|
+
Step 1: Object Detection (YOLO/ResNet)
|
|
1311
|
+
Step 2: LLM Object Identification with Reference Images
|
|
1312
|
+
Step 3: Planogram Comparison and Compliance Verification
|
|
1313
|
+
"""
|
|
1314
|
+
def __init__(
|
|
1315
|
+
self,
|
|
1316
|
+
planogram_config: PlanogramConfig,
|
|
1317
|
+
llm: Any = None,
|
|
1318
|
+
llm_provider: str = "google",
|
|
1319
|
+
llm_model: Optional[str] = None,
|
|
1320
|
+
**kwargs: Any
|
|
1321
|
+
):
|
|
1322
|
+
"""
|
|
1323
|
+
Initialize the 3-step pipeline
|
|
1324
|
+
|
|
1325
|
+
Args:
|
|
1326
|
+
llm_provider: LLM provider for identification
|
|
1327
|
+
llm_model: Specific LLM model
|
|
1328
|
+
api_key: API key
|
|
1329
|
+
detection_model: Object detection model to use
|
|
1330
|
+
"""
|
|
1331
|
+
# Endcap geometry defaults (can be tuned per program)
|
|
1332
|
+
geometry = planogram_config.endcap_geometry
|
|
1333
|
+
self.endcap_aspect_ratio = geometry.aspect_ratio
|
|
1334
|
+
self.left_margin_ratio = geometry.left_margin_ratio
|
|
1335
|
+
self.right_margin_ratio = geometry.right_margin_ratio
|
|
1336
|
+
self.top_margin_ratio = geometry.top_margin_ratio
|
|
1337
|
+
self.bottom_margin_ratio = geometry.bottom_margin_ratio
|
|
1338
|
+
self.inter_shelf_padding = geometry.inter_shelf_padding
|
|
1339
|
+
|
|
1340
|
+
# saving the planogram config for later use
|
|
1341
|
+
self.planogram_config = planogram_config
|
|
1342
|
+
super().__init__(
|
|
1343
|
+
llm=llm,
|
|
1344
|
+
llm_provider=llm_provider,
|
|
1345
|
+
llm_model=llm_model,
|
|
1346
|
+
**kwargs
|
|
1347
|
+
)
|
|
1348
|
+
reference_images = planogram_config.reference_images
|
|
1349
|
+
references = list(reference_images.values()) if reference_images else None
|
|
1350
|
+
# Initialize the generic shape detector
|
|
1351
|
+
self.shape_detector = RetailDetector(
|
|
1352
|
+
yolo_model=planogram_config.detection_model,
|
|
1353
|
+
conf=planogram_config.confidence_threshold,
|
|
1354
|
+
llm=self.llm,
|
|
1355
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
1356
|
+
reference_images=references
|
|
1357
|
+
)
|
|
1358
|
+
self.logger.debug(
|
|
1359
|
+
f"Initialized RetailDetector with {planogram_config.detection_model}"
|
|
1360
|
+
)
|
|
1361
|
+
self.reference_images = reference_images or {}
|
|
1362
|
+
self.confidence_threshold = planogram_config.confidence_threshold
|
|
1363
|
+
|
|
1364
|
+
async def detect_objects_and_shelves(
|
|
1365
|
+
self,
|
|
1366
|
+
image: Image,
|
|
1367
|
+
image_array: np.ndarray,
|
|
1368
|
+
endcap: Detection,
|
|
1369
|
+
ad: Optional[Detection] = None,
|
|
1370
|
+
brand: Optional[Detection] = None,
|
|
1371
|
+
panel_text: Optional[Detection] = None,
|
|
1372
|
+
planogram_description: Optional[PlanogramDescription] = None
|
|
1373
|
+
):
|
|
1374
|
+
self.logger.debug(
|
|
1375
|
+
"Step 1: Detecting generic shapes and boundaries..."
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
det_out = await self.shape_detector.detect(
|
|
1379
|
+
image=image,
|
|
1380
|
+
image_array=image_array,
|
|
1381
|
+
endcap=endcap,
|
|
1382
|
+
ad=ad,
|
|
1383
|
+
planogram=planogram_description,
|
|
1384
|
+
debug_yolo="/tmp/data/yolo_raw.png",
|
|
1385
|
+
debug_phase1="/tmp/data/yolo_phase1_debug.png",
|
|
1386
|
+
debug_phases="/tmp/data/yolo_phases_debug.png",
|
|
1387
|
+
)
|
|
1388
|
+
|
|
1389
|
+
shelves = det_out["shelves"] # {'top': DetectionBox(...), 'middle': ...}
|
|
1390
|
+
proposals = det_out["proposals"] # List[DetectionBox]
|
|
1391
|
+
|
|
1392
|
+
print("PROPOSALS:", proposals)
|
|
1393
|
+
print("SHELVES:", shelves)
|
|
1394
|
+
|
|
1395
|
+
h, w = image_array.shape[:2]
|
|
1396
|
+
if brand:
|
|
1397
|
+
bx1, by1, bx2, by2 = brand.bbox.get_pixel_coordinates(width=w, height=h)
|
|
1398
|
+
proposals.append(
|
|
1399
|
+
DetectionBox(
|
|
1400
|
+
x1=bx1, y1=by1, x2=bx2, y2=by2,
|
|
1401
|
+
confidence=brand.confidence,
|
|
1402
|
+
class_id=CID["brand_logo"],
|
|
1403
|
+
class_name="brand_logo",
|
|
1404
|
+
area=(bx2 - bx1) * (by2 - by1),
|
|
1405
|
+
ocr_text=brand.content
|
|
1406
|
+
)
|
|
1407
|
+
)
|
|
1408
|
+
print(f" + Injected brand_logo: '{brand.content}'")
|
|
1409
|
+
|
|
1410
|
+
if panel_text:
|
|
1411
|
+
tx1, ty1, tx2, ty2 = panel_text.bbox.get_pixel_coordinates(width=w, height=h)
|
|
1412
|
+
proposals.append(
|
|
1413
|
+
DetectionBox(
|
|
1414
|
+
x1=tx1, y1=ty1, x2=tx2, y2=ty2,
|
|
1415
|
+
confidence=panel_text.confidence,
|
|
1416
|
+
class_id=CID["poster_text"],
|
|
1417
|
+
class_name="poster_text",
|
|
1418
|
+
area=(tx2 - tx1) * (ty2 - ty1),
|
|
1419
|
+
ocr_text=panel_text.content.replace('.', ' ')
|
|
1420
|
+
)
|
|
1421
|
+
)
|
|
1422
|
+
print(f" + Injected poster_text: '{panel_text.content}'")
|
|
1423
|
+
|
|
1424
|
+
# --- IMPORTANT: use Phase-1 shelf bands (not %-of-image buckets) ---
|
|
1425
|
+
shelf_regions = self._materialize_shelf_regions(shelves, proposals, planogram_description)
|
|
1426
|
+
|
|
1427
|
+
detections = list(proposals)
|
|
1428
|
+
|
|
1429
|
+
self.logger.debug(
|
|
1430
|
+
"Found %d objects in %d shelf regions", len(detections), len(shelf_regions)
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
self.logger.debug("Found %d objects in %d shelf regions",
|
|
1434
|
+
len(detections), len(shelf_regions))
|
|
1435
|
+
return shelf_regions, detections
|
|
1436
|
+
|
|
1437
|
+
def _materialize_shelf_regions(
|
|
1438
|
+
self,
|
|
1439
|
+
shelves_dict: Dict[str, DetectionBox],
|
|
1440
|
+
dets: List[DetectionBox],
|
|
1441
|
+
planogram_description: Optional[PlanogramDescription] = None
|
|
1442
|
+
) -> List[ShelfRegion]:
|
|
1443
|
+
"""Turn Phase-1 shelf bands into ShelfRegion objects and assign detections by y-overlap."""
|
|
1444
|
+
def y_overlap(a1, a2, b1, b2) -> int:
|
|
1445
|
+
return max(0, min(a2, b2) - max(a1, b1))
|
|
1446
|
+
|
|
1447
|
+
regions: List[ShelfRegion] = []
|
|
1448
|
+
|
|
1449
|
+
# Iterate through the shelves defined in the planogram config, in their specified order.
|
|
1450
|
+
for shelf_config in planogram_description.shelves:
|
|
1451
|
+
level = shelf_config.level
|
|
1452
|
+
band = shelves_dict.get(level)
|
|
1453
|
+
if not band:
|
|
1454
|
+
self.logger.warning(
|
|
1455
|
+
f"Shelf '{level}' is defined in the planogram but was not detected in the image."
|
|
1456
|
+
)
|
|
1457
|
+
continue
|
|
1458
|
+
|
|
1459
|
+
# Find all object proposals that vertically overlap with this shelf's detected band.
|
|
1460
|
+
# An object belongs to the shelf if any part of it is within the shelf's y-range.
|
|
1461
|
+
objs = [d for d in dets if y_overlap(d.y1, d.y2, band.y1, band.y2) > 0]
|
|
1462
|
+
|
|
1463
|
+
# If no objects were found on this shelf, we don't need to create a region for it.
|
|
1464
|
+
if objs:
|
|
1465
|
+
x1 = min(o.x1 for o in objs)
|
|
1466
|
+
x2 = max(o.x2 for o in objs)
|
|
1467
|
+
else:
|
|
1468
|
+
# Use band boundaries if no objects
|
|
1469
|
+
x1, x2 = band.x1, band.x2
|
|
1470
|
+
|
|
1471
|
+
# Create a new bounding box for the ShelfRegion.
|
|
1472
|
+
# The Y coordinates are fixed by the detected shelf band.
|
|
1473
|
+
# The X coordinates are calculated as the min/max extent of the objects on that shelf.
|
|
1474
|
+
y1 = band.y1
|
|
1475
|
+
y2 = band.y2
|
|
1476
|
+
|
|
1477
|
+
bbox = DetectionBox(
|
|
1478
|
+
x1=x1, y1=y1, x2=x2, y2=y2,
|
|
1479
|
+
confidence=1.0,
|
|
1480
|
+
class_id=CID["shelf_region"],
|
|
1481
|
+
class_name="shelf_region",
|
|
1482
|
+
area=(x2 - x1) * (y2 - y1)
|
|
1483
|
+
)
|
|
1484
|
+
|
|
1485
|
+
# Create the final ShelfRegion object.
|
|
1486
|
+
regions.append(
|
|
1487
|
+
ShelfRegion(
|
|
1488
|
+
shelf_id=f"{level}_shelf",
|
|
1489
|
+
bbox=bbox,
|
|
1490
|
+
level=level,
|
|
1491
|
+
objects=objs
|
|
1492
|
+
)
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
return regions
|
|
1496
|
+
|
|
1497
|
+
async def identify_objects_with_references(
|
|
1498
|
+
self,
|
|
1499
|
+
image: Union[str, Path, Image.Image],
|
|
1500
|
+
detections: List[DetectionBox],
|
|
1501
|
+
shelf_regions: List[ShelfRegion],
|
|
1502
|
+
reference_images: List[Union[str, Path, Image.Image]],
|
|
1503
|
+
prompt: str
|
|
1504
|
+
) -> List[IdentifiedProduct]:
|
|
1505
|
+
"""
|
|
1506
|
+
Step 2: Use LLM to identify detected objects using reference images
|
|
1507
|
+
|
|
1508
|
+
Args:
|
|
1509
|
+
image: Original endcap image
|
|
1510
|
+
detections: Object detections from Step 1
|
|
1511
|
+
shelf_regions: Shelf regions from Step 1
|
|
1512
|
+
reference_images: Reference product images
|
|
1513
|
+
prompt: Prompt for object identification
|
|
1514
|
+
|
|
1515
|
+
Returns:
|
|
1516
|
+
List of identified products
|
|
1517
|
+
"""
|
|
1518
|
+
|
|
1519
|
+
self.logger.debug(
|
|
1520
|
+
f"Starting identification with {len(detections)} detections"
|
|
1521
|
+
)
|
|
1522
|
+
# If no detections, return empty list
|
|
1523
|
+
if not detections:
|
|
1524
|
+
self.logger.warning("No detections to identify")
|
|
1525
|
+
return []
|
|
1526
|
+
|
|
1527
|
+
|
|
1528
|
+
pil_image = self._get_image(image)
|
|
1529
|
+
|
|
1530
|
+
# Create annotated image showing detection boxes
|
|
1531
|
+
effective_dets = [
|
|
1532
|
+
d for d in detections if d.class_name not in {"slot", "shelf_region", "price_tag", "fact_tag"}
|
|
1533
|
+
]
|
|
1534
|
+
annotated_image = self._create_annotated_image(pil_image, effective_dets)
|
|
1535
|
+
|
|
1536
|
+
async with self.llm as client:
|
|
1537
|
+
try:
|
|
1538
|
+
extra_refs = {
|
|
1539
|
+
"annotated_image": annotated_image,
|
|
1540
|
+
**reference_images
|
|
1541
|
+
}
|
|
1542
|
+
identified_products = await client.image_identification(
|
|
1543
|
+
prompt=self._build_gemini_identification_prompt(
|
|
1544
|
+
effective_dets,
|
|
1545
|
+
shelf_regions,
|
|
1546
|
+
partial_prompt=prompt
|
|
1547
|
+
),
|
|
1548
|
+
image=image,
|
|
1549
|
+
detections=effective_dets,
|
|
1550
|
+
shelf_regions=shelf_regions,
|
|
1551
|
+
reference_images=extra_refs,
|
|
1552
|
+
temperature=0.0
|
|
1553
|
+
)
|
|
1554
|
+
identified_products = await self._augment_products_with_box_ocr(
|
|
1555
|
+
image,
|
|
1556
|
+
identified_products
|
|
1557
|
+
)
|
|
1558
|
+
for product in identified_products:
|
|
1559
|
+
if product.product_type == "promotional_graphic":
|
|
1560
|
+
if lines := await self._extract_text_from_region(image, product.detection_box):
|
|
1561
|
+
snippet = " ".join(lines)[:120]
|
|
1562
|
+
product.visual_features = (product.visual_features or []) + [f"ocr:{snippet}"]
|
|
1563
|
+
return identified_products
|
|
1564
|
+
|
|
1565
|
+
except Exception as e:
|
|
1566
|
+
self.logger.error(f"Error in structured identification: {e}")
|
|
1567
|
+
traceback.print_exc()
|
|
1568
|
+
raise
|
|
1569
|
+
|
|
1570
|
+
def _guess_et_model_from_text(self, text: str) -> Optional[str]:
|
|
1571
|
+
"""
|
|
1572
|
+
Find Epson EcoTank model tokens in text.
|
|
1573
|
+
Returns normalized like 'et-4950' (device) or 'et-2980', etc.
|
|
1574
|
+
"""
|
|
1575
|
+
if not text:
|
|
1576
|
+
return None
|
|
1577
|
+
t = text.lower().replace(" ", "")
|
|
1578
|
+
# common variants: et-4950, et4950, et – 4950, etc.
|
|
1579
|
+
m = re.search(r"et[-]?\s?(\d{4})", t)
|
|
1580
|
+
if not m:
|
|
1581
|
+
return None
|
|
1582
|
+
num = m.group(1)
|
|
1583
|
+
# Accept only models we care about (tighten if needed)
|
|
1584
|
+
if num in {"2980", "3950", "4950"}:
|
|
1585
|
+
return f"et-{num}"
|
|
1586
|
+
return None
|
|
1587
|
+
|
|
1588
|
+
|
|
1589
|
+
def _maybe_brand_from_text(self, text: str) -> Optional[str]:
|
|
1590
|
+
if not text:
|
|
1591
|
+
return None
|
|
1592
|
+
t = text.lower()
|
|
1593
|
+
if "epson" in t or "ecotank" in t:
|
|
1594
|
+
return "Epson"
|
|
1595
|
+
if 'hisense' in t or "canvastv" in t:
|
|
1596
|
+
return "Hisense"
|
|
1597
|
+
if "firetv" in t or "fire tv" in t:
|
|
1598
|
+
return "Amazon"
|
|
1599
|
+
if "google tv" in t or "chromecast" in t:
|
|
1600
|
+
return "Google"
|
|
1601
|
+
return None
|
|
1602
|
+
|
|
1603
|
+
def _normalize_ocr_text(self, s: str) -> str:
|
|
1604
|
+
"""
|
|
1605
|
+
Make OCR text match-friendly:
|
|
1606
|
+
- Unicode normalize (NFKC), strip diacritics
|
|
1607
|
+
- Replace fancy dashes/quotes with spaces
|
|
1608
|
+
- Remove non-alnum except spaces, collapse whitespace
|
|
1609
|
+
- Lowercase
|
|
1610
|
+
"""
|
|
1611
|
+
if not s:
|
|
1612
|
+
return ""
|
|
1613
|
+
s = unicodedata.normalize("NFKC", s)
|
|
1614
|
+
# strip accents
|
|
1615
|
+
s = "".join(ch for ch in unicodedata.normalize("NFKD", s) if not unicodedata.combining(ch))
|
|
1616
|
+
# unify punctuation to spaces
|
|
1617
|
+
s = re.sub(r"[—–‐-‒–—―…“”\"'·•••·•—–/\\|_=+^°™®©§]", " ", s)
|
|
1618
|
+
# keep letters/digits/spaces
|
|
1619
|
+
s = re.sub(r"[^A-Za-z0-9 ]+", " ", s)
|
|
1620
|
+
# collapse
|
|
1621
|
+
s = re.sub(r"\s+", " ", s).strip().lower()
|
|
1622
|
+
return s
|
|
1623
|
+
|
|
1624
|
+
async def _augment_products_with_box_ocr(
|
|
1625
|
+
self,
|
|
1626
|
+
image: Union[str, Path, Image.Image],
|
|
1627
|
+
products: List[IdentifiedProduct]
|
|
1628
|
+
) -> List[IdentifiedProduct]:
|
|
1629
|
+
"""Add OCR-derived evidence to boxes/printers and fix product_model when we see ET-xxxx."""
|
|
1630
|
+
for p in products:
|
|
1631
|
+
if not p.detection_box:
|
|
1632
|
+
continue
|
|
1633
|
+
# normalize product brand logo with OCR or content from detection if is null:
|
|
1634
|
+
if getattr(p.detection_box, 'class_name', None) == 'brand_logo' and not getattr(p, 'brand', None):
|
|
1635
|
+
if p.detection_box.ocr_text:
|
|
1636
|
+
brand = self._maybe_brand_from_text(p.detection_box.ocr_text)
|
|
1637
|
+
if brand:
|
|
1638
|
+
try:
|
|
1639
|
+
p.brand = brand # only if IdentifiedProduct has 'brand'
|
|
1640
|
+
except Exception:
|
|
1641
|
+
if not p.visual_features:
|
|
1642
|
+
p.visual_features = []
|
|
1643
|
+
p.visual_features.append(f"brand:{brand}")
|
|
1644
|
+
if p.product_type in {"product_box", "printer"}:
|
|
1645
|
+
lines = await self._extract_text_from_region(image, p.detection_box, mode="model")
|
|
1646
|
+
if lines:
|
|
1647
|
+
# Keep some OCR as visual evidence (don’t explode the list)
|
|
1648
|
+
snippet = " ".join(lines)[:120]
|
|
1649
|
+
if not p.visual_features:
|
|
1650
|
+
p.visual_features = []
|
|
1651
|
+
p.visual_features.append(f"ocr:{snippet}")
|
|
1652
|
+
|
|
1653
|
+
# Brand hint
|
|
1654
|
+
brand = self._maybe_brand_from_text(snippet)
|
|
1655
|
+
if brand and not getattr(p, "brand", None):
|
|
1656
|
+
try:
|
|
1657
|
+
p.brand = brand # only if IdentifiedProduct has 'brand'
|
|
1658
|
+
except Exception:
|
|
1659
|
+
# If the model doesn’t have brand, keep it as a feature.
|
|
1660
|
+
p.visual_features.append(f"brand:{brand}")
|
|
1661
|
+
|
|
1662
|
+
# Model from OCR
|
|
1663
|
+
model = self._guess_et_model_from_text(snippet)
|
|
1664
|
+
if model:
|
|
1665
|
+
# Normalize to your scheme:
|
|
1666
|
+
target = model.upper()
|
|
1667
|
+
# If missing or mismatched, replace
|
|
1668
|
+
if not p.product_model:
|
|
1669
|
+
p.product_model = target
|
|
1670
|
+
else:
|
|
1671
|
+
# If current looks generic/incorrect, fix it
|
|
1672
|
+
cur = (p.product_model or "").lower()
|
|
1673
|
+
if "et-" in target.lower() and ("et-" not in cur or "box" in target.lower() and "box" not in cur):
|
|
1674
|
+
p.product_model = target
|
|
1675
|
+
elif p.product_type == "promotional_graphic":
|
|
1676
|
+
if lines := await self._extract_text_from_region(image, p.detection_box):
|
|
1677
|
+
snippet = " ".join(lines)[:160]
|
|
1678
|
+
p.visual_features = (p.visual_features or []) + [f"ocr:{snippet}"]
|
|
1679
|
+
# keep a normalized text blob
|
|
1680
|
+
joined = " ".join(lines)
|
|
1681
|
+
if norm := self._normalize_ocr_text(joined):
|
|
1682
|
+
p.visual_features.append(norm)
|
|
1683
|
+
for ln in lines:
|
|
1684
|
+
if ln and (nln := self._normalize_ocr_text(ln)) and nln not in p.visual_features:
|
|
1685
|
+
p.visual_features.append(nln)
|
|
1686
|
+
|
|
1687
|
+
# NEW: infer brand from OCR/features if missing
|
|
1688
|
+
if not getattr(p, "brand", None):
|
|
1689
|
+
brand = self._maybe_brand_from_text(joined)
|
|
1690
|
+
if not brand and p.visual_features:
|
|
1691
|
+
vf_blob = " ".join(p.visual_features)
|
|
1692
|
+
brand = self._maybe_brand_from_text(vf_blob)
|
|
1693
|
+
if brand:
|
|
1694
|
+
p.brand = brand
|
|
1695
|
+
return products
|
|
1696
|
+
|
|
1697
|
+
async def _extract_text_from_region(
|
|
1698
|
+
self,
|
|
1699
|
+
image: Union[str, Path, Image.Image],
|
|
1700
|
+
detection_box: DetectionBox,
|
|
1701
|
+
mode: str = "generic", # "generic" | "model"
|
|
1702
|
+
) -> List[str]:
|
|
1703
|
+
"""Extract text from a region with OCR.
|
|
1704
|
+
- generic: multi-pass (psm 6 & 4) + unsharp + binarize
|
|
1705
|
+
- model : tuned to catch ET-xxxx
|
|
1706
|
+
Returns lines + normalized variants so TextMatcher has more chances.
|
|
1707
|
+
"""
|
|
1708
|
+
try:
|
|
1709
|
+
pil_image = Image.open(image) if isinstance(image, (str, Path)) else image
|
|
1710
|
+
pad = 10
|
|
1711
|
+
x1 = max(0, detection_box.x1 - pad)
|
|
1712
|
+
y1 = max(0, detection_box.y1 - pad)
|
|
1713
|
+
x2 = min(pil_image.width - 1, detection_box.x2 + pad)
|
|
1714
|
+
y2 = min(pil_image.height - 1, detection_box.y2 + pad)
|
|
1715
|
+
|
|
1716
|
+
# ENSURE VALID CROP COORDINATES
|
|
1717
|
+
if x1 >= x2:
|
|
1718
|
+
x2 = x1 + 10
|
|
1719
|
+
if y1 >= y2:
|
|
1720
|
+
y2 = y1 + 10
|
|
1721
|
+
|
|
1722
|
+
crop_rgb = pil_image.crop((x1, y1, x2, y2)).convert("RGB")
|
|
1723
|
+
|
|
1724
|
+
def _prep(arr):
|
|
1725
|
+
g = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
|
|
1726
|
+
g = cv2.resize(g, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_CUBIC)
|
|
1727
|
+
blur = cv2.GaussianBlur(g, (0, 0), sigmaX=1.0)
|
|
1728
|
+
sharp = cv2.addWeighted(g, 1.6, blur, -0.6, 0)
|
|
1729
|
+
_, th = cv2.threshold(sharp, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
1730
|
+
return th
|
|
1731
|
+
|
|
1732
|
+
if mode == "model":
|
|
1733
|
+
th = _prep(np.array(crop_rgb))
|
|
1734
|
+
crop = Image.fromarray(th).convert("L")
|
|
1735
|
+
cfg = "--oem 3 --psm 6 -l eng -c tessedit_char_whitelist=ETet0123456789-ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
1736
|
+
raw = pytesseract.image_to_string(crop, config=cfg)
|
|
1737
|
+
lines = [ln.strip() for ln in raw.splitlines() if ln.strip()]
|
|
1738
|
+
else:
|
|
1739
|
+
arr = np.array(crop_rgb)
|
|
1740
|
+
th = _prep(arr)
|
|
1741
|
+
# two passes help for 'Goodbye Cartridges' on light box
|
|
1742
|
+
raw1 = pytesseract.image_to_string(Image.fromarray(th), config="--psm 6 -l eng")
|
|
1743
|
+
raw2 = pytesseract.image_to_string(Image.fromarray(th), config="--psm 4 -l eng")
|
|
1744
|
+
raw = raw1 + "\n" + raw2
|
|
1745
|
+
lines = [ln.strip() for ln in raw.splitlines() if ln.strip()]
|
|
1746
|
+
|
|
1747
|
+
# Add normalized variants to help TextMatcher:
|
|
1748
|
+
# - lowercase, punctuation stripped
|
|
1749
|
+
# - a single combined line
|
|
1750
|
+
def norm(s: str) -> str:
|
|
1751
|
+
s = s.lower()
|
|
1752
|
+
s = re.sub(r"[^a-z0-9\s]", " ", s) # drop punctuation like colons
|
|
1753
|
+
s = re.sub(r"\s+", " ", s).strip()
|
|
1754
|
+
return s
|
|
1755
|
+
|
|
1756
|
+
variants = [norm(ln) for ln in lines if ln]
|
|
1757
|
+
if variants:
|
|
1758
|
+
variants.append(norm(" ".join(lines)))
|
|
1759
|
+
|
|
1760
|
+
# merge unique while preserving originals first
|
|
1761
|
+
out = lines[:]
|
|
1762
|
+
for v in variants:
|
|
1763
|
+
if v and v not in out:
|
|
1764
|
+
out.append(v)
|
|
1765
|
+
|
|
1766
|
+
return out
|
|
1767
|
+
|
|
1768
|
+
except Exception as e:
|
|
1769
|
+
self.logger.error(f"Text extraction failed: {e}")
|
|
1770
|
+
return []
|
|
1771
|
+
|
|
1772
|
+
def _get_image(
|
|
1773
|
+
self,
|
|
1774
|
+
image: Union[str, Path, Image.Image]
|
|
1775
|
+
) -> Image.Image:
|
|
1776
|
+
"""Load image from path or return copy if already PIL"""
|
|
1777
|
+
|
|
1778
|
+
if isinstance(image, (str, Path)):
|
|
1779
|
+
pil_image = Image.open(image).copy()
|
|
1780
|
+
else:
|
|
1781
|
+
pil_image = image.copy()
|
|
1782
|
+
return pil_image
|
|
1783
|
+
|
|
1784
|
+
def _create_annotated_image(
|
|
1785
|
+
self,
|
|
1786
|
+
image: Image.Image,
|
|
1787
|
+
detections: List[DetectionBox]
|
|
1788
|
+
) -> Image.Image:
|
|
1789
|
+
"""Create an annotated image with detection boxes and IDs"""
|
|
1790
|
+
|
|
1791
|
+
draw = ImageDraw.Draw(image)
|
|
1792
|
+
|
|
1793
|
+
for i, detection in enumerate(detections):
|
|
1794
|
+
# Draw bounding box
|
|
1795
|
+
draw.rectangle(
|
|
1796
|
+
[(detection.x1, detection.y1), (detection.x2, detection.y2)],
|
|
1797
|
+
outline="red", width=2
|
|
1798
|
+
)
|
|
1799
|
+
|
|
1800
|
+
# Add detection ID and confidence
|
|
1801
|
+
label = f"ID:{i+1} ({detection.confidence:.2f})"
|
|
1802
|
+
draw.text((detection.x1, detection.y1 - 20), label, fill="red")
|
|
1803
|
+
|
|
1804
|
+
return image
|
|
1805
|
+
|
|
1806
|
+
def _build_gemini_identification_prompt(
|
|
1807
|
+
self,
|
|
1808
|
+
detections: List[DetectionBox],
|
|
1809
|
+
shelf_regions: List[ShelfRegion],
|
|
1810
|
+
partial_prompt: str
|
|
1811
|
+
) -> str:
|
|
1812
|
+
"""Builds a more detailed prompt to help Gemini differentiate similar products."""
|
|
1813
|
+
detection_lines = ["\nDETECTED OBJECTS (with pre-assigned IDs):"]
|
|
1814
|
+
if detections:
|
|
1815
|
+
for i, detection in enumerate(detections, 1):
|
|
1816
|
+
detection_lines.append(
|
|
1817
|
+
f"ID {i}: Initial class '{detection.class_name}' at bbox ({detection.x1},{detection.y1},{detection.x2},{detection.y2})"
|
|
1818
|
+
)
|
|
1819
|
+
else:
|
|
1820
|
+
detection_lines.append("None")
|
|
1821
|
+
|
|
1822
|
+
shelf_definitions = ["\n**VALID SHELF NAMES & LOCATIONS (Ground Truth):**"]
|
|
1823
|
+
valid_shelf_names = []
|
|
1824
|
+
num_detections = len(detections)
|
|
1825
|
+
for shelf in shelf_regions:
|
|
1826
|
+
# if shelf.level in ['header', 'middle', 'bottom']:
|
|
1827
|
+
valid_shelf_names.append(f"'{shelf.level}'")
|
|
1828
|
+
shelf_definitions.append(f"- Shelf '{shelf.level}': Covers the vertical pixel range from y={shelf.bbox.y1} to y={shelf.bbox.y2}.")
|
|
1829
|
+
shelf_definitions.append(f"\n**RULE:** For the `shelf_location` field, you MUST use one of these exact names: {', '.join(valid_shelf_names)}.")
|
|
1830
|
+
|
|
1831
|
+
# REVISED: Enhanced prompt with new rules
|
|
1832
|
+
prompt = f"""
|
|
1833
|
+
You are an expert at identifying retail products in planogram displays.
|
|
1834
|
+
I have provided an image of a retail endcap, labeled reference images, and a list of {num_detections} pre-detected objects.
|
|
1835
|
+
|
|
1836
|
+
{''.join(detection_lines)}
|
|
1837
|
+
{''.join(shelf_definitions)}
|
|
1838
|
+
|
|
1839
|
+
**YOUR TASK:**
|
|
1840
|
+
For each distinct product, you must first analyze its visual features according to the guide, state your reasoning, and then provide the final identification.
|
|
1841
|
+
|
|
1842
|
+
"""
|
|
1843
|
+
partial_prompt = partial_prompt.strip().format(
|
|
1844
|
+
num_detections=num_detections,
|
|
1845
|
+
shelf_names=", ".join(valid_shelf_names)
|
|
1846
|
+
)
|
|
1847
|
+
prompt += partial_prompt
|
|
1848
|
+
prompt += f"""
|
|
1849
|
+
---
|
|
1850
|
+
|
|
1851
|
+
**JSON OUTPUT FORMAT:**
|
|
1852
|
+
Respond with a single JSON object. For each **distinct product** you identify, provide an entry in the 'detections' list.
|
|
1853
|
+
|
|
1854
|
+
- **detection_id**: The pre-detected ID number, or `null` for newly found items.
|
|
1855
|
+
- **detection_box**: **REQUIRED** if `detection_id` is `null`. An array of four numbers `[x1, y1, x2, y2]`.
|
|
1856
|
+
- **product_type**: printer, tv, product_box, fact_tag, promotional_graphic, or ink_bottle.
|
|
1857
|
+
- **product_model**: Follow naming rules above.
|
|
1858
|
+
- **confidence**: Your confidence (0.0-1.0).
|
|
1859
|
+
- **visual_features**: List of key visual features as if device is turned on, color, size, brightness or any other visual features.
|
|
1860
|
+
- **reasoning**: A brief sentence explaining your identification based on the visual guide. Example: "Reasoning: The control panel has a physical key pad, which matches the ET-3950 guide."
|
|
1861
|
+
- **reference_match**: Which reference image name matches (or "none").
|
|
1862
|
+
- **shelf_location**: {', '.join(valid_shelf_names)}.
|
|
1863
|
+
- **position_on_shelf**: 'left', 'center', or 'right'.
|
|
1864
|
+
|
|
1865
|
+
**!! FINAL CHECK !!**
|
|
1866
|
+
- Ensure your response contains **NO DUPLICATE** entries for the same physical object.
|
|
1867
|
+
- **CRITICAL**: Verify that any item with `detection_id: null` also includes a `detection_box`.
|
|
1868
|
+
|
|
1869
|
+
Analyze all provided images and return the complete JSON response.
|
|
1870
|
+
"""
|
|
1871
|
+
return prompt
|
|
1872
|
+
|
|
1873
|
+
def _calculate_visual_feature_match(self, expected_features: List[str], detected_features: List[str]) -> float:
|
|
1874
|
+
"""
|
|
1875
|
+
Enhanced visual feature matching with semantic understanding
|
|
1876
|
+
"""
|
|
1877
|
+
if not expected_features:
|
|
1878
|
+
return 1.0 # No requirements = full match
|
|
1879
|
+
|
|
1880
|
+
if not detected_features:
|
|
1881
|
+
return 0.0 # No detected features but requirements exist
|
|
1882
|
+
|
|
1883
|
+
# Normalize and create keyword sets for semantic matching
|
|
1884
|
+
def extract_keywords(text):
|
|
1885
|
+
"""Extract meaningful keywords from feature text"""
|
|
1886
|
+
text = text.lower().strip()
|
|
1887
|
+
# Remove common words that don't add meaning
|
|
1888
|
+
stop_words = {'a', 'an', 'the', 'is', 'are', 'on', 'of', 'in', 'at', 'to', 'for', 'with', 'visible', 'displayed', 'showing'}
|
|
1889
|
+
words = [w for w in text.split() if w not in stop_words and len(w) > 1]
|
|
1890
|
+
return set(words)
|
|
1891
|
+
|
|
1892
|
+
# Special semantic mappings for common concepts
|
|
1893
|
+
semantic_mappings = {
|
|
1894
|
+
'active': ['active', 'on', 'powered', 'illuminated', 'lit'],
|
|
1895
|
+
'display': ['display', 'screen', 'tv', 'television', 'monitor'],
|
|
1896
|
+
'illuminated': ['illuminated', 'backlit', 'lit', 'bright', 'glowing'],
|
|
1897
|
+
'logo': ['logo', 'text', 'branding', 'brand'],
|
|
1898
|
+
'dynamic': ['dynamic', 'colorful', 'graphics', 'content'],
|
|
1899
|
+
'official': ['official', 'partner'],
|
|
1900
|
+
'white': ['white', 'large']
|
|
1901
|
+
}
|
|
1902
|
+
|
|
1903
|
+
def semantic_match(expected_word, detected_keywords):
|
|
1904
|
+
"""Check if expected word semantically matches any detected keywords"""
|
|
1905
|
+
if expected_word in detected_keywords:
|
|
1906
|
+
return True
|
|
1907
|
+
|
|
1908
|
+
# Check semantic mappings
|
|
1909
|
+
if expected_word in semantic_mappings:
|
|
1910
|
+
synonyms = semantic_mappings[expected_word]
|
|
1911
|
+
return any(syn in detected_keywords for syn in synonyms)
|
|
1912
|
+
|
|
1913
|
+
# Check if any detected keyword contains the expected word
|
|
1914
|
+
return any(expected_word in keyword for keyword in detected_keywords)
|
|
1915
|
+
|
|
1916
|
+
matches = 0
|
|
1917
|
+
for expected in expected_features:
|
|
1918
|
+
expected_keywords = extract_keywords(expected)
|
|
1919
|
+
|
|
1920
|
+
# Combine all detected feature keywords
|
|
1921
|
+
all_detected_keywords = set()
|
|
1922
|
+
for detected in detected_features:
|
|
1923
|
+
all_detected_keywords.update(extract_keywords(detected))
|
|
1924
|
+
|
|
1925
|
+
# Check if any expected keyword has a semantic match
|
|
1926
|
+
feature_matched = False
|
|
1927
|
+
for exp_keyword in expected_keywords:
|
|
1928
|
+
if semantic_match(exp_keyword, all_detected_keywords):
|
|
1929
|
+
feature_matched = True
|
|
1930
|
+
break
|
|
1931
|
+
|
|
1932
|
+
if feature_matched:
|
|
1933
|
+
matches += 1
|
|
1934
|
+
|
|
1935
|
+
score = matches / len(expected_features)
|
|
1936
|
+
return score
|
|
1937
|
+
|
|
1938
|
+
def check_planogram_compliance(
|
|
1939
|
+
self,
|
|
1940
|
+
identified_products: List[IdentifiedProduct],
|
|
1941
|
+
planogram_description: PlanogramDescription,
|
|
1942
|
+
) -> List[ComplianceResult]:
|
|
1943
|
+
"""Check compliance of identified products against the planogram."""
|
|
1944
|
+
def _matches(ek, fk) -> bool:
|
|
1945
|
+
(e_ptype, e_base), (f_ptype, f_base) = ek, fk
|
|
1946
|
+
if e_ptype != f_ptype:
|
|
1947
|
+
return False
|
|
1948
|
+
if not e_base or not f_base:
|
|
1949
|
+
return True
|
|
1950
|
+
# If no base model specified in planogram, accept type-only match
|
|
1951
|
+
if not e_base:
|
|
1952
|
+
return True
|
|
1953
|
+
if f_base == e_base or e_base in f_base or f_base in e_base:
|
|
1954
|
+
return True
|
|
1955
|
+
if f_base == e_base:
|
|
1956
|
+
return True
|
|
1957
|
+
# NEW: allow cross-slug promo matching if synonyms overlap
|
|
1958
|
+
if e_ptype == "promotional_graphic":
|
|
1959
|
+
fam = lambda s: "canvas-tv" if "canvas-tv" in s else s
|
|
1960
|
+
return fam(e_base) == fam(f_base)
|
|
1961
|
+
# containment: allow 'et-4950' inside 'epson et-4950 bundle' etc.
|
|
1962
|
+
return e_base in f_base or f_base in e_base
|
|
1963
|
+
|
|
1964
|
+
results: List[ComplianceResult] = []
|
|
1965
|
+
|
|
1966
|
+
planogram_brand = planogram_description.brand.lower()
|
|
1967
|
+
found_brand_product = next((
|
|
1968
|
+
p for p in identified_products if p.brand and p.brand.lower() == planogram_brand
|
|
1969
|
+
), None)
|
|
1970
|
+
|
|
1971
|
+
brand = getattr(planogram_description, 'brand', planogram_brand)
|
|
1972
|
+
|
|
1973
|
+
brand_compliance_result = BrandComplianceResult(
|
|
1974
|
+
expected_brand=planogram_description.brand,
|
|
1975
|
+
found_brand=found_brand_product.brand if found_brand_product else None,
|
|
1976
|
+
found=bool(found_brand_product),
|
|
1977
|
+
confidence=found_brand_product.confidence if found_brand_product else 0.0
|
|
1978
|
+
)
|
|
1979
|
+
brand_check_ok = brand_compliance_result.found
|
|
1980
|
+
by_shelf = defaultdict(list)
|
|
1981
|
+
|
|
1982
|
+
for p in identified_products:
|
|
1983
|
+
by_shelf[p.shelf_location].append(p)
|
|
1984
|
+
|
|
1985
|
+
for shelf_cfg in planogram_description.shelves:
|
|
1986
|
+
shelf_level = shelf_cfg.level
|
|
1987
|
+
products_on_shelf = by_shelf.get(shelf_level, [])
|
|
1988
|
+
expected = []
|
|
1989
|
+
# --- 1. Main matching loop for expected products ---
|
|
1990
|
+
for sp in shelf_cfg.products:
|
|
1991
|
+
if sp.product_type in ("fact_tag", "price_tag", "slot"):
|
|
1992
|
+
continue
|
|
1993
|
+
|
|
1994
|
+
e_ptype, e_base = self._canonical_expected_key(sp, brand=brand)
|
|
1995
|
+
expected.append((e_ptype, e_base))
|
|
1996
|
+
|
|
1997
|
+
# --- Build canonical FOUND keys for this shelf (and keep refs for reporting) ---
|
|
1998
|
+
found_keys = [] # list[(ptype, base_model)]
|
|
1999
|
+
found_lookup = [] # parallel to found_keys to map back to strings for reporting
|
|
2000
|
+
promos = []
|
|
2001
|
+
for p in products_on_shelf:
|
|
2002
|
+
if p.product_type in ("fact_tag", "price_tag", "slot", "brand_logo"):
|
|
2003
|
+
continue
|
|
2004
|
+
f_ptype, f_base, f_conf = self._canonical_found_key(p, brand=brand)
|
|
2005
|
+
found_keys.append((f_ptype, f_base))
|
|
2006
|
+
if p.product_type == "promotional_graphic":
|
|
2007
|
+
promos.append(p)
|
|
2008
|
+
|
|
2009
|
+
# for human-readable 'found_products' list later:
|
|
2010
|
+
label = p.product_model or p.product_type or "unknown"
|
|
2011
|
+
found_lookup.append((f_ptype, f_base, label))
|
|
2012
|
+
|
|
2013
|
+
# --- Matching: (ptype must match) AND (base_model equal OR base_model contained in planogram name) ---
|
|
2014
|
+
matched = [False] * len(expected)
|
|
2015
|
+
consumed = [False] * len(found_keys)
|
|
2016
|
+
visual_feature_scores = [] # Track visual feature matching scores
|
|
2017
|
+
|
|
2018
|
+
# Greedy 1:1 matching
|
|
2019
|
+
for i, ek in enumerate(expected):
|
|
2020
|
+
for j, fk in enumerate(found_keys):
|
|
2021
|
+
if matched[i] or consumed[j]:
|
|
2022
|
+
continue
|
|
2023
|
+
if _matches(ek, fk):
|
|
2024
|
+
matched[i] = True
|
|
2025
|
+
consumed[j] = True
|
|
2026
|
+
|
|
2027
|
+
# ADD VISUAL FEATURE MATCHING HERE
|
|
2028
|
+
# Find the corresponding ShelfProduct and IdentifiedProduct
|
|
2029
|
+
shelf_product = shelf_cfg.products[i] # Get the shelf product config
|
|
2030
|
+
identified_product = products_on_shelf[j] # Get the identified product
|
|
2031
|
+
|
|
2032
|
+
# Calculate visual feature match score
|
|
2033
|
+
if hasattr(shelf_product, 'visual_features') and shelf_product.visual_features:
|
|
2034
|
+
detected_features = getattr(identified_product, 'visual_features', []) or []
|
|
2035
|
+
vf_score = self._calculate_visual_feature_match(
|
|
2036
|
+
shelf_product.visual_features,
|
|
2037
|
+
detected_features
|
|
2038
|
+
)
|
|
2039
|
+
visual_feature_scores.append(vf_score)
|
|
2040
|
+
break
|
|
2041
|
+
|
|
2042
|
+
# Compute lists for reporting/scoring
|
|
2043
|
+
expected_readable = [
|
|
2044
|
+
f"{e_ptype}:{e_base}" if e_base else f"{e_ptype}" for (e_ptype, e_base) in expected
|
|
2045
|
+
]
|
|
2046
|
+
found_readable = []
|
|
2047
|
+
for (used, (f_ptype, f_base), (_, _, original_label)) in zip(consumed, found_keys, found_lookup):
|
|
2048
|
+
# Keep the original label for readability but also show our canonicalization
|
|
2049
|
+
tag = original_label
|
|
2050
|
+
if f_base:
|
|
2051
|
+
tag = f"{original_label} [{f_ptype}:{f_base}]"
|
|
2052
|
+
found_readable.append(tag)
|
|
2053
|
+
|
|
2054
|
+
missing = [expected_readable[i] for i, ok in enumerate(matched) if not ok]
|
|
2055
|
+
# If extras not allowed, mark unexpected any unconsumed found
|
|
2056
|
+
unexpected = []
|
|
2057
|
+
if not shelf_cfg.allow_extra_products:
|
|
2058
|
+
for used, (f_ptype, f_base), (_, _, original_label) in zip(consumed, found_keys, found_lookup):
|
|
2059
|
+
if not used:
|
|
2060
|
+
lbl = original_label
|
|
2061
|
+
if f_base:
|
|
2062
|
+
lbl = f"{original_label} [{f_ptype}:{f_base}]"
|
|
2063
|
+
unexpected.append(lbl)
|
|
2064
|
+
|
|
2065
|
+
# Product score = fraction of expected matched
|
|
2066
|
+
basic_score = (sum(1 for ok in matched if ok) / (len(expected) or 1.0))
|
|
2067
|
+
|
|
2068
|
+
# ADD VISUAL FEATURE SCORING
|
|
2069
|
+
visual_feature_score = 1.0
|
|
2070
|
+
if visual_feature_scores:
|
|
2071
|
+
visual_feature_score = sum(visual_feature_scores) / len(visual_feature_scores)
|
|
2072
|
+
|
|
2073
|
+
text_results, text_score, overall_text_ok = [], 1.0, True
|
|
2074
|
+
|
|
2075
|
+
endcap = planogram_description.advertisement_endcap
|
|
2076
|
+
if endcap and endcap.enabled and endcap.position == shelf_level:
|
|
2077
|
+
if endcap.text_requirements:
|
|
2078
|
+
# Combine visual features from all promotional items
|
|
2079
|
+
all_features = []
|
|
2080
|
+
ocr_blocks = []
|
|
2081
|
+
for promo in promos:
|
|
2082
|
+
if getattr(promo, "visual_features", None):
|
|
2083
|
+
all_features.extend(promo.visual_features)
|
|
2084
|
+
for feat in promo.visual_features:
|
|
2085
|
+
if isinstance(feat, str) and feat.startswith("ocr:"):
|
|
2086
|
+
ocr_blocks.append(feat[4:].strip())
|
|
2087
|
+
# if promo have ocr_text, add that too
|
|
2088
|
+
ocr_text = getattr(promo.detection_box, 'ocr_text', '')
|
|
2089
|
+
if ocr_text:
|
|
2090
|
+
ocr_blocks.append(ocr_text.strip())
|
|
2091
|
+
|
|
2092
|
+
if ocr_blocks:
|
|
2093
|
+
ocr_norm = self._normalize_ocr_text(" ".join(ocr_blocks))
|
|
2094
|
+
if ocr_norm:
|
|
2095
|
+
all_features.append(ocr_norm)
|
|
2096
|
+
|
|
2097
|
+
# If no promotional graphics found but text required, create default failure
|
|
2098
|
+
if not promos and shelf_level == "header":
|
|
2099
|
+
self.logger.warning(
|
|
2100
|
+
f"No promotional graphics found on {shelf_level} shelf but text requirements exist"
|
|
2101
|
+
)
|
|
2102
|
+
overall_text_ok = False
|
|
2103
|
+
for text_req in endcap.text_requirements:
|
|
2104
|
+
text_results.append(TextComplianceResult(
|
|
2105
|
+
required_text=text_req.required_text,
|
|
2106
|
+
found=False,
|
|
2107
|
+
matched_features=[],
|
|
2108
|
+
confidence=0.0,
|
|
2109
|
+
match_type=text_req.match_type
|
|
2110
|
+
))
|
|
2111
|
+
else:
|
|
2112
|
+
# Check text requirements against found features
|
|
2113
|
+
for text_req in endcap.text_requirements:
|
|
2114
|
+
result = TextMatcher.check_text_match(
|
|
2115
|
+
required_text=text_req.required_text,
|
|
2116
|
+
visual_features=all_features,
|
|
2117
|
+
match_type=text_req.match_type,
|
|
2118
|
+
case_sensitive=text_req.case_sensitive,
|
|
2119
|
+
confidence_threshold=text_req.confidence_threshold
|
|
2120
|
+
)
|
|
2121
|
+
text_results.append(result)
|
|
2122
|
+
|
|
2123
|
+
if not result.found and text_req.mandatory:
|
|
2124
|
+
overall_text_ok = False
|
|
2125
|
+
|
|
2126
|
+
# Calculate text compliance score
|
|
2127
|
+
if text_results:
|
|
2128
|
+
text_score = sum(r.confidence for r in text_results if r.found) / len(text_results)
|
|
2129
|
+
|
|
2130
|
+
elif shelf_level != "header":
|
|
2131
|
+
overall_text_ok = True
|
|
2132
|
+
text_score = 1.0
|
|
2133
|
+
|
|
2134
|
+
threshold = getattr(
|
|
2135
|
+
shelf_cfg, "compliance_threshold", planogram_description.global_compliance_threshold or 0.8
|
|
2136
|
+
)
|
|
2137
|
+
|
|
2138
|
+
major_unexpected = [
|
|
2139
|
+
p for p in unexpected if "ink" not in p.lower() and "price tag" not in p.lower()
|
|
2140
|
+
]
|
|
2141
|
+
|
|
2142
|
+
# MODIFIED: Status determination logic with brand check override
|
|
2143
|
+
status = ComplianceStatus.NON_COMPLIANT # Default status
|
|
2144
|
+
if shelf_level != "header":
|
|
2145
|
+
if basic_score >= threshold and not major_unexpected:
|
|
2146
|
+
status = ComplianceStatus.COMPLIANT
|
|
2147
|
+
elif basic_score == 0.0 and len(expected) > 0:
|
|
2148
|
+
status = ComplianceStatus.MISSING
|
|
2149
|
+
else: # Header shelf logic
|
|
2150
|
+
# The brand check is now a mandatory condition for compliance
|
|
2151
|
+
if not brand_check_ok:
|
|
2152
|
+
status = ComplianceStatus.NON_COMPLIANT # OVERRIDE: Brand check failed
|
|
2153
|
+
elif basic_score >= threshold and not major_unexpected and overall_text_ok:
|
|
2154
|
+
status = ComplianceStatus.COMPLIANT
|
|
2155
|
+
elif basic_score == 0.0 and len(expected) > 0:
|
|
2156
|
+
status = ComplianceStatus.MISSING
|
|
2157
|
+
else:
|
|
2158
|
+
status = ComplianceStatus.NON_COMPLIANT
|
|
2159
|
+
|
|
2160
|
+
# MODIFIED: Combined score calculation with visual features
|
|
2161
|
+
# Use the existing visual_features_weight from CategoryDetectionConfig
|
|
2162
|
+
visual_weight = getattr(
|
|
2163
|
+
planogram_description,
|
|
2164
|
+
'visual_features_weight',
|
|
2165
|
+
0.2
|
|
2166
|
+
) # Default 20%
|
|
2167
|
+
|
|
2168
|
+
if shelf_level == "header" and endcap:
|
|
2169
|
+
# Adjust product weight to make room for visual features
|
|
2170
|
+
adjusted_product_weight = endcap.product_weight * (1 - visual_weight)
|
|
2171
|
+
visual_feature_weight = endcap.product_weight * visual_weight
|
|
2172
|
+
combined_score = (
|
|
2173
|
+
(basic_score * adjusted_product_weight) +
|
|
2174
|
+
(text_score * endcap.text_weight) +
|
|
2175
|
+
(brand_compliance_result.confidence * getattr(endcap, "brand_weight", 0.0)) +
|
|
2176
|
+
(visual_feature_score * visual_feature_weight)
|
|
2177
|
+
)
|
|
2178
|
+
else:
|
|
2179
|
+
combined_score = (
|
|
2180
|
+
basic_score * (1 - visual_weight) +
|
|
2181
|
+
text_score * 0.1 +
|
|
2182
|
+
visual_feature_score * visual_weight
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
# Ensure score never exceeds 1.0
|
|
2186
|
+
combined_score = min(1.0, max(0.0, combined_score))
|
|
2187
|
+
text_score = min(1.0, max(0.0, text_score))
|
|
2188
|
+
|
|
2189
|
+
# Prepare human-readable outputs
|
|
2190
|
+
expected = expected_readable
|
|
2191
|
+
found = found_readable
|
|
2192
|
+
results.append(
|
|
2193
|
+
ComplianceResult(
|
|
2194
|
+
shelf_level=shelf_level,
|
|
2195
|
+
expected_products=expected,
|
|
2196
|
+
found_products=found,
|
|
2197
|
+
missing_products=missing,
|
|
2198
|
+
unexpected_products=unexpected,
|
|
2199
|
+
compliance_status=status,
|
|
2200
|
+
compliance_score=combined_score,
|
|
2201
|
+
text_compliance_results=text_results,
|
|
2202
|
+
text_compliance_score=text_score,
|
|
2203
|
+
overall_text_compliant=overall_text_ok,
|
|
2204
|
+
brand_compliance_result=brand_compliance_result
|
|
2205
|
+
)
|
|
2206
|
+
)
|
|
2207
|
+
|
|
2208
|
+
return results
|
|
2209
|
+
|
|
2210
|
+
def _base_model_from_str(self, s: str, brand: str = None) -> str:
|
|
2211
|
+
"""
|
|
2212
|
+
Extract normalized base model from any text, supporting multiple brands.
|
|
2213
|
+
|
|
2214
|
+
Args:
|
|
2215
|
+
s: String to extract model from
|
|
2216
|
+
brand: Optional brand hint to improve extraction
|
|
2217
|
+
|
|
2218
|
+
Returns:
|
|
2219
|
+
Normalized model string or empty string if no model found
|
|
2220
|
+
"""
|
|
2221
|
+
if not s:
|
|
2222
|
+
return ""
|
|
2223
|
+
|
|
2224
|
+
t = s.lower().strip()
|
|
2225
|
+
# normalize separators
|
|
2226
|
+
t = t.replace("—", "-").replace("–", "-").replace("_", "-")
|
|
2227
|
+
|
|
2228
|
+
# Brand-specific patterns
|
|
2229
|
+
if brand and brand.lower() == "epson":
|
|
2230
|
+
# EPSON EcoTank models: ET-2980, ET-3950, ET-4950
|
|
2231
|
+
m = re.search(r"(et)[- ]?(\d{4})", t)
|
|
2232
|
+
if m:
|
|
2233
|
+
return f"{m.group(1)}-{m.group(2)}"
|
|
2234
|
+
|
|
2235
|
+
elif brand and brand.lower() == "hisense":
|
|
2236
|
+
# HISENSE TV models: U6, U7, U8, plus potential series numbers
|
|
2237
|
+
# Patterns: U7, U8, U6, 55U8, U7K, etc.
|
|
2238
|
+
if re.search(r"canvas[\s-]*tv", t):
|
|
2239
|
+
return "canvas-tv"
|
|
2240
|
+
if re.search(r"canvas", t):
|
|
2241
|
+
return "canvas"
|
|
2242
|
+
patterns = [
|
|
2243
|
+
r"(\d*)(u\d+)([a-z]*)", # 55U8K, U7, U8K, etc.
|
|
2244
|
+
r"(u\d+)", # Simple U6, U7, U8
|
|
2245
|
+
]
|
|
2246
|
+
for pattern in patterns:
|
|
2247
|
+
m = re.search(pattern, t)
|
|
2248
|
+
if m:
|
|
2249
|
+
if len(m.groups()) >= 2:
|
|
2250
|
+
# Extract size + series + variant if available
|
|
2251
|
+
size = m.group(1) if m.group(1) else ""
|
|
2252
|
+
series = m.group(2)
|
|
2253
|
+
variant = m.group(3) if len(m.groups()) > 2 and m.group(3) else ""
|
|
2254
|
+
return f"{size}{series}{variant}".lower()
|
|
2255
|
+
else:
|
|
2256
|
+
return m.group(1).lower()
|
|
2257
|
+
|
|
2258
|
+
# Generic patterns for any brand
|
|
2259
|
+
generic_patterns = [
|
|
2260
|
+
# Model with dashes: ABC-1234, XYZ-567
|
|
2261
|
+
r"([a-z]+)[- ]?(\d{3,4})",
|
|
2262
|
+
# Series patterns: U7, U8, A6, etc.
|
|
2263
|
+
r"([a-z]\d+)",
|
|
2264
|
+
# Number-letter combinations: 4950, 2980 (for fallback)
|
|
2265
|
+
r"(\d{4})",
|
|
2266
|
+
]
|
|
2267
|
+
|
|
2268
|
+
for pattern in generic_patterns:
|
|
2269
|
+
m = re.search(pattern, t)
|
|
2270
|
+
if m:
|
|
2271
|
+
if len(m.groups()) >= 2:
|
|
2272
|
+
return f"{m.group(1)}-{m.group(2)}"
|
|
2273
|
+
else:
|
|
2274
|
+
return m.group(1).lower()
|
|
2275
|
+
|
|
2276
|
+
return ""
|
|
2277
|
+
|
|
2278
|
+
def _looks_like_box(self, visual_features: list[str] | None) -> bool:
|
|
2279
|
+
"""Heuristic: does the detection look like packaging?"""
|
|
2280
|
+
if not visual_features:
|
|
2281
|
+
return False
|
|
2282
|
+
keywords = {"packaging", "package", "cardboard", "box", "blue packaging", "printer image on box"}
|
|
2283
|
+
norm = " ".join(visual_features).lower()
|
|
2284
|
+
return any(k in norm for k in keywords)
|
|
2285
|
+
|
|
2286
|
+
def _canonical_expected_key(self, sp: str, brand: str) -> tuple[str, str]:
|
|
2287
|
+
"""
|
|
2288
|
+
From planogram product spec -> (product_type, base_model).
|
|
2289
|
+
Example: name='ET-4950', product_type='product_box' -> ('product_box','et-4950')
|
|
2290
|
+
"""
|
|
2291
|
+
ptype = (sp.product_type or "").strip().lower()
|
|
2292
|
+
# Normalize product types
|
|
2293
|
+
type_mappings = {
|
|
2294
|
+
"tv_demonstration": "tv",
|
|
2295
|
+
"promotional_graphic": "promotional_graphic",
|
|
2296
|
+
"product_box": "product_box",
|
|
2297
|
+
"printer": "printer",
|
|
2298
|
+
"promotional_materials": "promotional_materials"
|
|
2299
|
+
}
|
|
2300
|
+
ptype = type_mappings.get(ptype, ptype)
|
|
2301
|
+
model_str = getattr(sp, "name", "") or getattr(sp, "product_model", "") or ""
|
|
2302
|
+
base = self._base_model_from_str(model_str, brand=brand)
|
|
2303
|
+
return ptype or "unknown", base or ""
|
|
2304
|
+
|
|
2305
|
+
def _canonical_found_key(self, p: str, brand: str) -> tuple[str, str, float]:
|
|
2306
|
+
"""
|
|
2307
|
+
From IdentifiedProduct -> (resolved_product_type, base_model, adjusted_confidence).
|
|
2308
|
+
If visual features scream 'box', coerce/confirm product_type as 'product_box' and boost conf a bit.
|
|
2309
|
+
"""
|
|
2310
|
+
ptype = (p.product_type or "").strip().lower()
|
|
2311
|
+
# Normalize product types
|
|
2312
|
+
type_mappings = {
|
|
2313
|
+
"tv_demonstration": "tv",
|
|
2314
|
+
"promotional_graphic": "promotional_graphic",
|
|
2315
|
+
"product_box": "product_box",
|
|
2316
|
+
"printer": "printer",
|
|
2317
|
+
"promotional_material": "promotional_material",
|
|
2318
|
+
"promotional_display": "promotional_display"
|
|
2319
|
+
}
|
|
2320
|
+
ptype = type_mappings.get(ptype, ptype)
|
|
2321
|
+
model_str = p.product_model or p.product_type or ""
|
|
2322
|
+
base = self._base_model_from_str(model_str, brand=brand)
|
|
2323
|
+
conf = float(getattr(p, "confidence", 0.0) or 0.0)
|
|
2324
|
+
|
|
2325
|
+
if self._looks_like_box(getattr(p, "visual_features", None)):
|
|
2326
|
+
if ptype != "product_box":
|
|
2327
|
+
ptype = "product_box"
|
|
2328
|
+
conf = min(1.0, conf + 0.05) # gentle nudge for box evidence
|
|
2329
|
+
return ptype or "unknown", base or "", conf
|
|
2330
|
+
|
|
2331
|
+
async def _find_poster(
|
|
2332
|
+
self,
|
|
2333
|
+
image: Image.Image,
|
|
2334
|
+
planogram: PlanogramDescription,
|
|
2335
|
+
partial_prompt: str
|
|
2336
|
+
) -> tuple[Detections, Detections, Detections, Detections]:
|
|
2337
|
+
"""
|
|
2338
|
+
Ask VISION Model to find the main promotional graphic for the given brand/tags.
|
|
2339
|
+
Returns (x1,y1,x2,y2) in absolute pixels, and the parsed JSON for logging.
|
|
2340
|
+
"""
|
|
2341
|
+
brand = (getattr(planogram, "brand", "") or "").strip()
|
|
2342
|
+
tags = [t.strip() for t in getattr(planogram, "tags", []) or []]
|
|
2343
|
+
endcap = getattr(planogram, "advertisement_endcap", None)
|
|
2344
|
+
geometry = self.planogram_config.endcap_geometry
|
|
2345
|
+
if endcap and getattr(endcap, "text_requirements", None):
|
|
2346
|
+
for tr in endcap.text_requirements:
|
|
2347
|
+
if getattr(tr, "required_text", None):
|
|
2348
|
+
tags.append(tr.required_text)
|
|
2349
|
+
tag_hint = ", ".join(sorted(set(f"'{t}'" for t in tags if t)))
|
|
2350
|
+
|
|
2351
|
+
# downscale for LLM
|
|
2352
|
+
image_small = self._downscale_image(image, max_side=1024, quality=78)
|
|
2353
|
+
prompt = partial_prompt.format(
|
|
2354
|
+
brand=brand,
|
|
2355
|
+
tag_hint=tag_hint,
|
|
2356
|
+
image_size=image_small.size
|
|
2357
|
+
)
|
|
2358
|
+
max_attempts = 2 # Initial attempt + 1 retry
|
|
2359
|
+
retry_delay_seconds = 10
|
|
2360
|
+
msg = None
|
|
2361
|
+
for attempt in range(max_attempts):
|
|
2362
|
+
try:
|
|
2363
|
+
async with self.roi_client as client:
|
|
2364
|
+
msg = await client.ask_to_image(
|
|
2365
|
+
image=image_small,
|
|
2366
|
+
prompt=prompt,
|
|
2367
|
+
model="gemini-2.5-flash",
|
|
2368
|
+
no_memory=True,
|
|
2369
|
+
structured_output=Detections,
|
|
2370
|
+
max_tokens=8192
|
|
2371
|
+
)
|
|
2372
|
+
# If the call succeeds, break out of the loop
|
|
2373
|
+
break
|
|
2374
|
+
except ServerError as e:
|
|
2375
|
+
# Check if this was the last attempt
|
|
2376
|
+
if attempt < max_attempts - 1:
|
|
2377
|
+
print(
|
|
2378
|
+
f"WARNING: Model is overloaded. Retrying in {retry_delay_seconds} seconds... (Attempt {attempt + 1}/{max_attempts})"
|
|
2379
|
+
)
|
|
2380
|
+
await asyncio.sleep(retry_delay_seconds)
|
|
2381
|
+
else:
|
|
2382
|
+
print(
|
|
2383
|
+
f"ERROR: Model is still overloaded after {max_attempts} attempts. Failing."
|
|
2384
|
+
)
|
|
2385
|
+
# Re-raise the exception if the last attempt fails
|
|
2386
|
+
raise e
|
|
2387
|
+
# Evaluate the Output:
|
|
2388
|
+
# print('MSG >> ', msg)
|
|
2389
|
+
# print('OUTPUT > ', msg.output)
|
|
2390
|
+
data = msg.structured_output or msg.output or {}
|
|
2391
|
+
dets = data.detections or []
|
|
2392
|
+
if not dets:
|
|
2393
|
+
return None, data
|
|
2394
|
+
# pick detections
|
|
2395
|
+
panel_det = next(
|
|
2396
|
+
(d for d in dets if d.label == "poster_panel"), None) \
|
|
2397
|
+
or next((d for d in dets if d.label == "poster"), None) \
|
|
2398
|
+
or (max(dets, key=lambda x: float(x.confidence)) if dets else None
|
|
2399
|
+
)
|
|
2400
|
+
# poster text:
|
|
2401
|
+
text_det = next((d for d in dets if d.label == "poster_text"), None)
|
|
2402
|
+
# brand logo:
|
|
2403
|
+
brand_det = next((d for d in dets if d.label == "brand_logo"), None)
|
|
2404
|
+
if not panel_det:
|
|
2405
|
+
self.logger.error("Critical failure: Could not detect the poster_panel.")
|
|
2406
|
+
return None, None, None, None
|
|
2407
|
+
|
|
2408
|
+
# promotional graphic (inside the panel):
|
|
2409
|
+
promo_graphic_det = next(
|
|
2410
|
+
(d for d in dets if d.label == "promotional_graphic"), None
|
|
2411
|
+
)
|
|
2412
|
+
|
|
2413
|
+
# check if promo_graphic is contained by panel_det, if not, increase the panel:
|
|
2414
|
+
if promo_graphic_det and panel_det:
|
|
2415
|
+
# If promo graphic is outside panel, expand panel to include it
|
|
2416
|
+
if not (
|
|
2417
|
+
promo_graphic_det.bbox.x1 >= panel_det.bbox.x1 and
|
|
2418
|
+
promo_graphic_det.bbox.x2 <= panel_det.bbox.x2
|
|
2419
|
+
):
|
|
2420
|
+
self.logger.info("Expanding poster_panel to include promotional_graphic.")
|
|
2421
|
+
panel_det.bbox.x1 = min(panel_det.bbox.x1, promo_graphic_det.bbox.x1)
|
|
2422
|
+
panel_det.bbox.x2 = max(panel_det.bbox.x2, promo_graphic_det.bbox.x2)
|
|
2423
|
+
|
|
2424
|
+
# Get planogram advertisement config with safe defaults
|
|
2425
|
+
advertisement_config = getattr(planogram, "advertisement_endcap", {})
|
|
2426
|
+
# # Default values if not in planogram, normalized to image (not ROI)
|
|
2427
|
+
# config_width_percent = advertisement_config.width_margin_percent
|
|
2428
|
+
# config_height_percent = advertisement_config.height_margin_percent
|
|
2429
|
+
# config_top_margin_percent = advertisement_config.top_margin_percent
|
|
2430
|
+
# # E.g., 5% of panel width
|
|
2431
|
+
# side_margin_percent = advertisement_config.side_margin_percent
|
|
2432
|
+
|
|
2433
|
+
config_width_percent = geometry.width_margin_percent
|
|
2434
|
+
config_height_percent = geometry.height_margin_percent
|
|
2435
|
+
config_top_margin_percent = geometry.top_margin_percent
|
|
2436
|
+
side_margin_percent = geometry.side_margin_percent
|
|
2437
|
+
|
|
2438
|
+
# --- Refined Panel Padding ---
|
|
2439
|
+
# Apply padding to the panel_det itself to ensure it captures the full visual area
|
|
2440
|
+
panel_det.bbox.x1 = max(0.0, panel_det.bbox.x1 - side_margin_percent)
|
|
2441
|
+
panel_det.bbox.x2 = min(1.0, panel_det.bbox.x2 + side_margin_percent)
|
|
2442
|
+
|
|
2443
|
+
if panel_det and text_det:
|
|
2444
|
+
text_bottom_y2 = text_det.bbox.y2
|
|
2445
|
+
padding = 0.08
|
|
2446
|
+
new_panel_y2 = min(text_bottom_y2 + padding, 1.0)
|
|
2447
|
+
panel_det.bbox.y2 = new_panel_y2
|
|
2448
|
+
|
|
2449
|
+
# --- endcap Detected:
|
|
2450
|
+
endcap_det = next((d for d in dets if d.label == "endcap"), None)
|
|
2451
|
+
|
|
2452
|
+
# panel
|
|
2453
|
+
px1, py1, px2, py2 = panel_det.bbox.x1, panel_det.bbox.y1, panel_det.bbox.x2, panel_det.bbox.y2
|
|
2454
|
+
|
|
2455
|
+
# Initial endcap box: Use the LLM's endcap detection if it exists, otherwise fall back to the panel
|
|
2456
|
+
if endcap_det:
|
|
2457
|
+
ex1, ey1, ex2, ey2 = endcap_det.bbox.x1, endcap_det.bbox.y1, endcap_det.bbox.x2, endcap_det.bbox.y2
|
|
2458
|
+
else:
|
|
2459
|
+
ex1, ey1, ex2, ey2 = px1, py1, px2, py2
|
|
2460
|
+
|
|
2461
|
+
if endcap_det is None:
|
|
2462
|
+
panel_h = py2 - py1
|
|
2463
|
+
ratio = max(1e-6, float(config_height_percent))
|
|
2464
|
+
top_margin = float(config_top_margin_percent)
|
|
2465
|
+
ey1 = max(0.0, py1 - top_margin)
|
|
2466
|
+
ey2 = min(1.0, ey1 + panel_h / ratio)
|
|
2467
|
+
|
|
2468
|
+
x_buffer = max(self.left_margin_ratio * (px2-px1), self.right_margin_ratio * (px2-px1))
|
|
2469
|
+
ex1 = min(ex1, px1 - x_buffer)
|
|
2470
|
+
ex2 = max(ex2, px2 + x_buffer)
|
|
2471
|
+
|
|
2472
|
+
# Clamp & monotonic
|
|
2473
|
+
ex1 = max(0.0, ex1)
|
|
2474
|
+
ex2 = min(1.0, ex2)
|
|
2475
|
+
if ex2 <= ex1:
|
|
2476
|
+
ex2 = ex1 + 1e-6
|
|
2477
|
+
ey1 = max(0.0, ey1)
|
|
2478
|
+
ey2 = min(1.0, ey2)
|
|
2479
|
+
if ey2 <= ey1:
|
|
2480
|
+
ey2 = ey1 + 1e-6
|
|
2481
|
+
|
|
2482
|
+
# Update the endcap_det bbox with the corrected values
|
|
2483
|
+
if endcap_det is None:
|
|
2484
|
+
endcap_det = DetectionBox(
|
|
2485
|
+
x1=ex1, y1=ey1, x2=ex2, y2=ey2,
|
|
2486
|
+
confidence=0.9, # Assign a default confidence
|
|
2487
|
+
label="endcap"
|
|
2488
|
+
)
|
|
2489
|
+
else:
|
|
2490
|
+
endcap_det.bbox.x1 = ex1
|
|
2491
|
+
endcap_det.bbox.x2 = ex2
|
|
2492
|
+
endcap_det.bbox.y1 = ey1
|
|
2493
|
+
endcap_det.bbox.y2 = ey2
|
|
2494
|
+
|
|
2495
|
+
return endcap_det, panel_det, brand_det, text_det, dets
|
|
2496
|
+
|
|
2497
|
+
# Complete Pipeline
|
|
2498
|
+
async def run(
|
|
2499
|
+
self,
|
|
2500
|
+
image: Union[str, Path, Image.Image],
|
|
2501
|
+
debug_raw="/tmp/data/yolo_raw_debug.png",
|
|
2502
|
+
return_overlay: Optional[str] = None, # "identified" | "detections" | "both" | None
|
|
2503
|
+
overlay_save_path: Optional[Union[str, Path]] = None,
|
|
2504
|
+
) -> Dict[str, Any]:
|
|
2505
|
+
"""
|
|
2506
|
+
Run the complete 3-step planogram compliance pipeline
|
|
2507
|
+
|
|
2508
|
+
Returns:
|
|
2509
|
+
Complete analysis results including all steps
|
|
2510
|
+
"""
|
|
2511
|
+
self.logger.debug("Step 1: Find Region of Interest...")
|
|
2512
|
+
# Optimize Image for Classification:
|
|
2513
|
+
img = self.open_image(image)
|
|
2514
|
+
|
|
2515
|
+
# ROI detection:
|
|
2516
|
+
img_array = np.array(img) # RGB
|
|
2517
|
+
|
|
2518
|
+
# 1) Find the poster:
|
|
2519
|
+
planogram_description = self.planogram_config.get_planogram_description()
|
|
2520
|
+
endcap, ad, brand, panel_text, dets = await self._find_poster(
|
|
2521
|
+
img,
|
|
2522
|
+
planogram_description,
|
|
2523
|
+
partial_prompt=self.planogram_config.roi_detection_prompt
|
|
2524
|
+
)
|
|
2525
|
+
if return_overlay == 'detections' or return_overlay == 'both':
|
|
2526
|
+
debug_poster_path = debug_raw.replace(".png", "_poster_debug.png") if debug_raw else None
|
|
2527
|
+
panel_px = ad.bbox.get_coordinates()
|
|
2528
|
+
self._save_detections(
|
|
2529
|
+
image, panel_px, dets, debug_poster_path
|
|
2530
|
+
)
|
|
2531
|
+
# Check if detections are valid before proceeding
|
|
2532
|
+
if not endcap or not ad:
|
|
2533
|
+
print("ERROR: Failed to get required detections.")
|
|
2534
|
+
return # or raise an exception
|
|
2535
|
+
|
|
2536
|
+
# Locate Shelves and Objects:
|
|
2537
|
+
shelf_regions, detections = await self.detect_objects_and_shelves(
|
|
2538
|
+
image,
|
|
2539
|
+
img_array,
|
|
2540
|
+
endcap=endcap,
|
|
2541
|
+
ad=ad,
|
|
2542
|
+
brand=brand,
|
|
2543
|
+
panel_text=panel_text,
|
|
2544
|
+
planogram_description=planogram_description
|
|
2545
|
+
)
|
|
2546
|
+
|
|
2547
|
+
self.logger.debug(
|
|
2548
|
+
f"Found {len(detections)} objects in {len(shelf_regions)} shelf regions"
|
|
2549
|
+
)
|
|
2550
|
+
|
|
2551
|
+
self.logger.notice("Step 2: Identifying objects with LLM...")
|
|
2552
|
+
identified_products = await self.identify_objects_with_references(
|
|
2553
|
+
image,
|
|
2554
|
+
detections,
|
|
2555
|
+
shelf_regions,
|
|
2556
|
+
self.reference_images,
|
|
2557
|
+
prompt=self.planogram_config.object_identification_prompt
|
|
2558
|
+
)
|
|
2559
|
+
|
|
2560
|
+
self.logger.debug(
|
|
2561
|
+
f"Identified Products: {identified_products}"
|
|
2562
|
+
)
|
|
2563
|
+
|
|
2564
|
+
compliance_results = self.check_planogram_compliance(
|
|
2565
|
+
identified_products, planogram_description
|
|
2566
|
+
)
|
|
2567
|
+
|
|
2568
|
+
# Calculate overall compliance
|
|
2569
|
+
total_score = sum(
|
|
2570
|
+
r.compliance_score for r in compliance_results
|
|
2571
|
+
) / len(compliance_results) if compliance_results else 0.0
|
|
2572
|
+
if total_score >= (planogram_description.global_compliance_threshold or 0.8):
|
|
2573
|
+
overall_compliant = True
|
|
2574
|
+
else:
|
|
2575
|
+
overall_compliant = all(
|
|
2576
|
+
r.compliance_status == ComplianceStatus.COMPLIANT for r in compliance_results
|
|
2577
|
+
)
|
|
2578
|
+
overlay_image = None
|
|
2579
|
+
overlay_path = None
|
|
2580
|
+
if return_overlay == 'identified' or return_overlay == 'both':
|
|
2581
|
+
try:
|
|
2582
|
+
overlay_image = self.render_evaluated_image(
|
|
2583
|
+
image,
|
|
2584
|
+
shelf_regions=shelf_regions,
|
|
2585
|
+
detections=detections,
|
|
2586
|
+
identified_products=identified_products,
|
|
2587
|
+
mode=return_overlay,
|
|
2588
|
+
show_shelves=True,
|
|
2589
|
+
save_to=overlay_save_path,
|
|
2590
|
+
)
|
|
2591
|
+
if overlay_save_path:
|
|
2592
|
+
overlay_path = str(Path(overlay_save_path))
|
|
2593
|
+
except Exception as e:
|
|
2594
|
+
self.logger.error(f"Failed to render overlay image: {e}")
|
|
2595
|
+
# is not mandatory to fail the whole pipeline
|
|
2596
|
+
overlay_image = None
|
|
2597
|
+
overlay_path = None
|
|
2598
|
+
|
|
2599
|
+
return {
|
|
2600
|
+
"step1_detections": detections,
|
|
2601
|
+
"step1_shelf_regions": shelf_regions,
|
|
2602
|
+
"step2_identified_products": identified_products,
|
|
2603
|
+
"step3_compliance_results": compliance_results,
|
|
2604
|
+
"overall_compliance_score": total_score,
|
|
2605
|
+
"overall_compliant": overall_compliant,
|
|
2606
|
+
"analysis_timestamp": datetime.now(),
|
|
2607
|
+
"overlay_image": overlay_image,
|
|
2608
|
+
"overlay_path": overlay_path,
|
|
2609
|
+
}
|
|
2610
|
+
|
|
2611
|
+
def render_evaluated_image(
|
|
2612
|
+
self,
|
|
2613
|
+
image: Union[str, Path, Image.Image],
|
|
2614
|
+
*,
|
|
2615
|
+
shelf_regions: Optional[List[ShelfRegion]] = None,
|
|
2616
|
+
detections: Optional[List[DetectionBox]] = None,
|
|
2617
|
+
identified_products: Optional[List[IdentifiedProduct]] = None,
|
|
2618
|
+
mode: str = "identified",
|
|
2619
|
+
show_shelves: bool = True,
|
|
2620
|
+
save_to: Optional[Union[str, Path]] = None,
|
|
2621
|
+
) -> Image.Image:
|
|
2622
|
+
"""
|
|
2623
|
+
Enhanced render with safe coordinate handling
|
|
2624
|
+
"""
|
|
2625
|
+
def _norm_box(x1, y1, x2, y2):
|
|
2626
|
+
"""Normalize box coordinates to ensure valid rectangle"""
|
|
2627
|
+
x1, x2 = int(x1), int(x2)
|
|
2628
|
+
y1, y2 = int(y1), int(y2)
|
|
2629
|
+
|
|
2630
|
+
# Ensure coordinates are in correct order
|
|
2631
|
+
if x1 > x2:
|
|
2632
|
+
x1, x2 = x2, x1
|
|
2633
|
+
if y1 > y2:
|
|
2634
|
+
y1, y2 = y2, y1
|
|
2635
|
+
|
|
2636
|
+
# Ensure minimum size
|
|
2637
|
+
if x2 - x1 < 1:
|
|
2638
|
+
x2 = x1 + 1
|
|
2639
|
+
if y2 - y1 < 1:
|
|
2640
|
+
y2 = y1 + 1
|
|
2641
|
+
|
|
2642
|
+
return x1, y1, x2, y2
|
|
2643
|
+
|
|
2644
|
+
# Get base image
|
|
2645
|
+
if isinstance(image, (str, Path)):
|
|
2646
|
+
base = Image.open(image).convert("RGB").copy()
|
|
2647
|
+
else:
|
|
2648
|
+
base = image.convert("RGB").copy()
|
|
2649
|
+
|
|
2650
|
+
draw = ImageDraw.Draw(base)
|
|
2651
|
+
try:
|
|
2652
|
+
font = ImageFont.load_default()
|
|
2653
|
+
except Exception:
|
|
2654
|
+
font = None
|
|
2655
|
+
|
|
2656
|
+
W, H = base.size
|
|
2657
|
+
|
|
2658
|
+
def _clip(x1, y1, x2, y2):
|
|
2659
|
+
"""Clip coordinates to image bounds"""
|
|
2660
|
+
return max(0, x1), max(0, y1), min(W-1, x2), min(H-1, y2)
|
|
2661
|
+
|
|
2662
|
+
def _txt(draw_obj, xy, text, fill, bg=None):
|
|
2663
|
+
"""Safe text drawing with error handling"""
|
|
2664
|
+
try:
|
|
2665
|
+
if not font:
|
|
2666
|
+
draw_obj.text(xy, text, fill=fill)
|
|
2667
|
+
return
|
|
2668
|
+
bbox = draw_obj.textbbox(xy, text, font=font)
|
|
2669
|
+
if bg is not None:
|
|
2670
|
+
draw_obj.rectangle(bbox, fill=bg)
|
|
2671
|
+
draw_obj.text(xy, text, fill=fill, font=font)
|
|
2672
|
+
except Exception:
|
|
2673
|
+
# Fallback to simple text if there's any error
|
|
2674
|
+
try:
|
|
2675
|
+
draw_obj.text(xy, text, fill=fill)
|
|
2676
|
+
except Exception:
|
|
2677
|
+
pass # Skip this text if it still fails
|
|
2678
|
+
|
|
2679
|
+
# Colors per product type
|
|
2680
|
+
colors = {
|
|
2681
|
+
"tv_demonstration": (0, 255, 0), # green for TVs
|
|
2682
|
+
"promotional_graphic": (255, 0, 255), # magenta for logos
|
|
2683
|
+
"promotional_base": (0, 0, 255), # blue for partner branding
|
|
2684
|
+
"fact_tag": (255, 255, 0), # yellow for info displays
|
|
2685
|
+
"product_box": (255, 128, 0), # orange
|
|
2686
|
+
"printer": (255, 0, 0), # red
|
|
2687
|
+
"unknown": (200, 200, 200), # gray
|
|
2688
|
+
}
|
|
2689
|
+
|
|
2690
|
+
# Draw shelves
|
|
2691
|
+
if show_shelves and shelf_regions:
|
|
2692
|
+
for sr in shelf_regions:
|
|
2693
|
+
try:
|
|
2694
|
+
x1, y1, x2, y2 = _clip(sr.bbox.x1, sr.bbox.y1, sr.bbox.x2, sr.bbox.y2)
|
|
2695
|
+
x1, y1, x2, y2 = _norm_box(x1, y1, x2, y2)
|
|
2696
|
+
draw.rectangle([x1, y1, x2, y2], outline=(255, 255, 0), width=3)
|
|
2697
|
+
_txt(draw, (x1+3, max(0, y1-14)), f"SHELF {sr.level}", fill=(0, 0, 0), bg=(255, 255, 0))
|
|
2698
|
+
except Exception as e:
|
|
2699
|
+
print(f"Warning: Could not draw shelf {sr.level}: {e}")
|
|
2700
|
+
|
|
2701
|
+
# Draw detections (thin)
|
|
2702
|
+
if mode in ("detections", "both") and detections:
|
|
2703
|
+
for i, d in enumerate(detections, start=1):
|
|
2704
|
+
try:
|
|
2705
|
+
x1, y1, x2, y2 = _clip(d.x1, d.y1, d.x2, d.y2)
|
|
2706
|
+
x1, y1, x2, y2 = _norm_box(x1, y1, x2, y2)
|
|
2707
|
+
draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0), width=2)
|
|
2708
|
+
lbl = f"ID:{i} {d.class_name} {d.confidence:.2f}"
|
|
2709
|
+
_txt(draw, (x1+2, max(0, y1-12)), lbl, fill=(0, 0, 0), bg=(255, 0, 0))
|
|
2710
|
+
except Exception as e:
|
|
2711
|
+
print(f"Warning: Could not draw detection {i}: {e}")
|
|
2712
|
+
|
|
2713
|
+
# Draw identified products (thick)
|
|
2714
|
+
if mode in ("identified", "both") and identified_products:
|
|
2715
|
+
for p in sorted(identified_products, key=lambda x: (x.detection_box.area if x.detection_box else 0), reverse=True):
|
|
2716
|
+
if not p.detection_box:
|
|
2717
|
+
continue
|
|
2718
|
+
try:
|
|
2719
|
+
x1, y1, x2, y2 = _clip(p.detection_box.x1, p.detection_box.y1, p.detection_box.x2, p.detection_box.y2)
|
|
2720
|
+
x1, y1, x2, y2 = _norm_box(x1, y1, x2, y2)
|
|
2721
|
+
|
|
2722
|
+
c = colors.get(p.product_type, (255, 0, 255))
|
|
2723
|
+
draw.rectangle([x1, y1, x2, y2], outline=c, width=5)
|
|
2724
|
+
|
|
2725
|
+
# Label
|
|
2726
|
+
pid = p.detection_id if p.detection_id is not None else "NEW"
|
|
2727
|
+
mm = f" {p.product_model}" if p.product_model else ""
|
|
2728
|
+
lab = f"#{pid} {p.product_type}{mm} ({p.confidence:.2f})"
|
|
2729
|
+
_txt(draw, (x1+3, max(0, y1-14)), lab, fill=(0, 0, 0), bg=c)
|
|
2730
|
+
|
|
2731
|
+
except Exception as e:
|
|
2732
|
+
print(f"Warning: Could not draw product {p.product_model}: {e}")
|
|
2733
|
+
|
|
2734
|
+
# Add legend
|
|
2735
|
+
legend_y = 8
|
|
2736
|
+
for key in ("tv_demonstration", "promotional_graphic", "promotional_base", "fact_tag"):
|
|
2737
|
+
if key in colors:
|
|
2738
|
+
try:
|
|
2739
|
+
c = colors[key]
|
|
2740
|
+
draw.rectangle([8, legend_y, 28, legend_y+10], fill=c)
|
|
2741
|
+
_txt(draw, (34, legend_y-2), key, fill=(255,255,255))
|
|
2742
|
+
legend_y += 14
|
|
2743
|
+
except Exception:
|
|
2744
|
+
pass
|
|
2745
|
+
|
|
2746
|
+
# Save if requested
|
|
2747
|
+
if save_to:
|
|
2748
|
+
try:
|
|
2749
|
+
save_to = Path(save_to)
|
|
2750
|
+
save_to.parent.mkdir(parents=True, exist_ok=True)
|
|
2751
|
+
base.save(save_to, quality=90)
|
|
2752
|
+
print(f"Overlay saved to: {save_to}")
|
|
2753
|
+
except Exception as e:
|
|
2754
|
+
print(f"Warning: Could not save overlay: {e}")
|
|
2755
|
+
|
|
2756
|
+
return base
|
|
2757
|
+
|
|
2758
|
+
def generate_compliance_json(self, results: Dict[str, Any]) -> Dict[str, Any]:
|
|
2759
|
+
"""
|
|
2760
|
+
Generate comprehensive JSON report from pipeline results.
|
|
2761
|
+
|
|
2762
|
+
Args:
|
|
2763
|
+
results: Complete results object from pipeline.run()
|
|
2764
|
+
|
|
2765
|
+
Returns:
|
|
2766
|
+
Dictionary containing comprehensive compliance report
|
|
2767
|
+
"""
|
|
2768
|
+
compliance_results = results['step3_compliance_results']
|
|
2769
|
+
|
|
2770
|
+
def serialize_compliance_result(result) -> Dict[str, Any]:
|
|
2771
|
+
"""Convert ComplianceResult to serializable dictionary."""
|
|
2772
|
+
result_dict = {
|
|
2773
|
+
"shelf_level": result.shelf_level,
|
|
2774
|
+
"compliance_status": result.compliance_status.value,
|
|
2775
|
+
"compliance_score": round(result.compliance_score, 3),
|
|
2776
|
+
"expected_products": result.expected_products,
|
|
2777
|
+
"found_products": result.found_products,
|
|
2778
|
+
"missing_products": result.missing_products,
|
|
2779
|
+
"unexpected_products": result.unexpected_products,
|
|
2780
|
+
"text_compliance": {
|
|
2781
|
+
"score": round(result.text_compliance_score, 3),
|
|
2782
|
+
"overall_compliant": result.overall_text_compliant,
|
|
2783
|
+
"requirements": []
|
|
2784
|
+
}
|
|
2785
|
+
}
|
|
2786
|
+
|
|
2787
|
+
# Add text compliance details
|
|
2788
|
+
for text_result in result.text_compliance_results:
|
|
2789
|
+
text_dict = {
|
|
2790
|
+
"required_text": text_result.required_text,
|
|
2791
|
+
"found": text_result.found,
|
|
2792
|
+
"confidence": round(text_result.confidence, 3),
|
|
2793
|
+
"match_type": text_result.match_type,
|
|
2794
|
+
"matched_features": text_result.matched_features
|
|
2795
|
+
}
|
|
2796
|
+
result_dict["text_compliance"]["requirements"].append(text_dict)
|
|
2797
|
+
|
|
2798
|
+
# Add brand compliance if present
|
|
2799
|
+
if hasattr(result, 'brand_compliance_result') and result.brand_compliance_result:
|
|
2800
|
+
result_dict["brand_compliance"] = {
|
|
2801
|
+
"expected_brand": result.brand_compliance_result.expected_brand,
|
|
2802
|
+
"found_brand": result.brand_compliance_result.found_brand,
|
|
2803
|
+
"found": result.brand_compliance_result.found,
|
|
2804
|
+
"confidence": round(result.brand_compliance_result.confidence, 3)
|
|
2805
|
+
}
|
|
2806
|
+
|
|
2807
|
+
return result_dict
|
|
2808
|
+
|
|
2809
|
+
# Build the main report structure
|
|
2810
|
+
report = {
|
|
2811
|
+
"metadata": {
|
|
2812
|
+
"analysis_timestamp": results['analysis_timestamp'].isoformat(),
|
|
2813
|
+
"report_version": "1.0",
|
|
2814
|
+
"total_shelves_analyzed": len(compliance_results)
|
|
2815
|
+
},
|
|
2816
|
+
"overall_compliance": {
|
|
2817
|
+
"compliant": results['overall_compliant'],
|
|
2818
|
+
"score": round(results['overall_compliance_score'], 3),
|
|
2819
|
+
"percentage": f"{results['overall_compliance_score']:.1%}"
|
|
2820
|
+
},
|
|
2821
|
+
"shelf_results": [serialize_compliance_result(result) for result in compliance_results],
|
|
2822
|
+
"summary": {
|
|
2823
|
+
"compliant_shelves": sum(1 for r in compliance_results if r.compliance_status.value == "compliant"),
|
|
2824
|
+
"non_compliant_shelves": sum(1 for r in compliance_results if r.compliance_status.value == "non_compliant"),
|
|
2825
|
+
"missing_shelves": sum(1 for r in compliance_results if r.compliance_status.value == "missing"),
|
|
2826
|
+
"average_shelf_score": round(sum(r.compliance_score for r in compliance_results) / len(compliance_results), 3) if compliance_results else 0.0
|
|
2827
|
+
}
|
|
2828
|
+
}
|
|
2829
|
+
|
|
2830
|
+
# Add overlay path if provided
|
|
2831
|
+
if 'overlay_path' in results and results['overlay_path']:
|
|
2832
|
+
report["artifacts"] = {
|
|
2833
|
+
"overlay_image_path": str(results['overlay_path'])
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
return report
|
|
2837
|
+
|
|
2838
|
+
def generate_compliance_markdown(
|
|
2839
|
+
self,
|
|
2840
|
+
results: Dict[str, Any],
|
|
2841
|
+
brand_name: Optional[str] = None,
|
|
2842
|
+
additional_notes: Optional[str] = None
|
|
2843
|
+
) -> str:
|
|
2844
|
+
"""
|
|
2845
|
+
Generate comprehensive Markdown report from pipeline results.
|
|
2846
|
+
|
|
2847
|
+
Args:
|
|
2848
|
+
results: Complete results object from pipeline.run()
|
|
2849
|
+
brand_name: Brand being analyzed (optional)
|
|
2850
|
+
additional_notes: Additional notes to include (optional)
|
|
2851
|
+
|
|
2852
|
+
Returns:
|
|
2853
|
+
Formatted Markdown string
|
|
2854
|
+
"""
|
|
2855
|
+
compliance_results = results['step3_compliance_results']
|
|
2856
|
+
overall_compliance_score = results['overall_compliance_score']
|
|
2857
|
+
overall_compliant = results['overall_compliant']
|
|
2858
|
+
analysis_timestamp = results['analysis_timestamp']
|
|
2859
|
+
overlay_path = results.get('overlay_path')
|
|
2860
|
+
|
|
2861
|
+
def status_emoji(status: str) -> str:
|
|
2862
|
+
"""Get emoji for compliance status."""
|
|
2863
|
+
status_map = {
|
|
2864
|
+
"compliant": "✅",
|
|
2865
|
+
"non_compliant": "❌",
|
|
2866
|
+
"missing": "⚠️",
|
|
2867
|
+
"misplaced": "🔄"
|
|
2868
|
+
}
|
|
2869
|
+
return status_map.get(status, "❓")
|
|
2870
|
+
|
|
2871
|
+
def format_percentage(score: float) -> str:
|
|
2872
|
+
"""Format score as percentage."""
|
|
2873
|
+
return f"{score:.1%}"
|
|
2874
|
+
|
|
2875
|
+
# Start building the markdown
|
|
2876
|
+
lines = []
|
|
2877
|
+
|
|
2878
|
+
# Header
|
|
2879
|
+
brand_title = f" - {brand_name}" if brand_name else ""
|
|
2880
|
+
lines.append(f"# Planogram Compliance Report{brand_title}")
|
|
2881
|
+
lines.append("")
|
|
2882
|
+
lines.append(
|
|
2883
|
+
f"**Analysis Date:** {analysis_timestamp.strftime('%Y-%m-%d %H:%M:%S')}"
|
|
2884
|
+
)
|
|
2885
|
+
lines.append("")
|
|
2886
|
+
|
|
2887
|
+
# Overall Compliance Section
|
|
2888
|
+
overall_emoji = "✅" if overall_compliant else "❌"
|
|
2889
|
+
lines.append("## Overall Compliance")
|
|
2890
|
+
lines.append("")
|
|
2891
|
+
lines.append(f"**Status:** {overall_emoji} {'COMPLIANT' if overall_compliant else 'NON-COMPLIANT'}")
|
|
2892
|
+
lines.append(f"**Score:** {format_percentage(overall_compliance_score)}")
|
|
2893
|
+
lines.append("")
|
|
2894
|
+
|
|
2895
|
+
# Summary Statistics
|
|
2896
|
+
compliant_count = sum(1 for r in compliance_results if r.compliance_status.value == "compliant")
|
|
2897
|
+
total_count = len(compliance_results)
|
|
2898
|
+
|
|
2899
|
+
lines.append("## Summary")
|
|
2900
|
+
lines.append("")
|
|
2901
|
+
lines.append(f"- **Total Shelves:** {total_count}")
|
|
2902
|
+
lines.append(f"- **Compliant Shelves:** {compliant_count}/{total_count}")
|
|
2903
|
+
lines.append(f"- **Non-Compliant Shelves:** {total_count - compliant_count}/{total_count}")
|
|
2904
|
+
|
|
2905
|
+
if compliance_results:
|
|
2906
|
+
avg_score = sum(r.compliance_score for r in compliance_results) / len(compliance_results)
|
|
2907
|
+
lines.append(f"- **Average Shelf Score:** {format_percentage(avg_score)}")
|
|
2908
|
+
lines.append("")
|
|
2909
|
+
|
|
2910
|
+
# Detailed Shelf Results
|
|
2911
|
+
lines.append("## Detailed Results by Shelf")
|
|
2912
|
+
lines.append("")
|
|
2913
|
+
|
|
2914
|
+
for result in compliance_results:
|
|
2915
|
+
shelf_emoji = status_emoji(result.compliance_status.value)
|
|
2916
|
+
lines.append(f"### {result.shelf_level.upper().replace('_', ' ')}")
|
|
2917
|
+
lines.append("")
|
|
2918
|
+
lines.append(f"**Status:** {shelf_emoji} {result.compliance_status.value.upper()}")
|
|
2919
|
+
lines.append(f"**Score:** {format_percentage(result.compliance_score)}")
|
|
2920
|
+
lines.append("")
|
|
2921
|
+
|
|
2922
|
+
# Products
|
|
2923
|
+
lines.append("**Expected Products:**")
|
|
2924
|
+
for product in result.expected_products:
|
|
2925
|
+
lines.append(f"- {product}")
|
|
2926
|
+
lines.append("")
|
|
2927
|
+
|
|
2928
|
+
lines.append("**Found Products:**")
|
|
2929
|
+
if result.found_products:
|
|
2930
|
+
for product in result.found_products:
|
|
2931
|
+
lines.append(f"- {product}")
|
|
2932
|
+
else:
|
|
2933
|
+
lines.append("- *(None)*")
|
|
2934
|
+
lines.append("")
|
|
2935
|
+
|
|
2936
|
+
# Missing/Unexpected
|
|
2937
|
+
if result.missing_products:
|
|
2938
|
+
lines.append("**Missing Products:**")
|
|
2939
|
+
for product in result.missing_products:
|
|
2940
|
+
lines.append(f"- ❌ {product}")
|
|
2941
|
+
lines.append("")
|
|
2942
|
+
|
|
2943
|
+
if result.unexpected_products:
|
|
2944
|
+
lines.append("**Unexpected Products:**")
|
|
2945
|
+
for product in result.unexpected_products:
|
|
2946
|
+
lines.append(f"- ⚠️ {product}")
|
|
2947
|
+
lines.append("")
|
|
2948
|
+
|
|
2949
|
+
# Text Compliance
|
|
2950
|
+
if result.text_compliance_results:
|
|
2951
|
+
text_emoji = "✅" if result.overall_text_compliant else "❌"
|
|
2952
|
+
lines.append(f"**Text Compliance:** {text_emoji} {format_percentage(result.text_compliance_score)}")
|
|
2953
|
+
lines.append("")
|
|
2954
|
+
|
|
2955
|
+
for text_result in result.text_compliance_results:
|
|
2956
|
+
req_emoji = "✅" if text_result.found else "❌"
|
|
2957
|
+
lines.append(f"- {req_emoji} '{text_result.required_text}' (confidence: {text_result.confidence:.2f})")
|
|
2958
|
+
if text_result.matched_features:
|
|
2959
|
+
lines.append(f" - Matched: {', '.join(text_result.matched_features)}")
|
|
2960
|
+
lines.append("")
|
|
2961
|
+
|
|
2962
|
+
# Brand Compliance - only show on promotional graphic shelves
|
|
2963
|
+
if (hasattr(result, 'brand_compliance_result') and
|
|
2964
|
+
result.brand_compliance_result and
|
|
2965
|
+
'promotional_graphic' in str(result.expected_products).lower()):
|
|
2966
|
+
brand_emoji = "✅" if result.brand_compliance_result.found else "❌"
|
|
2967
|
+
lines.append(f"**Brand Compliance:** {brand_emoji}")
|
|
2968
|
+
lines.append(f"- Expected: {result.brand_compliance_result.expected_brand}")
|
|
2969
|
+
if result.brand_compliance_result.found_brand:
|
|
2970
|
+
lines.append(f"- Found: {result.brand_compliance_result.found_brand}")
|
|
2971
|
+
lines.append(f"- Confidence: {result.brand_compliance_result.confidence:.2f}")
|
|
2972
|
+
else:
|
|
2973
|
+
lines.append("- Found: *(None)*")
|
|
2974
|
+
lines.append("")
|
|
2975
|
+
|
|
2976
|
+
lines.append("---")
|
|
2977
|
+
lines.append("")
|
|
2978
|
+
|
|
2979
|
+
# Artifacts Section
|
|
2980
|
+
if overlay_path:
|
|
2981
|
+
lines.append("## Analysis Artifacts")
|
|
2982
|
+
lines.append("")
|
|
2983
|
+
lines.append(f"**Overlay Image:** `{overlay_path}`")
|
|
2984
|
+
lines.append("")
|
|
2985
|
+
|
|
2986
|
+
# Add image link if it's a web-accessible path
|
|
2987
|
+
if str(overlay_path).startswith(('http://', 'https://')):
|
|
2988
|
+
lines.append(f"")
|
|
2989
|
+
lines.append("")
|
|
2990
|
+
|
|
2991
|
+
# Additional Notes
|
|
2992
|
+
if additional_notes:
|
|
2993
|
+
lines.append("## Additional Notes")
|
|
2994
|
+
lines.append("")
|
|
2995
|
+
lines.append(additional_notes)
|
|
2996
|
+
lines.append("")
|
|
2997
|
+
|
|
2998
|
+
# Footer
|
|
2999
|
+
lines.append("---")
|
|
3000
|
+
lines.append("*Report generated by AI-Parrot Planogram Compliance Pipeline*")
|
|
3001
|
+
|
|
3002
|
+
return '\n'.join(lines)
|