crca 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CRCA.py +172 -7
- MODEL_CARD.md +53 -0
- PKG-INFO +8 -2
- RELEASE_NOTES.md +17 -0
- STABILITY.md +19 -0
- architecture/hybrid/consistency_engine.py +362 -0
- architecture/hybrid/conversation_manager.py +421 -0
- architecture/hybrid/explanation_generator.py +452 -0
- architecture/hybrid/few_shot_learner.py +533 -0
- architecture/hybrid/graph_compressor.py +286 -0
- architecture/hybrid/hybrid_agent.py +4398 -0
- architecture/hybrid/language_compiler.py +623 -0
- architecture/hybrid/main,py +0 -0
- architecture/hybrid/reasoning_tracker.py +322 -0
- architecture/hybrid/self_verifier.py +524 -0
- architecture/hybrid/task_decomposer.py +567 -0
- architecture/hybrid/text_corrector.py +341 -0
- benchmark_results/crca_core_benchmarks.json +178 -0
- branches/crca_sd/crca_sd_realtime.py +6 -2
- branches/general_agent/__init__.py +102 -0
- branches/general_agent/general_agent.py +1400 -0
- branches/general_agent/personality.py +169 -0
- branches/general_agent/utils/__init__.py +19 -0
- branches/general_agent/utils/prompt_builder.py +170 -0
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/METADATA +8 -2
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/RECORD +303 -20
- crca_core/__init__.py +35 -0
- crca_core/benchmarks/__init__.py +14 -0
- crca_core/benchmarks/synthetic_scm.py +103 -0
- crca_core/core/__init__.py +23 -0
- crca_core/core/api.py +120 -0
- crca_core/core/estimate.py +208 -0
- crca_core/core/godclass.py +72 -0
- crca_core/core/intervention_design.py +174 -0
- crca_core/core/lifecycle.py +48 -0
- crca_core/discovery/__init__.py +9 -0
- crca_core/discovery/tabular.py +193 -0
- crca_core/identify/__init__.py +171 -0
- crca_core/identify/backdoor.py +39 -0
- crca_core/identify/frontdoor.py +48 -0
- crca_core/identify/graph.py +106 -0
- crca_core/identify/id_algorithm.py +43 -0
- crca_core/identify/iv.py +48 -0
- crca_core/models/__init__.py +67 -0
- crca_core/models/provenance.py +56 -0
- crca_core/models/refusal.py +39 -0
- crca_core/models/result.py +83 -0
- crca_core/models/spec.py +151 -0
- crca_core/models/validation.py +68 -0
- crca_core/scm/__init__.py +9 -0
- crca_core/scm/linear_gaussian.py +198 -0
- crca_core/timeseries/__init__.py +6 -0
- crca_core/timeseries/pcmci.py +181 -0
- crca_llm/__init__.py +12 -0
- crca_llm/client.py +85 -0
- crca_llm/coauthor.py +118 -0
- crca_llm/orchestrator.py +289 -0
- crca_llm/types.py +21 -0
- crca_reasoning/__init__.py +16 -0
- crca_reasoning/critique.py +54 -0
- crca_reasoning/godclass.py +206 -0
- crca_reasoning/memory.py +24 -0
- crca_reasoning/rationale.py +10 -0
- crca_reasoning/react_controller.py +81 -0
- crca_reasoning/tool_router.py +97 -0
- crca_reasoning/types.py +40 -0
- crca_sd/__init__.py +15 -0
- crca_sd/crca_sd_core.py +2 -0
- crca_sd/crca_sd_governance.py +2 -0
- crca_sd/crca_sd_mpc.py +2 -0
- crca_sd/crca_sd_realtime.py +2 -0
- crca_sd/crca_sd_tui.py +2 -0
- cuda-keyring_1.1-1_all.deb +0 -0
- cuda-keyring_1.1-1_all.deb.1 +0 -0
- docs/IMAGE_ANNOTATION_USAGE.md +539 -0
- docs/INSTALL_DEEPSPEED.md +125 -0
- docs/api/branches/crca-cg.md +19 -0
- docs/api/branches/crca-q.md +27 -0
- docs/api/branches/crca-sd.md +37 -0
- docs/api/branches/general-agent.md +24 -0
- docs/api/branches/overview.md +19 -0
- docs/api/crca/agent-methods.md +62 -0
- docs/api/crca/operations.md +79 -0
- docs/api/crca/overview.md +32 -0
- docs/api/image-annotation/engine.md +52 -0
- docs/api/image-annotation/overview.md +17 -0
- docs/api/schemas/annotation.md +34 -0
- docs/api/schemas/core-schemas.md +82 -0
- docs/api/schemas/overview.md +32 -0
- docs/api/schemas/policy.md +30 -0
- docs/api/utils/conversation.md +22 -0
- docs/api/utils/graph-reasoner.md +32 -0
- docs/api/utils/overview.md +21 -0
- docs/api/utils/router.md +19 -0
- docs/api/utils/utilities.md +97 -0
- docs/architecture/causal-graphs.md +41 -0
- docs/architecture/data-flow.md +29 -0
- docs/architecture/design-principles.md +33 -0
- docs/architecture/hybrid-agent/components.md +38 -0
- docs/architecture/hybrid-agent/consistency.md +26 -0
- docs/architecture/hybrid-agent/overview.md +44 -0
- docs/architecture/hybrid-agent/reasoning.md +22 -0
- docs/architecture/llm-integration.md +26 -0
- docs/architecture/modular-structure.md +37 -0
- docs/architecture/overview.md +69 -0
- docs/architecture/policy-engine-arch.md +29 -0
- docs/branches/crca-cg/corposwarm.md +39 -0
- docs/branches/crca-cg/esg-scoring.md +30 -0
- docs/branches/crca-cg/multi-agent.md +35 -0
- docs/branches/crca-cg/overview.md +40 -0
- docs/branches/crca-q/alternative-data.md +55 -0
- docs/branches/crca-q/architecture.md +71 -0
- docs/branches/crca-q/backtesting.md +45 -0
- docs/branches/crca-q/causal-engine.md +33 -0
- docs/branches/crca-q/execution.md +39 -0
- docs/branches/crca-q/market-data.md +60 -0
- docs/branches/crca-q/overview.md +58 -0
- docs/branches/crca-q/philosophy.md +60 -0
- docs/branches/crca-q/portfolio-optimization.md +66 -0
- docs/branches/crca-q/risk-management.md +102 -0
- docs/branches/crca-q/setup.md +65 -0
- docs/branches/crca-q/signal-generation.md +61 -0
- docs/branches/crca-q/signal-validation.md +43 -0
- docs/branches/crca-sd/core.md +84 -0
- docs/branches/crca-sd/governance.md +53 -0
- docs/branches/crca-sd/mpc-solver.md +65 -0
- docs/branches/crca-sd/overview.md +59 -0
- docs/branches/crca-sd/realtime.md +28 -0
- docs/branches/crca-sd/tui.md +20 -0
- docs/branches/general-agent/overview.md +37 -0
- docs/branches/general-agent/personality.md +36 -0
- docs/branches/general-agent/prompt-builder.md +30 -0
- docs/changelog/index.md +79 -0
- docs/contributing/code-style.md +69 -0
- docs/contributing/documentation.md +43 -0
- docs/contributing/overview.md +29 -0
- docs/contributing/testing.md +29 -0
- docs/core/crcagent/async-operations.md +65 -0
- docs/core/crcagent/automatic-extraction.md +107 -0
- docs/core/crcagent/batch-prediction.md +80 -0
- docs/core/crcagent/bayesian-inference.md +60 -0
- docs/core/crcagent/causal-graph.md +92 -0
- docs/core/crcagent/counterfactuals.md +96 -0
- docs/core/crcagent/deterministic-simulation.md +78 -0
- docs/core/crcagent/dual-mode-operation.md +82 -0
- docs/core/crcagent/initialization.md +88 -0
- docs/core/crcagent/optimization.md +65 -0
- docs/core/crcagent/overview.md +63 -0
- docs/core/crcagent/time-series.md +57 -0
- docs/core/schemas/annotation.md +30 -0
- docs/core/schemas/core-schemas.md +82 -0
- docs/core/schemas/overview.md +30 -0
- docs/core/schemas/policy.md +41 -0
- docs/core/templates/base-agent.md +31 -0
- docs/core/templates/feature-mixins.md +31 -0
- docs/core/templates/overview.md +29 -0
- docs/core/templates/templates-guide.md +75 -0
- docs/core/tools/mcp-client.md +34 -0
- docs/core/tools/overview.md +24 -0
- docs/core/utils/conversation.md +27 -0
- docs/core/utils/graph-reasoner.md +29 -0
- docs/core/utils/overview.md +27 -0
- docs/core/utils/router.md +27 -0
- docs/core/utils/utilities.md +97 -0
- docs/css/custom.css +84 -0
- docs/examples/basic-usage.md +57 -0
- docs/examples/general-agent/general-agent-examples.md +50 -0
- docs/examples/hybrid-agent/hybrid-agent-examples.md +56 -0
- docs/examples/image-annotation/image-annotation-examples.md +54 -0
- docs/examples/integration/integration-examples.md +58 -0
- docs/examples/overview.md +37 -0
- docs/examples/trading/trading-examples.md +46 -0
- docs/features/causal-reasoning/advanced-topics.md +101 -0
- docs/features/causal-reasoning/counterfactuals.md +43 -0
- docs/features/causal-reasoning/do-calculus.md +50 -0
- docs/features/causal-reasoning/overview.md +47 -0
- docs/features/causal-reasoning/structural-models.md +52 -0
- docs/features/hybrid-agent/advanced-components.md +55 -0
- docs/features/hybrid-agent/core-components.md +64 -0
- docs/features/hybrid-agent/overview.md +34 -0
- docs/features/image-annotation/engine.md +82 -0
- docs/features/image-annotation/features.md +113 -0
- docs/features/image-annotation/integration.md +75 -0
- docs/features/image-annotation/overview.md +53 -0
- docs/features/image-annotation/quickstart.md +73 -0
- docs/features/policy-engine/doctrine-ledger.md +105 -0
- docs/features/policy-engine/monitoring.md +44 -0
- docs/features/policy-engine/mpc-control.md +89 -0
- docs/features/policy-engine/overview.md +46 -0
- docs/getting-started/configuration.md +225 -0
- docs/getting-started/first-agent.md +164 -0
- docs/getting-started/installation.md +144 -0
- docs/getting-started/quickstart.md +137 -0
- docs/index.md +118 -0
- docs/js/mathjax.js +13 -0
- docs/lrm/discovery_proof_notes.md +25 -0
- docs/lrm/finetune_full.md +83 -0
- docs/lrm/math_appendix.md +120 -0
- docs/lrm/overview.md +32 -0
- docs/mkdocs.yml +238 -0
- docs/stylesheets/extra.css +21 -0
- docs_generated/crca_core/CounterfactualResult.md +12 -0
- docs_generated/crca_core/DiscoveryHypothesisResult.md +13 -0
- docs_generated/crca_core/DraftSpec.md +13 -0
- docs_generated/crca_core/EstimateResult.md +13 -0
- docs_generated/crca_core/IdentificationResult.md +17 -0
- docs_generated/crca_core/InterventionDesignResult.md +12 -0
- docs_generated/crca_core/LockedSpec.md +15 -0
- docs_generated/crca_core/RefusalResult.md +12 -0
- docs_generated/crca_core/ValidationReport.md +9 -0
- docs_generated/crca_core/index.md +13 -0
- examples/general_agent_example.py +277 -0
- examples/general_agent_quickstart.py +202 -0
- examples/general_agent_simple.py +92 -0
- examples/hybrid_agent_auto_extraction.py +84 -0
- examples/hybrid_agent_dictionary_demo.py +104 -0
- examples/hybrid_agent_enhanced.py +179 -0
- examples/hybrid_agent_general_knowledge.py +107 -0
- examples/image_annotation_quickstart.py +328 -0
- examples/test_hybrid_fixes.py +77 -0
- image_annotation/__init__.py +27 -0
- image_annotation/annotation_engine.py +2593 -0
- install_cuda_wsl2.sh +59 -0
- install_deepspeed.sh +56 -0
- install_deepspeed_simple.sh +87 -0
- mkdocs.yml +252 -0
- ollama/Modelfile +8 -0
- prompts/__init__.py +2 -1
- prompts/default_crca.py +9 -1
- prompts/general_agent.py +227 -0
- prompts/image_annotation.py +56 -0
- pyproject.toml +17 -2
- requirements-docs.txt +10 -0
- requirements.txt +21 -2
- schemas/__init__.py +26 -1
- schemas/annotation.py +222 -0
- schemas/conversation.py +193 -0
- schemas/hybrid.py +211 -0
- schemas/reasoning.py +276 -0
- schemas_export/crca_core/CounterfactualResult.schema.json +108 -0
- schemas_export/crca_core/DiscoveryHypothesisResult.schema.json +113 -0
- schemas_export/crca_core/DraftSpec.schema.json +635 -0
- schemas_export/crca_core/EstimateResult.schema.json +113 -0
- schemas_export/crca_core/IdentificationResult.schema.json +145 -0
- schemas_export/crca_core/InterventionDesignResult.schema.json +111 -0
- schemas_export/crca_core/LockedSpec.schema.json +646 -0
- schemas_export/crca_core/RefusalResult.schema.json +90 -0
- schemas_export/crca_core/ValidationReport.schema.json +62 -0
- scripts/build_lrm_dataset.py +80 -0
- scripts/export_crca_core_schemas.py +54 -0
- scripts/export_hf_lrm.py +37 -0
- scripts/export_ollama_gguf.py +45 -0
- scripts/generate_changelog.py +157 -0
- scripts/generate_crca_core_docs_from_schemas.py +86 -0
- scripts/run_crca_core_benchmarks.py +163 -0
- scripts/run_full_finetune.py +198 -0
- scripts/run_lrm_eval.py +31 -0
- templates/graph_management.py +29 -0
- tests/conftest.py +9 -0
- tests/test_core.py +2 -3
- tests/test_crca_core_discovery_tabular.py +15 -0
- tests/test_crca_core_estimate_dowhy.py +36 -0
- tests/test_crca_core_identify.py +18 -0
- tests/test_crca_core_intervention_design.py +36 -0
- tests/test_crca_core_linear_gaussian_scm.py +69 -0
- tests/test_crca_core_spec.py +25 -0
- tests/test_crca_core_timeseries_pcmci.py +15 -0
- tests/test_crca_llm_coauthor.py +12 -0
- tests/test_crca_llm_orchestrator.py +80 -0
- tests/test_hybrid_agent_llm_enhanced.py +556 -0
- tests/test_image_annotation_demo.py +376 -0
- tests/test_image_annotation_operational.py +408 -0
- tests/test_image_annotation_unit.py +551 -0
- tests/test_training_moe.py +13 -0
- training/__init__.py +42 -0
- training/datasets.py +140 -0
- training/deepspeed_zero2_0_5b.json +22 -0
- training/deepspeed_zero2_1_5b.json +22 -0
- training/deepspeed_zero3_0_5b.json +28 -0
- training/deepspeed_zero3_14b.json +28 -0
- training/deepspeed_zero3_h100_3gpu.json +20 -0
- training/deepspeed_zero3_offload.json +28 -0
- training/eval.py +92 -0
- training/finetune.py +516 -0
- training/public_datasets.py +89 -0
- training_data/react_train.jsonl +7473 -0
- utils/agent_discovery.py +311 -0
- utils/batch_processor.py +317 -0
- utils/conversation.py +78 -0
- utils/edit_distance.py +118 -0
- utils/formatter.py +33 -0
- utils/graph_reasoner.py +530 -0
- utils/rate_limiter.py +283 -0
- utils/router.py +2 -2
- utils/tool_discovery.py +307 -0
- webui/__init__.py +10 -0
- webui/app.py +229 -0
- webui/config.py +104 -0
- webui/static/css/style.css +332 -0
- webui/static/js/main.js +284 -0
- webui/templates/index.html +42 -0
- tests/test_crca_excel.py +0 -166
- tests/test_data_broker.py +0 -424
- tests/test_palantir.py +0 -349
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/WHEEL +0 -0
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,2593 @@
|
|
|
1
|
+
"""Image Annotation Engine with GPT-4o-mini under adversarial constraints.
|
|
2
|
+
|
|
3
|
+
This module implements a comprehensive image annotation system designed to neutralize
|
|
4
|
+
GPT-4o-mini's failure modes through adversarial containment. The system pre-processes
|
|
5
|
+
images, extracts geometric primitives via OpenCV, restricts GPT-4o-mini to semantic
|
|
6
|
+
labeling only, compiles annotations into a typed graph with contradiction detection,
|
|
7
|
+
uses deterministic math, integrates with CR-CA for failure-aware reasoning, tracks
|
|
8
|
+
temporal coherence, and outputs overlay, formal report, and JSON.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import base64
|
|
12
|
+
import hashlib
|
|
13
|
+
import io
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
21
|
+
from urllib.parse import urlparse
|
|
22
|
+
|
|
23
|
+
import cv2
|
|
24
|
+
import numpy as np
|
|
25
|
+
from loguru import logger
|
|
26
|
+
import rustworkx as rx
|
|
27
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
28
|
+
|
|
29
|
+
# Try to import tqdm for progress bars
|
|
30
|
+
try:
|
|
31
|
+
from tqdm import tqdm
|
|
32
|
+
TQDM_AVAILABLE = True
|
|
33
|
+
except ImportError:
|
|
34
|
+
TQDM_AVAILABLE = False
|
|
35
|
+
logger.warning("tqdm not available, progress bars will be disabled")
|
|
36
|
+
|
|
37
|
+
# Try to import yaml for config files
|
|
38
|
+
try:
|
|
39
|
+
import yaml
|
|
40
|
+
YAML_AVAILABLE = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
YAML_AVAILABLE = False
|
|
43
|
+
|
|
44
|
+
# Try to import requests for URL loading
|
|
45
|
+
try:
|
|
46
|
+
import requests
|
|
47
|
+
REQUESTS_AVAILABLE = True
|
|
48
|
+
except ImportError:
|
|
49
|
+
REQUESTS_AVAILABLE = False
|
|
50
|
+
logger.warning("requests not available, URL loading will be disabled")
|
|
51
|
+
|
|
52
|
+
# Optional dependencies with graceful fallbacks
|
|
53
|
+
try:
|
|
54
|
+
from skimage import exposure
|
|
55
|
+
SKIMAGE_AVAILABLE = True
|
|
56
|
+
except ImportError:
|
|
57
|
+
SKIMAGE_AVAILABLE = False
|
|
58
|
+
logger.warning("scikit-image not available, histogram equalization will be disabled")
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
import pywt
|
|
62
|
+
PYWAVELETS_AVAILABLE = True
|
|
63
|
+
except ImportError:
|
|
64
|
+
PYWAVELETS_AVAILABLE = False
|
|
65
|
+
logger.warning("pywavelets not available, noise-aware downscaling will be disabled")
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
import sympy
|
|
69
|
+
SYMPY_AVAILABLE = True
|
|
70
|
+
except ImportError:
|
|
71
|
+
SYMPY_AVAILABLE = False
|
|
72
|
+
logger.warning("sympy not available, symbolic math will be disabled")
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
import z3
|
|
76
|
+
Z3_AVAILABLE = True
|
|
77
|
+
except ImportError:
|
|
78
|
+
Z3_AVAILABLE = False
|
|
79
|
+
logger.warning("z3-solver not available, constraint solving will be disabled")
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
from filterpy.kalman import KalmanFilter
|
|
83
|
+
FILTERPY_AVAILABLE = True
|
|
84
|
+
except ImportError:
|
|
85
|
+
FILTERPY_AVAILABLE = False
|
|
86
|
+
logger.warning("filterpy not available, temporal tracking will be disabled")
|
|
87
|
+
|
|
88
|
+
# Local imports
|
|
89
|
+
from CRCA import CRCAAgent
|
|
90
|
+
from schemas.annotation import (
|
|
91
|
+
PrimitiveEntity,
|
|
92
|
+
Line,
|
|
93
|
+
Circle,
|
|
94
|
+
Contour,
|
|
95
|
+
Intersection,
|
|
96
|
+
SemanticLabel,
|
|
97
|
+
Relation,
|
|
98
|
+
Contradiction,
|
|
99
|
+
Claim,
|
|
100
|
+
AnnotationGraph,
|
|
101
|
+
AnnotationResult
|
|
102
|
+
)
|
|
103
|
+
from prompts.image_annotation import RESTRICTED_LABELER_SYSTEM_PROMPT
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class AnnotationConfig:
|
|
108
|
+
"""Configuration for image annotation engine.
|
|
109
|
+
|
|
110
|
+
All parameters are optional - None means auto-tune/auto-detect.
|
|
111
|
+
"""
|
|
112
|
+
# Model settings
|
|
113
|
+
gpt_model: str = "gpt-4o-mini"
|
|
114
|
+
use_crca_tools: bool = False
|
|
115
|
+
|
|
116
|
+
# Feature toggles
|
|
117
|
+
enable_temporal_tracking: bool = True
|
|
118
|
+
cache_enabled: bool = True
|
|
119
|
+
auto_retry: bool = True
|
|
120
|
+
auto_tune_params: bool = True
|
|
121
|
+
auto_detect_type: bool = True
|
|
122
|
+
|
|
123
|
+
# Retry settings
|
|
124
|
+
max_retries: int = 3
|
|
125
|
+
retry_backoff: float = 1.5 # Exponential backoff multiplier
|
|
126
|
+
|
|
127
|
+
# Output settings
|
|
128
|
+
output_format: str = "overlay" # "overlay", "json", "report", "all"
|
|
129
|
+
|
|
130
|
+
# Auto-tuned parameters (None = auto-tune)
|
|
131
|
+
opencv_params: Optional[Dict[str, Any]] = None
|
|
132
|
+
preprocessing_params: Optional[Dict[str, Any]] = None
|
|
133
|
+
|
|
134
|
+
# Batch processing
|
|
135
|
+
parallel_workers: Optional[int] = None # None = auto-detect
|
|
136
|
+
show_progress: bool = True
|
|
137
|
+
|
|
138
|
+
# Cache settings
|
|
139
|
+
cache_dir: Optional[str] = None # None = use default .cache directory
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def from_env(cls) -> "AnnotationConfig":
|
|
143
|
+
"""Load configuration from environment variables."""
|
|
144
|
+
config = cls()
|
|
145
|
+
|
|
146
|
+
# Load from environment
|
|
147
|
+
if os.getenv("ANNOTATION_GPT_MODEL"):
|
|
148
|
+
config.gpt_model = os.getenv("ANNOTATION_GPT_MODEL")
|
|
149
|
+
if os.getenv("ANNOTATION_CACHE_ENABLED"):
|
|
150
|
+
config.cache_enabled = os.getenv("ANNOTATION_CACHE_ENABLED").lower() == "true"
|
|
151
|
+
if os.getenv("ANNOTATION_AUTO_RETRY"):
|
|
152
|
+
config.auto_retry = os.getenv("ANNOTATION_AUTO_RETRY").lower() == "true"
|
|
153
|
+
if os.getenv("ANNOTATION_MAX_RETRIES"):
|
|
154
|
+
config.max_retries = int(os.getenv("ANNOTATION_MAX_RETRIES"))
|
|
155
|
+
if os.getenv("ANNOTATION_OUTPUT_FORMAT"):
|
|
156
|
+
config.output_format = os.getenv("ANNOTATION_OUTPUT_FORMAT")
|
|
157
|
+
if os.getenv("ANNOTATION_CACHE_DIR"):
|
|
158
|
+
config.cache_dir = os.getenv("ANNOTATION_CACHE_DIR")
|
|
159
|
+
|
|
160
|
+
return config
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def from_file(cls, config_path: Union[str, Path]) -> "AnnotationConfig":
|
|
164
|
+
"""Load configuration from YAML or JSON file."""
|
|
165
|
+
config_path = Path(config_path)
|
|
166
|
+
|
|
167
|
+
if not config_path.exists():
|
|
168
|
+
logger.warning(f"Config file not found: {config_path}, using defaults")
|
|
169
|
+
return cls()
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
if config_path.suffix in ['.yaml', '.yml']:
|
|
173
|
+
if not YAML_AVAILABLE:
|
|
174
|
+
logger.warning("PyYAML not available, cannot load YAML config")
|
|
175
|
+
return cls()
|
|
176
|
+
with open(config_path, 'r') as f:
|
|
177
|
+
data = yaml.safe_load(f)
|
|
178
|
+
elif config_path.suffix == '.json':
|
|
179
|
+
with open(config_path, 'r') as f:
|
|
180
|
+
data = json.load(f)
|
|
181
|
+
else:
|
|
182
|
+
logger.warning(f"Unsupported config file format: {config_path.suffix}")
|
|
183
|
+
return cls()
|
|
184
|
+
|
|
185
|
+
# Create config from dict
|
|
186
|
+
return cls(**data)
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.error(f"Error loading config file: {e}")
|
|
189
|
+
return cls()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class ImageAnnotationEngine:
|
|
193
|
+
"""Main god-class for image annotation with adversarial constraints.
|
|
194
|
+
|
|
195
|
+
This class orchestrates the entire annotation pipeline:
|
|
196
|
+
1. Image preprocessing (reduce entropy)
|
|
197
|
+
2. Primitive extraction (OpenCV)
|
|
198
|
+
3. Semantic labeling (GPT-4o-mini, restricted)
|
|
199
|
+
4. Graph compilation (rustworkx)
|
|
200
|
+
5. Contradiction detection
|
|
201
|
+
6. Deterministic math computations
|
|
202
|
+
7. Temporal tracking (Kalman filters)
|
|
203
|
+
8. CR-CA integration (failure-aware reasoning)
|
|
204
|
+
9. Output generation (overlay, report, JSON)
|
|
205
|
+
|
|
206
|
+
Attributes:
|
|
207
|
+
gpt_model: GPT model name (default: "gpt-4o-mini")
|
|
208
|
+
enable_temporal_tracking: Whether to enable temporal tracking
|
|
209
|
+
_labeler: CRCAAgent instance for labeling
|
|
210
|
+
_crca_agent: CRCAAgent instance for failure-aware reasoning
|
|
211
|
+
_entity_trackers: Dict mapping entity IDs to Kalman filters
|
|
212
|
+
_frame_history: List of previous frame annotations
|
|
213
|
+
_claims: Dict mapping claim IDs to Claim objects
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
config: Optional[AnnotationConfig] = None,
|
|
219
|
+
gpt_model: Optional[str] = None,
|
|
220
|
+
enable_temporal_tracking: Optional[bool] = None,
|
|
221
|
+
use_crca_tools: Optional[bool] = None,
|
|
222
|
+
cache_enabled: Optional[bool] = None,
|
|
223
|
+
auto_retry: Optional[bool] = None,
|
|
224
|
+
output_format: Optional[str] = None,
|
|
225
|
+
**kwargs
|
|
226
|
+
):
|
|
227
|
+
"""
|
|
228
|
+
Initialize the annotation engine.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
config: Optional AnnotationConfig (if None, loads from env/file/defaults)
|
|
232
|
+
gpt_model: GPT model name (overrides config)
|
|
233
|
+
enable_temporal_tracking: Enable temporal tracking (overrides config)
|
|
234
|
+
use_crca_tools: Whether to enable CRCA tools (overrides config)
|
|
235
|
+
cache_enabled: Enable caching (overrides config)
|
|
236
|
+
auto_retry: Enable auto-retry (overrides config)
|
|
237
|
+
output_format: Output format (overrides config)
|
|
238
|
+
**kwargs: Additional arguments passed to CRCAAgent
|
|
239
|
+
"""
|
|
240
|
+
# Load configuration (priority: explicit params > config > env > defaults)
|
|
241
|
+
if config is None:
|
|
242
|
+
config = self._load_config()
|
|
243
|
+
|
|
244
|
+
# Override with explicit parameters if provided
|
|
245
|
+
if gpt_model is not None:
|
|
246
|
+
config.gpt_model = gpt_model
|
|
247
|
+
if enable_temporal_tracking is not None:
|
|
248
|
+
config.enable_temporal_tracking = enable_temporal_tracking
|
|
249
|
+
if use_crca_tools is not None:
|
|
250
|
+
config.use_crca_tools = use_crca_tools
|
|
251
|
+
if cache_enabled is not None:
|
|
252
|
+
config.cache_enabled = cache_enabled
|
|
253
|
+
if auto_retry is not None:
|
|
254
|
+
config.auto_retry = auto_retry
|
|
255
|
+
if output_format is not None:
|
|
256
|
+
config.output_format = output_format
|
|
257
|
+
|
|
258
|
+
self.config = config
|
|
259
|
+
self.gpt_model = config.gpt_model
|
|
260
|
+
self.enable_temporal_tracking = config.enable_temporal_tracking and FILTERPY_AVAILABLE
|
|
261
|
+
|
|
262
|
+
# Initialize labeling agent (restricted GPT-4o-mini)
|
|
263
|
+
self._labeler = CRCAAgent(
|
|
264
|
+
model_name=config.gpt_model,
|
|
265
|
+
system_prompt=RESTRICTED_LABELER_SYSTEM_PROMPT,
|
|
266
|
+
agent_name="image-annotation-labeler",
|
|
267
|
+
agent_description="Restricted semantic labeler for image annotations",
|
|
268
|
+
use_crca_tools=config.use_crca_tools,
|
|
269
|
+
agent_max_loops=1, # Single pass for labeling
|
|
270
|
+
**kwargs
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Initialize CR-CA agent for failure-aware reasoning
|
|
274
|
+
self._crca_agent = CRCAAgent(
|
|
275
|
+
model_name=config.gpt_model,
|
|
276
|
+
agent_name="image-annotation-crca",
|
|
277
|
+
agent_description="Failure-aware reasoning for image annotations",
|
|
278
|
+
use_crca_tools=True,
|
|
279
|
+
**kwargs
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Temporal tracking state
|
|
283
|
+
self._entity_trackers: Dict[str, Any] = {} # entity_id -> KalmanFilter
|
|
284
|
+
self._frame_history: List[AnnotationGraph] = []
|
|
285
|
+
self._entity_id_map: Dict[str, str] = {} # Maps entity IDs across frames
|
|
286
|
+
self._claims: Dict[str, Claim] = {}
|
|
287
|
+
|
|
288
|
+
# Cache setup
|
|
289
|
+
self._cache: Dict[str, Any] = {} # cache_key -> cached_data
|
|
290
|
+
if config.cache_dir:
|
|
291
|
+
self._cache_dir = Path(config.cache_dir)
|
|
292
|
+
else:
|
|
293
|
+
self._cache_dir = Path.home() / ".cache" / "image_annotation"
|
|
294
|
+
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
|
295
|
+
|
|
296
|
+
logger.info(f"ImageAnnotationEngine initialized with model={config.gpt_model}, temporal_tracking={self.enable_temporal_tracking}, cache={config.cache_enabled}")
|
|
297
|
+
|
|
298
|
+
def _load_config(self) -> AnnotationConfig:
|
|
299
|
+
"""Load configuration from environment variables, config file, or defaults."""
|
|
300
|
+
# Try environment variables first
|
|
301
|
+
config = AnnotationConfig.from_env()
|
|
302
|
+
|
|
303
|
+
# Try config file (config.yaml or config.json in current directory or home)
|
|
304
|
+
config_paths = [
|
|
305
|
+
Path("config.yaml"),
|
|
306
|
+
Path("config.json"),
|
|
307
|
+
Path.home() / ".image_annotation_config.yaml",
|
|
308
|
+
Path.home() / ".image_annotation_config.json"
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
for config_path in config_paths:
|
|
312
|
+
if config_path.exists():
|
|
313
|
+
file_config = AnnotationConfig.from_file(config_path)
|
|
314
|
+
# Merge: file config overrides env config
|
|
315
|
+
for key, value in file_config.__dict__.items():
|
|
316
|
+
if value is not None and value != AnnotationConfig().__dict__[key]:
|
|
317
|
+
setattr(config, key, value)
|
|
318
|
+
break
|
|
319
|
+
|
|
320
|
+
return config
|
|
321
|
+
|
|
322
|
+
# ==================== Smart Input Handling ====================
|
|
323
|
+
|
|
324
|
+
def _detect_input_type(self, input: Any) -> str:
|
|
325
|
+
"""Detect the type of input."""
|
|
326
|
+
if isinstance(input, str):
|
|
327
|
+
if input.startswith(('http://', 'https://')):
|
|
328
|
+
return 'url'
|
|
329
|
+
return 'file_path'
|
|
330
|
+
elif isinstance(input, (Path, os.PathLike)):
|
|
331
|
+
return 'file_path'
|
|
332
|
+
elif isinstance(input, np.ndarray):
|
|
333
|
+
return 'numpy_array'
|
|
334
|
+
elif isinstance(input, Image.Image):
|
|
335
|
+
return 'pil_image'
|
|
336
|
+
elif isinstance(input, list):
|
|
337
|
+
return 'batch'
|
|
338
|
+
else:
|
|
339
|
+
return 'unknown'
|
|
340
|
+
|
|
341
|
+
def _auto_load_input(self, input: Any) -> np.ndarray:
|
|
342
|
+
"""
|
|
343
|
+
Automatically load and convert input to numpy array.
|
|
344
|
+
|
|
345
|
+
Supports: file paths, URLs, numpy arrays, PIL Images
|
|
346
|
+
"""
|
|
347
|
+
input_type = self._detect_input_type(input)
|
|
348
|
+
|
|
349
|
+
if input_type == 'numpy_array':
|
|
350
|
+
return input.copy()
|
|
351
|
+
|
|
352
|
+
elif input_type == 'file_path':
|
|
353
|
+
path = Path(input) if not isinstance(input, Path) else input
|
|
354
|
+
if not path.exists():
|
|
355
|
+
raise FileNotFoundError(f"Image file not found: {path}")
|
|
356
|
+
image = cv2.imread(str(path))
|
|
357
|
+
if image is None:
|
|
358
|
+
raise ValueError(f"Could not load image from: {path}")
|
|
359
|
+
return image
|
|
360
|
+
|
|
361
|
+
elif input_type == 'url':
|
|
362
|
+
if not REQUESTS_AVAILABLE:
|
|
363
|
+
raise ImportError("requests library required for URL loading. Install with: pip install requests")
|
|
364
|
+
response = requests.get(input, timeout=30)
|
|
365
|
+
response.raise_for_status()
|
|
366
|
+
image_array = np.frombuffer(response.content, np.uint8)
|
|
367
|
+
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
|
368
|
+
if image is None:
|
|
369
|
+
raise ValueError(f"Could not decode image from URL: {input}")
|
|
370
|
+
return image
|
|
371
|
+
|
|
372
|
+
elif input_type == 'pil_image':
|
|
373
|
+
# Convert PIL to numpy
|
|
374
|
+
if input.mode != 'RGB':
|
|
375
|
+
input = input.convert('RGB')
|
|
376
|
+
image_array = np.array(input)
|
|
377
|
+
# PIL uses RGB, OpenCV uses BGR
|
|
378
|
+
image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
|
|
379
|
+
return image
|
|
380
|
+
|
|
381
|
+
else:
|
|
382
|
+
raise ValueError(f"Unsupported input type: {type(input)}")
|
|
383
|
+
|
|
384
|
+
def _convert_to_numpy(self, input: Any) -> np.ndarray:
|
|
385
|
+
"""Convert any input to numpy array (alias for _auto_load_input)."""
|
|
386
|
+
return self._auto_load_input(input)
|
|
387
|
+
|
|
388
|
+
# ==================== Image Type Detection ====================
|
|
389
|
+
|
|
390
|
+
def _detect_image_type(self, image: np.ndarray) -> str:
|
|
391
|
+
"""
|
|
392
|
+
Automatically detect image type based on visual characteristics.
|
|
393
|
+
|
|
394
|
+
Returns: "circuit", "architectural", "mathematical", "technical", "general"
|
|
395
|
+
"""
|
|
396
|
+
# Convert to grayscale for analysis
|
|
397
|
+
if len(image.shape) == 3:
|
|
398
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
399
|
+
else:
|
|
400
|
+
gray = image.copy()
|
|
401
|
+
|
|
402
|
+
h, w = gray.shape
|
|
403
|
+
total_pixels = h * w
|
|
404
|
+
|
|
405
|
+
# Analyze image characteristics
|
|
406
|
+
edges = cv2.Canny(gray, 50, 150)
|
|
407
|
+
edge_density = np.sum(edges > 0) / total_pixels
|
|
408
|
+
|
|
409
|
+
# Detect lines
|
|
410
|
+
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=50, minLineLength=30, maxLineGap=10)
|
|
411
|
+
line_count = len(lines) if lines is not None else 0
|
|
412
|
+
line_density = line_count / (total_pixels / 10000) # Normalize
|
|
413
|
+
|
|
414
|
+
# Detect circles
|
|
415
|
+
circles = cv2.HoughCircles(
|
|
416
|
+
gray, cv2.HOUGH_GRADIENT, dp=1, minDist=30,
|
|
417
|
+
param1=50, param2=30, minRadius=5, maxRadius=min(h, w) // 4
|
|
418
|
+
)
|
|
419
|
+
circle_count = len(circles[0]) if circles is not None else 0
|
|
420
|
+
|
|
421
|
+
# Detect contours (closed shapes)
|
|
422
|
+
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
423
|
+
contour_count = len(contours)
|
|
424
|
+
|
|
425
|
+
# Color analysis (for circuit diagrams - often have colored components)
|
|
426
|
+
if len(image.shape) == 3:
|
|
427
|
+
color_variance = np.std(image, axis=2).mean()
|
|
428
|
+
else:
|
|
429
|
+
color_variance = 0
|
|
430
|
+
|
|
431
|
+
# Heuristic rules for image type detection
|
|
432
|
+
# Circuit diagrams: high line density, many circles (components), moderate contours
|
|
433
|
+
if line_density > 0.5 and circle_count > 5 and 10 < contour_count < 100:
|
|
434
|
+
return "circuit"
|
|
435
|
+
|
|
436
|
+
# Architectural drawings: very high line density, many parallel lines, few circles
|
|
437
|
+
if line_density > 1.0 and line_count > 100 and circle_count < 3:
|
|
438
|
+
return "architectural"
|
|
439
|
+
|
|
440
|
+
# Mathematical diagrams: moderate lines, some circles, text-like patterns
|
|
441
|
+
if 0.2 < line_density < 0.8 and circle_count > 2 and contour_count < 50:
|
|
442
|
+
return "mathematical"
|
|
443
|
+
|
|
444
|
+
# Technical drawings: high precision, many small details
|
|
445
|
+
if line_density > 0.8 and contour_count > 50:
|
|
446
|
+
return "technical"
|
|
447
|
+
|
|
448
|
+
# Default to general
|
|
449
|
+
return "general"
|
|
450
|
+
|
|
451
|
+
def _get_type_specific_params(self, image_type: str) -> Dict[str, Any]:
|
|
452
|
+
"""Get type-specific parameters for detection and processing."""
|
|
453
|
+
params = {
|
|
454
|
+
"circuit": {
|
|
455
|
+
"hough_line_threshold": 80,
|
|
456
|
+
"hough_line_min_length": 40,
|
|
457
|
+
"hough_circle_threshold": 30,
|
|
458
|
+
"canny_low": 50,
|
|
459
|
+
"canny_high": 150,
|
|
460
|
+
"preprocessing_strength": 0.7,
|
|
461
|
+
"expected_primitives": ["line", "circle", "contour"]
|
|
462
|
+
},
|
|
463
|
+
"architectural": {
|
|
464
|
+
"hough_line_threshold": 100,
|
|
465
|
+
"hough_line_min_length": 50,
|
|
466
|
+
"hough_circle_threshold": 50,
|
|
467
|
+
"canny_low": 30,
|
|
468
|
+
"canny_high": 100,
|
|
469
|
+
"preprocessing_strength": 0.8,
|
|
470
|
+
"expected_primitives": ["line", "contour"]
|
|
471
|
+
},
|
|
472
|
+
"mathematical": {
|
|
473
|
+
"hough_line_threshold": 70,
|
|
474
|
+
"hough_line_min_length": 30,
|
|
475
|
+
"hough_circle_threshold": 25,
|
|
476
|
+
"canny_low": 40,
|
|
477
|
+
"canny_high": 120,
|
|
478
|
+
"preprocessing_strength": 0.6,
|
|
479
|
+
"expected_primitives": ["line", "circle", "contour"]
|
|
480
|
+
},
|
|
481
|
+
"technical": {
|
|
482
|
+
"hough_line_threshold": 90,
|
|
483
|
+
"hough_line_min_length": 35,
|
|
484
|
+
"hough_circle_threshold": 35,
|
|
485
|
+
"canny_low": 45,
|
|
486
|
+
"canny_high": 130,
|
|
487
|
+
"preprocessing_strength": 0.75,
|
|
488
|
+
"expected_primitives": ["line", "circle", "contour"]
|
|
489
|
+
},
|
|
490
|
+
"general": {
|
|
491
|
+
"hough_line_threshold": 100,
|
|
492
|
+
"hough_line_min_length": 50,
|
|
493
|
+
"hough_circle_threshold": 30,
|
|
494
|
+
"canny_low": 50,
|
|
495
|
+
"canny_high": 150,
|
|
496
|
+
"preprocessing_strength": 0.7,
|
|
497
|
+
"expected_primitives": ["line", "circle", "contour"]
|
|
498
|
+
}
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
return params.get(image_type, params["general"])
|
|
502
|
+
|
|
503
|
+
# ==================== Auto Parameter Tuning ====================
|
|
504
|
+
|
|
505
|
+
def _auto_tune_params(self, image: np.ndarray, image_type: Optional[str] = None) -> Dict[str, Any]:
|
|
506
|
+
"""
|
|
507
|
+
Automatically tune parameters using hybrid strategy.
|
|
508
|
+
|
|
509
|
+
Strategy: Heuristic first, then adaptive refinement if needed.
|
|
510
|
+
"""
|
|
511
|
+
if not self.config.auto_tune_params:
|
|
512
|
+
# Use defaults or config-specified params
|
|
513
|
+
if self.config.opencv_params:
|
|
514
|
+
return self.config.opencv_params
|
|
515
|
+
return self._get_type_specific_params("general")
|
|
516
|
+
|
|
517
|
+
# Detect image type if not provided
|
|
518
|
+
if image_type is None and self.config.auto_detect_type:
|
|
519
|
+
image_type = self._detect_image_type(image)
|
|
520
|
+
|
|
521
|
+
# Start with heuristic-based parameters
|
|
522
|
+
params = self._heuristic_tune(image, image_type or "general")
|
|
523
|
+
|
|
524
|
+
return params
|
|
525
|
+
|
|
526
|
+
def _heuristic_tune(self, image: np.ndarray, image_type: str) -> Dict[str, Any]:
|
|
527
|
+
"""Heuristic-based parameter tuning based on image statistics."""
|
|
528
|
+
# Get base parameters for image type
|
|
529
|
+
params = self._get_type_specific_params(image_type)
|
|
530
|
+
|
|
531
|
+
# Analyze image to adjust parameters
|
|
532
|
+
if len(image.shape) == 3:
|
|
533
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
534
|
+
else:
|
|
535
|
+
gray = image.copy()
|
|
536
|
+
|
|
537
|
+
h, w = gray.shape
|
|
538
|
+
total_pixels = h * w
|
|
539
|
+
|
|
540
|
+
# Adjust based on image size
|
|
541
|
+
if total_pixels < 100000: # Small image
|
|
542
|
+
params["hough_line_threshold"] = int(params["hough_line_threshold"] * 0.7)
|
|
543
|
+
params["hough_line_min_length"] = int(params["hough_line_min_length"] * 0.7)
|
|
544
|
+
elif total_pixels > 1000000: # Large image
|
|
545
|
+
params["hough_line_threshold"] = int(params["hough_line_threshold"] * 1.3)
|
|
546
|
+
params["hough_line_min_length"] = int(params["hough_line_min_length"] * 1.3)
|
|
547
|
+
|
|
548
|
+
# Adjust based on edge density
|
|
549
|
+
edges = cv2.Canny(gray, params["canny_low"], params["canny_high"])
|
|
550
|
+
edge_density = np.sum(edges > 0) / total_pixels
|
|
551
|
+
|
|
552
|
+
if edge_density < 0.1: # Low edge density - lower thresholds
|
|
553
|
+
params["hough_line_threshold"] = int(params["hough_line_threshold"] * 0.8)
|
|
554
|
+
params["canny_low"] = max(20, int(params["canny_low"] * 0.8))
|
|
555
|
+
elif edge_density > 0.3: # High edge density - raise thresholds
|
|
556
|
+
params["hough_line_threshold"] = int(params["hough_line_threshold"] * 1.2)
|
|
557
|
+
params["canny_high"] = min(255, int(params["canny_high"] * 1.2))
|
|
558
|
+
|
|
559
|
+
return params
|
|
560
|
+
|
|
561
|
+
def _adaptive_refine(self, image: np.ndarray, initial_result: AnnotationResult, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
562
|
+
"""Adaptively refine parameters based on initial results."""
|
|
563
|
+
# If no primitives found, relax thresholds
|
|
564
|
+
if len(initial_result.annotation_graph.entities) == 0:
|
|
565
|
+
logger.info("No primitives found, relaxing thresholds")
|
|
566
|
+
params["hough_line_threshold"] = max(30, int(params.get("hough_line_threshold", 100) * 0.7))
|
|
567
|
+
params["hough_circle_threshold"] = max(20, int(params.get("hough_circle_threshold", 30) * 0.7))
|
|
568
|
+
params["canny_low"] = max(20, int(params.get("canny_low", 50) * 0.8))
|
|
569
|
+
|
|
570
|
+
# If too many primitives, raise thresholds
|
|
571
|
+
elif len(initial_result.annotation_graph.entities) > 500:
|
|
572
|
+
logger.info("Too many primitives, raising thresholds")
|
|
573
|
+
params["hough_line_threshold"] = int(params.get("hough_line_threshold", 100) * 1.3)
|
|
574
|
+
params["hough_circle_threshold"] = int(params.get("hough_circle_threshold", 30) * 1.3)
|
|
575
|
+
params["canny_high"] = min(255, int(params.get("canny_high", 150) * 1.2))
|
|
576
|
+
|
|
577
|
+
# If low confidence labels, adjust preprocessing
|
|
578
|
+
if initial_result.annotation_graph.labels:
|
|
579
|
+
avg_uncertainty = np.mean([l.uncertainty for l in initial_result.annotation_graph.labels])
|
|
580
|
+
if avg_uncertainty > 0.7:
|
|
581
|
+
logger.info("Low confidence labels, adjusting preprocessing")
|
|
582
|
+
params["preprocessing_strength"] = min(1.0, params.get("preprocessing_strength", 0.7) * 1.2)
|
|
583
|
+
|
|
584
|
+
return params
|
|
585
|
+
|
|
586
|
+
# ==================== Retry Logic ====================
|
|
587
|
+
|
|
588
|
+
def _should_retry(self, result: AnnotationResult, attempt: int) -> bool:
|
|
589
|
+
"""Determine if annotation should be retried."""
|
|
590
|
+
if not self.config.auto_retry:
|
|
591
|
+
return False
|
|
592
|
+
|
|
593
|
+
if attempt >= self.config.max_retries:
|
|
594
|
+
return False
|
|
595
|
+
|
|
596
|
+
# Retry if no primitives detected
|
|
597
|
+
if len(result.annotation_graph.entities) == 0:
|
|
598
|
+
logger.info(f"Retry {attempt + 1}: No primitives detected")
|
|
599
|
+
return True
|
|
600
|
+
|
|
601
|
+
# Retry if low confidence
|
|
602
|
+
if result.annotation_graph.labels:
|
|
603
|
+
avg_uncertainty = np.mean([l.uncertainty for l in result.annotation_graph.labels])
|
|
604
|
+
if avg_uncertainty > 0.7:
|
|
605
|
+
logger.info(f"Retry {attempt + 1}: Low confidence (avg uncertainty: {avg_uncertainty:.2f})")
|
|
606
|
+
return True
|
|
607
|
+
|
|
608
|
+
# Retry if many contradictions
|
|
609
|
+
if len(result.annotation_graph.contradictions) > 5:
|
|
610
|
+
logger.info(f"Retry {attempt + 1}: Many contradictions ({len(result.annotation_graph.contradictions)})")
|
|
611
|
+
return True
|
|
612
|
+
|
|
613
|
+
return False
|
|
614
|
+
|
|
615
|
+
def _get_retry_params(self, attempt: int, previous_result: Optional[AnnotationResult], base_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
616
|
+
"""Get parameters for retry attempt."""
|
|
617
|
+
params = base_params.copy()
|
|
618
|
+
|
|
619
|
+
# Exponential backoff: relax thresholds more with each retry
|
|
620
|
+
backoff_factor = self.config.retry_backoff ** attempt
|
|
621
|
+
|
|
622
|
+
# Relax thresholds (use defaults if not present)
|
|
623
|
+
params["hough_line_threshold"] = max(30, int(params.get("hough_line_threshold", 100) * (1.0 / backoff_factor)))
|
|
624
|
+
params["hough_circle_threshold"] = max(20, int(params.get("hough_circle_threshold", 30) * (1.0 / backoff_factor)))
|
|
625
|
+
params["canny_low"] = max(20, int(params.get("canny_low", 50) * (1.0 / backoff_factor)))
|
|
626
|
+
|
|
627
|
+
# Adjust preprocessing
|
|
628
|
+
params["preprocessing_strength"] = min(1.0, params.get("preprocessing_strength", 0.7) * (1.0 + 0.1 * attempt))
|
|
629
|
+
|
|
630
|
+
return params
|
|
631
|
+
|
|
632
|
+
def _annotate_with_retry(self, image: np.ndarray, frame_id: Optional[int] = None, params: Optional[Dict[str, Any]] = None) -> AnnotationResult:
|
|
633
|
+
"""Annotate image with automatic retry logic."""
|
|
634
|
+
if params is None:
|
|
635
|
+
params = self._auto_tune_params(image)
|
|
636
|
+
|
|
637
|
+
best_result = None
|
|
638
|
+
best_score = -1
|
|
639
|
+
|
|
640
|
+
for attempt in range(self.config.max_retries):
|
|
641
|
+
try:
|
|
642
|
+
# Use retry parameters if not first attempt
|
|
643
|
+
if attempt > 0:
|
|
644
|
+
params = self._get_retry_params(attempt, best_result, params)
|
|
645
|
+
logger.info(f"Retry attempt {attempt + 1}/{self.config.max_retries} with adjusted parameters")
|
|
646
|
+
|
|
647
|
+
# Perform annotation with current parameters
|
|
648
|
+
result = self._annotate_core(image, frame_id, params)
|
|
649
|
+
|
|
650
|
+
# Score result (higher is better)
|
|
651
|
+
score = self._score_result(result)
|
|
652
|
+
|
|
653
|
+
# Keep best result
|
|
654
|
+
if score > best_score:
|
|
655
|
+
best_result = result
|
|
656
|
+
best_score = score
|
|
657
|
+
|
|
658
|
+
# Check if we should retry
|
|
659
|
+
if not self._should_retry(result, attempt):
|
|
660
|
+
return result
|
|
661
|
+
|
|
662
|
+
# Wait before retry (exponential backoff)
|
|
663
|
+
if attempt < self.config.max_retries - 1:
|
|
664
|
+
wait_time = (self.config.retry_backoff ** attempt) * 0.5
|
|
665
|
+
time.sleep(wait_time)
|
|
666
|
+
|
|
667
|
+
except Exception as e:
|
|
668
|
+
logger.warning(f"Annotation attempt {attempt + 1} failed: {e}")
|
|
669
|
+
if attempt == self.config.max_retries - 1:
|
|
670
|
+
# Last attempt failed, return best result or error
|
|
671
|
+
if best_result:
|
|
672
|
+
return best_result
|
|
673
|
+
raise
|
|
674
|
+
|
|
675
|
+
# Return best result after all retries
|
|
676
|
+
return best_result if best_result else AnnotationResult(
|
|
677
|
+
annotation_graph=AnnotationGraph(),
|
|
678
|
+
overlay_image=None,
|
|
679
|
+
formal_report="Error: All retry attempts failed",
|
|
680
|
+
json_output={"error": "All retry attempts failed"},
|
|
681
|
+
processing_time=0.0
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
def _score_result(self, result: AnnotationResult) -> float:
|
|
685
|
+
"""Score annotation result (higher is better)."""
|
|
686
|
+
score = 0.0
|
|
687
|
+
|
|
688
|
+
# More entities is better
|
|
689
|
+
score += len(result.annotation_graph.entities) * 0.1
|
|
690
|
+
|
|
691
|
+
# More labels is better
|
|
692
|
+
score += len(result.annotation_graph.labels) * 0.2
|
|
693
|
+
|
|
694
|
+
# Lower average uncertainty is better
|
|
695
|
+
if result.annotation_graph.labels:
|
|
696
|
+
avg_uncertainty = np.mean([l.uncertainty for l in result.annotation_graph.labels])
|
|
697
|
+
score += (1.0 - avg_uncertainty) * 0.5
|
|
698
|
+
|
|
699
|
+
# Fewer contradictions is better
|
|
700
|
+
score -= len(result.annotation_graph.contradictions) * 0.3
|
|
701
|
+
|
|
702
|
+
return max(0.0, score)
|
|
703
|
+
|
|
704
|
+
# ==================== Smart Caching ====================
|
|
705
|
+
|
|
706
|
+
def _get_cache_key(self, image: np.ndarray, params: Dict[str, Any]) -> str:
|
|
707
|
+
"""Generate cache key from image content and parameters."""
|
|
708
|
+
# Hash image content
|
|
709
|
+
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
|
710
|
+
image_hash = hashlib.sha256(image_bytes).hexdigest()[:16]
|
|
711
|
+
|
|
712
|
+
# Hash parameters
|
|
713
|
+
params_str = json.dumps(params, sort_keys=True)
|
|
714
|
+
params_hash = hashlib.md5(params_str.encode()).hexdigest()[:8]
|
|
715
|
+
|
|
716
|
+
return f"{image_hash}_{params_hash}"
|
|
717
|
+
|
|
718
|
+
def _get_cached_primitives(self, cache_key: str) -> Optional[List[PrimitiveEntity]]:
|
|
719
|
+
"""Get cached primitives if available."""
|
|
720
|
+
if not self.config.cache_enabled:
|
|
721
|
+
return None
|
|
722
|
+
|
|
723
|
+
cache_file = self._cache_dir / f"{cache_key}_primitives.json"
|
|
724
|
+
if cache_file.exists():
|
|
725
|
+
try:
|
|
726
|
+
with open(cache_file, 'r') as f:
|
|
727
|
+
data = json.load(f)
|
|
728
|
+
primitives = [PrimitiveEntity(**item) for item in data]
|
|
729
|
+
logger.debug(f"Loaded {len(primitives)} primitives from cache")
|
|
730
|
+
return primitives
|
|
731
|
+
except Exception as e:
|
|
732
|
+
logger.warning(f"Error loading cache: {e}")
|
|
733
|
+
|
|
734
|
+
return None
|
|
735
|
+
|
|
736
|
+
def _cache_primitives(self, cache_key: str, primitives: List[PrimitiveEntity]) -> None:
|
|
737
|
+
"""Cache extracted primitives."""
|
|
738
|
+
if not self.config.cache_enabled:
|
|
739
|
+
return
|
|
740
|
+
|
|
741
|
+
try:
|
|
742
|
+
cache_file = self._cache_dir / f"{cache_key}_primitives.json"
|
|
743
|
+
data = [p.model_dump() for p in primitives]
|
|
744
|
+
with open(cache_file, 'w') as f:
|
|
745
|
+
json.dump(data, f)
|
|
746
|
+
logger.debug(f"Cached {len(primitives)} primitives")
|
|
747
|
+
except Exception as e:
|
|
748
|
+
logger.warning(f"Error caching primitives: {e}")
|
|
749
|
+
|
|
750
|
+
# ==================== Image Preprocessing ====================
|
|
751
|
+
|
|
752
|
+
def _preprocess_image(self, image: np.ndarray, params: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
|
753
|
+
"""
|
|
754
|
+
Main preprocessing pipeline to reduce input entropy.
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
image: Input image as numpy array
|
|
758
|
+
params: Optional parameters dict (uses defaults if None)
|
|
759
|
+
|
|
760
|
+
Returns:
|
|
761
|
+
Preprocessed image
|
|
762
|
+
"""
|
|
763
|
+
if params is None:
|
|
764
|
+
params = {}
|
|
765
|
+
|
|
766
|
+
processed = image.copy()
|
|
767
|
+
|
|
768
|
+
# Get preprocessing strength
|
|
769
|
+
preprocessing_strength = params.get("preprocessing_strength", 0.7)
|
|
770
|
+
|
|
771
|
+
# Convert to grayscale if needed
|
|
772
|
+
if len(processed.shape) == 3:
|
|
773
|
+
processed = cv2.cvtColor(processed, cv2.COLOR_BGR2GRAY)
|
|
774
|
+
|
|
775
|
+
# Adaptive histogram equalization (scaled by strength)
|
|
776
|
+
if preprocessing_strength > 0.3:
|
|
777
|
+
processed = self._adaptive_histogram_equalization(processed)
|
|
778
|
+
|
|
779
|
+
# Edge amplification (scaled by strength)
|
|
780
|
+
if preprocessing_strength > 0.5:
|
|
781
|
+
processed = self._edge_amplification(processed, strength=preprocessing_strength)
|
|
782
|
+
|
|
783
|
+
# Noise-aware downscaling (if image is too large)
|
|
784
|
+
max_dimension = 2048
|
|
785
|
+
h, w = processed.shape[:2]
|
|
786
|
+
if max(h, w) > max_dimension:
|
|
787
|
+
target_size = (int(w * max_dimension / max(h, w)), int(h * max_dimension / max(h, w)))
|
|
788
|
+
processed = self._noise_aware_downscale(processed, target_size)
|
|
789
|
+
|
|
790
|
+
return processed
|
|
791
|
+
|
|
792
|
+
def _adaptive_histogram_equalization(self, image: np.ndarray) -> np.ndarray:
|
|
793
|
+
"""Apply adaptive histogram equalization."""
|
|
794
|
+
if not SKIMAGE_AVAILABLE:
|
|
795
|
+
# Fallback to OpenCV CLAHE
|
|
796
|
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
797
|
+
return clahe.apply(image)
|
|
798
|
+
|
|
799
|
+
# Use skimage for better quality
|
|
800
|
+
return exposure.equalize_adapthist(image, clip_limit=0.03)
|
|
801
|
+
|
|
802
|
+
def _edge_amplification(self, image: np.ndarray, strength: float = 0.7) -> np.ndarray:
|
|
803
|
+
"""Amplify edges using Laplacian operator."""
|
|
804
|
+
# Ensure image is in correct format (uint8)
|
|
805
|
+
if image.dtype != np.uint8:
|
|
806
|
+
# Normalize to 0-255 range if float
|
|
807
|
+
if image.max() <= 1.0:
|
|
808
|
+
image = (image * 255).astype(np.uint8)
|
|
809
|
+
else:
|
|
810
|
+
image = np.clip(image, 0, 255).astype(np.uint8)
|
|
811
|
+
|
|
812
|
+
# Apply Laplacian filter
|
|
813
|
+
laplacian = cv2.Laplacian(image, cv2.CV_64F)
|
|
814
|
+
laplacian = np.absolute(laplacian)
|
|
815
|
+
laplacian = np.clip(laplacian, 0, 255).astype(np.uint8)
|
|
816
|
+
|
|
817
|
+
# Combine with original (strength controls amplification)
|
|
818
|
+
edge_weight = min(0.5, strength * 0.5)
|
|
819
|
+
original_weight = 1.0 - edge_weight
|
|
820
|
+
amplified = cv2.addWeighted(image, original_weight, laplacian, edge_weight, 0)
|
|
821
|
+
return amplified
|
|
822
|
+
|
|
823
|
+
def _noise_aware_downscale(self, image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
|
824
|
+
"""Downscale image using wavelets to preserve important features."""
|
|
825
|
+
if not PYWAVELETS_AVAILABLE:
|
|
826
|
+
# Fallback to standard resize
|
|
827
|
+
return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
|
828
|
+
|
|
829
|
+
# Use wavelets for better downscaling
|
|
830
|
+
# Simple approach: resize with area interpolation (wavelets would be more complex)
|
|
831
|
+
return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
|
832
|
+
|
|
833
|
+
# ==================== Vision Primitives Extraction ====================
|
|
834
|
+
|
|
835
|
+
def _extract_lines(self, image: np.ndarray, params: Optional[Dict[str, Any]] = None) -> List[Line]:
|
|
836
|
+
"""
|
|
837
|
+
Extract lines using Hough line detection.
|
|
838
|
+
|
|
839
|
+
Args:
|
|
840
|
+
image: Preprocessed image
|
|
841
|
+
params: Optional parameters dict (uses defaults if None)
|
|
842
|
+
|
|
843
|
+
Returns:
|
|
844
|
+
List of Line objects with pixel coordinates
|
|
845
|
+
"""
|
|
846
|
+
if params is None:
|
|
847
|
+
params = {}
|
|
848
|
+
|
|
849
|
+
lines = []
|
|
850
|
+
|
|
851
|
+
# Get parameters with defaults
|
|
852
|
+
canny_low = params.get("canny_low", 50)
|
|
853
|
+
canny_high = params.get("canny_high", 150)
|
|
854
|
+
hough_threshold = params.get("hough_line_threshold", 100)
|
|
855
|
+
min_line_length = params.get("hough_line_min_length", 50)
|
|
856
|
+
max_line_gap = params.get("hough_line_max_gap", 10)
|
|
857
|
+
|
|
858
|
+
# Edge detection
|
|
859
|
+
edges = cv2.Canny(image, canny_low, canny_high, apertureSize=3)
|
|
860
|
+
|
|
861
|
+
# Hough line detection
|
|
862
|
+
hough_lines = cv2.HoughLinesP(
|
|
863
|
+
edges,
|
|
864
|
+
rho=1,
|
|
865
|
+
theta=np.pi / 180,
|
|
866
|
+
threshold=hough_threshold,
|
|
867
|
+
minLineLength=min_line_length,
|
|
868
|
+
maxLineGap=max_line_gap
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
if hough_lines is not None:
|
|
872
|
+
for line in hough_lines:
|
|
873
|
+
x1, y1, x2, y2 = line[0]
|
|
874
|
+
lines.append(Line(
|
|
875
|
+
start_point=(int(x1), int(y1)),
|
|
876
|
+
end_point=(int(x2), int(y2))
|
|
877
|
+
))
|
|
878
|
+
|
|
879
|
+
logger.debug(f"Extracted {len(lines)} lines")
|
|
880
|
+
return lines
|
|
881
|
+
|
|
882
|
+
def _extract_circles(self, image: np.ndarray, params: Optional[Dict[str, Any]] = None) -> List[Circle]:
|
|
883
|
+
"""
|
|
884
|
+
Extract circles using Hough circle detection.
|
|
885
|
+
|
|
886
|
+
Args:
|
|
887
|
+
image: Preprocessed image
|
|
888
|
+
params: Optional parameters dict (uses defaults if None)
|
|
889
|
+
|
|
890
|
+
Returns:
|
|
891
|
+
List of Circle objects with center and radius
|
|
892
|
+
"""
|
|
893
|
+
if params is None:
|
|
894
|
+
params = {}
|
|
895
|
+
|
|
896
|
+
circles = []
|
|
897
|
+
|
|
898
|
+
# Get parameters with defaults
|
|
899
|
+
hough_threshold = params.get("hough_circle_threshold", 30)
|
|
900
|
+
min_dist = params.get("hough_circle_min_dist", 30)
|
|
901
|
+
min_radius = params.get("hough_circle_min_radius", 10)
|
|
902
|
+
max_radius = params.get("hough_circle_max_radius", 0) # 0 means no maximum
|
|
903
|
+
|
|
904
|
+
h, w = image.shape[:2]
|
|
905
|
+
if max_radius == 0:
|
|
906
|
+
max_radius = min(h, w) // 2
|
|
907
|
+
|
|
908
|
+
# Hough circle detection
|
|
909
|
+
detected_circles = cv2.HoughCircles(
|
|
910
|
+
image,
|
|
911
|
+
cv2.HOUGH_GRADIENT,
|
|
912
|
+
dp=1,
|
|
913
|
+
minDist=min_dist,
|
|
914
|
+
param1=50,
|
|
915
|
+
param2=hough_threshold,
|
|
916
|
+
minRadius=min_radius,
|
|
917
|
+
maxRadius=max_radius
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
if detected_circles is not None:
|
|
921
|
+
detected_circles = np.uint16(np.around(detected_circles))
|
|
922
|
+
for circle in detected_circles[0, :]:
|
|
923
|
+
center_x, center_y, radius = circle
|
|
924
|
+
circles.append(Circle(
|
|
925
|
+
center=(int(center_x), int(center_y)),
|
|
926
|
+
radius=float(radius)
|
|
927
|
+
))
|
|
928
|
+
|
|
929
|
+
logger.debug(f"Extracted {len(circles)} circles")
|
|
930
|
+
return circles
|
|
931
|
+
|
|
932
|
+
def _extract_contours(self, image: np.ndarray, params: Optional[Dict[str, Any]] = None) -> List[Contour]:
|
|
933
|
+
"""
|
|
934
|
+
Extract contours from image.
|
|
935
|
+
|
|
936
|
+
Args:
|
|
937
|
+
image: Preprocessed image
|
|
938
|
+
params: Optional parameters dict (uses defaults if None)
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
List of Contour objects
|
|
942
|
+
"""
|
|
943
|
+
if params is None:
|
|
944
|
+
params = {}
|
|
945
|
+
|
|
946
|
+
contours_list = []
|
|
947
|
+
|
|
948
|
+
# Get parameters with defaults
|
|
949
|
+
canny_low = params.get("canny_low", 50)
|
|
950
|
+
canny_high = params.get("canny_high", 150)
|
|
951
|
+
|
|
952
|
+
# Edge detection
|
|
953
|
+
edges = cv2.Canny(image, canny_low, canny_high)
|
|
954
|
+
|
|
955
|
+
# Find contours
|
|
956
|
+
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
957
|
+
|
|
958
|
+
for contour in contours:
|
|
959
|
+
if len(contour) >= 3: # Need at least 3 points for a polygon
|
|
960
|
+
points = [(int(pt[0][0]), int(pt[0][1])) for pt in contour]
|
|
961
|
+
contours_list.append(Contour(points=points))
|
|
962
|
+
|
|
963
|
+
logger.debug(f"Extracted {len(contours_list)} contours")
|
|
964
|
+
return contours_list
|
|
965
|
+
|
|
966
|
+
def _compute_intersections(self, primitives: List[PrimitiveEntity]) -> List[Intersection]:
|
|
967
|
+
"""
|
|
968
|
+
Compute intersections between primitives.
|
|
969
|
+
|
|
970
|
+
Args:
|
|
971
|
+
primitives: List of primitive entities
|
|
972
|
+
|
|
973
|
+
Returns:
|
|
974
|
+
List of Intersection objects
|
|
975
|
+
"""
|
|
976
|
+
intersections = []
|
|
977
|
+
|
|
978
|
+
# Group primitives by type
|
|
979
|
+
lines = [p for p in primitives if p.primitive_type == "line"]
|
|
980
|
+
circles = [p for p in primitives if p.primitive_type == "circle"]
|
|
981
|
+
|
|
982
|
+
# Line-line intersections
|
|
983
|
+
for i, line1 in enumerate(lines):
|
|
984
|
+
for j, line2 in enumerate(lines[i+1:], start=i+1):
|
|
985
|
+
intersection_point = self._line_line_intersection(
|
|
986
|
+
line1.pixel_coords[0], line1.pixel_coords[1],
|
|
987
|
+
line2.pixel_coords[0], line2.pixel_coords[1]
|
|
988
|
+
)
|
|
989
|
+
if intersection_point:
|
|
990
|
+
intersections.append(Intersection(
|
|
991
|
+
point=intersection_point,
|
|
992
|
+
primitive_ids=[line1.id, line2.id]
|
|
993
|
+
))
|
|
994
|
+
|
|
995
|
+
# Line-circle intersections
|
|
996
|
+
for line in lines:
|
|
997
|
+
for circle in circles:
|
|
998
|
+
if len(line.pixel_coords) >= 2 and len(circle.pixel_coords) >= 1:
|
|
999
|
+
intersection_points = self._line_circle_intersection(
|
|
1000
|
+
line.pixel_coords[0], line.pixel_coords[1],
|
|
1001
|
+
circle.pixel_coords[0], circle.metadata.get("radius", 0)
|
|
1002
|
+
)
|
|
1003
|
+
for point in intersection_points:
|
|
1004
|
+
intersections.append(Intersection(
|
|
1005
|
+
point=point,
|
|
1006
|
+
primitive_ids=[line.id, circle.id]
|
|
1007
|
+
))
|
|
1008
|
+
|
|
1009
|
+
logger.debug(f"Computed {len(intersections)} intersections")
|
|
1010
|
+
return intersections
|
|
1011
|
+
|
|
1012
|
+
def _line_line_intersection(
|
|
1013
|
+
self,
|
|
1014
|
+
p1: Tuple[int, int],
|
|
1015
|
+
p2: Tuple[int, int],
|
|
1016
|
+
p3: Tuple[int, int],
|
|
1017
|
+
p4: Tuple[int, int]
|
|
1018
|
+
) -> Optional[Tuple[int, int]]:
|
|
1019
|
+
"""Compute intersection of two lines."""
|
|
1020
|
+
x1, y1 = p1
|
|
1021
|
+
x2, y2 = p2
|
|
1022
|
+
x3, y3 = p3
|
|
1023
|
+
x4, y4 = p4
|
|
1024
|
+
|
|
1025
|
+
denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
|
|
1026
|
+
if abs(denom) < 1e-10:
|
|
1027
|
+
return None # Lines are parallel
|
|
1028
|
+
|
|
1029
|
+
t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom
|
|
1030
|
+
u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / denom
|
|
1031
|
+
|
|
1032
|
+
if 0 <= t <= 1 and 0 <= u <= 1:
|
|
1033
|
+
x = int(x1 + t * (x2 - x1))
|
|
1034
|
+
y = int(y1 + t * (y2 - y1))
|
|
1035
|
+
return (x, y)
|
|
1036
|
+
|
|
1037
|
+
return None
|
|
1038
|
+
|
|
1039
|
+
def _line_circle_intersection(
|
|
1040
|
+
self,
|
|
1041
|
+
p1: Tuple[int, int],
|
|
1042
|
+
p2: Tuple[int, int],
|
|
1043
|
+
center: Tuple[int, int],
|
|
1044
|
+
radius: float
|
|
1045
|
+
) -> List[Tuple[int, int]]:
|
|
1046
|
+
"""Compute intersection of line and circle."""
|
|
1047
|
+
if radius <= 0:
|
|
1048
|
+
return []
|
|
1049
|
+
|
|
1050
|
+
# Convert to float for computation
|
|
1051
|
+
x1, y1 = float(p1[0]), float(p1[1])
|
|
1052
|
+
x2, y2 = float(p2[0]), float(p2[1])
|
|
1053
|
+
cx, cy = float(center[0]), float(center[1])
|
|
1054
|
+
|
|
1055
|
+
# Line equation: ax + by + c = 0
|
|
1056
|
+
dx = x2 - x1
|
|
1057
|
+
dy = y2 - y1
|
|
1058
|
+
|
|
1059
|
+
if abs(dx) < 1e-10 and abs(dy) < 1e-10:
|
|
1060
|
+
return []
|
|
1061
|
+
|
|
1062
|
+
# Distance from center to line
|
|
1063
|
+
a = dy
|
|
1064
|
+
b = -dx
|
|
1065
|
+
c = dx * y1 - dy * x1
|
|
1066
|
+
|
|
1067
|
+
dist = abs(a * cx + b * cy + c) / np.sqrt(a * a + b * b)
|
|
1068
|
+
|
|
1069
|
+
if dist > radius:
|
|
1070
|
+
return []
|
|
1071
|
+
|
|
1072
|
+
# Compute intersection points (simplified)
|
|
1073
|
+
# This is a simplified version - full implementation would solve quadratic
|
|
1074
|
+
intersections = []
|
|
1075
|
+
|
|
1076
|
+
# Project center onto line
|
|
1077
|
+
if abs(dx) > abs(dy):
|
|
1078
|
+
t = (cx - x1) / dx if abs(dx) > 1e-10 else 0
|
|
1079
|
+
else:
|
|
1080
|
+
t = (cy - y1) / dy if abs(dy) > 1e-10 else 0
|
|
1081
|
+
|
|
1082
|
+
proj_x = x1 + t * dx
|
|
1083
|
+
proj_y = y1 + t * dy
|
|
1084
|
+
|
|
1085
|
+
# Check if projection is on line segment
|
|
1086
|
+
if 0 <= t <= 1:
|
|
1087
|
+
# Distance from projection to intersection
|
|
1088
|
+
h = np.sqrt(radius * radius - dist * dist)
|
|
1089
|
+
|
|
1090
|
+
# Direction vector
|
|
1091
|
+
length = np.sqrt(dx * dx + dy * dy)
|
|
1092
|
+
if length > 1e-10:
|
|
1093
|
+
dir_x = dx / length
|
|
1094
|
+
dir_y = dy / length
|
|
1095
|
+
|
|
1096
|
+
# Two intersection points
|
|
1097
|
+
for sign in [-1, 1]:
|
|
1098
|
+
ix = int(proj_x + sign * h * dir_x)
|
|
1099
|
+
iy = int(proj_y + sign * h * dir_y)
|
|
1100
|
+
intersections.append((ix, iy))
|
|
1101
|
+
|
|
1102
|
+
return intersections
|
|
1103
|
+
|
|
1104
|
+
def _validate_primitive(self, primitive: PrimitiveEntity, image: np.ndarray) -> bool:
|
|
1105
|
+
"""Verify primitive exists in image."""
|
|
1106
|
+
if not primitive.pixel_coords:
|
|
1107
|
+
return False
|
|
1108
|
+
|
|
1109
|
+
h, w = image.shape[:2]
|
|
1110
|
+
for x, y in primitive.pixel_coords:
|
|
1111
|
+
if not (0 <= x < w and 0 <= y < h):
|
|
1112
|
+
return False
|
|
1113
|
+
|
|
1114
|
+
return True
|
|
1115
|
+
|
|
1116
|
+
# ==================== GPT-4o-mini Labeling ====================
|
|
1117
|
+
|
|
1118
|
+
def _label_primitives(
|
|
1119
|
+
self,
|
|
1120
|
+
primitives: List[PrimitiveEntity],
|
|
1121
|
+
image: np.ndarray
|
|
1122
|
+
) -> List[SemanticLabel]:
|
|
1123
|
+
"""
|
|
1124
|
+
Label primitives using restricted GPT-4o-mini.
|
|
1125
|
+
|
|
1126
|
+
Args:
|
|
1127
|
+
primitives: List of primitive entities to label
|
|
1128
|
+
image: Original image (for context)
|
|
1129
|
+
|
|
1130
|
+
Returns:
|
|
1131
|
+
List of semantic labels
|
|
1132
|
+
"""
|
|
1133
|
+
if not primitives:
|
|
1134
|
+
return []
|
|
1135
|
+
|
|
1136
|
+
# Prepare prompt with primitive information
|
|
1137
|
+
primitive_info = []
|
|
1138
|
+
for prim in primitives:
|
|
1139
|
+
prim_type = prim.primitive_type
|
|
1140
|
+
coords = prim.pixel_coords
|
|
1141
|
+
if prim_type == "line" and len(coords) >= 2:
|
|
1142
|
+
info = f"Line {prim.id}: from {coords[0]} to {coords[1]}"
|
|
1143
|
+
elif prim_type == "circle" and len(coords) >= 1:
|
|
1144
|
+
radius = prim.metadata.get("radius", "unknown")
|
|
1145
|
+
info = f"Circle {prim.id}: center {coords[0]}, radius {radius}"
|
|
1146
|
+
elif prim_type == "contour":
|
|
1147
|
+
info = f"Contour {prim.id}: {len(coords)} points"
|
|
1148
|
+
else:
|
|
1149
|
+
info = f"{prim_type} {prim.id}: {len(coords)} coordinates"
|
|
1150
|
+
primitive_info.append(info)
|
|
1151
|
+
|
|
1152
|
+
prompt = f"""Label the following geometric primitives extracted from an image.
|
|
1153
|
+
|
|
1154
|
+
Primitives:
|
|
1155
|
+
{chr(10).join(primitive_info)}
|
|
1156
|
+
|
|
1157
|
+
For each primitive, provide:
|
|
1158
|
+
- entity_id: The ID from the list above
|
|
1159
|
+
- label: A semantic description (e.g., "resistor", "wire", "boundary", "component")
|
|
1160
|
+
- uncertainty: A number between 0.0 (certain) and 1.0 (uncertain)
|
|
1161
|
+
- tentative: false for labels
|
|
1162
|
+
- reasoning: Brief explanation (optional)
|
|
1163
|
+
|
|
1164
|
+
Return your response as a JSON array of labels, where each label has:
|
|
1165
|
+
{{"entity_id": "...", "label": "...", "uncertainty": 0.0-1.0, "tentative": false, "reasoning": "..."}}
|
|
1166
|
+
|
|
1167
|
+
Only label primitives that exist in the provided list. Do not invent new primitives."""
|
|
1168
|
+
|
|
1169
|
+
try:
|
|
1170
|
+
# Convert image to base64 for vision model
|
|
1171
|
+
_, buffer = cv2.imencode('.jpg', image)
|
|
1172
|
+
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
|
1173
|
+
img_data_url = f"data:image/jpeg;base64,{img_base64}"
|
|
1174
|
+
|
|
1175
|
+
# Call GPT-4o-mini
|
|
1176
|
+
response = self._labeler.run(task=prompt, img=img_data_url)
|
|
1177
|
+
|
|
1178
|
+
# Parse response
|
|
1179
|
+
labels = self._parse_label_response(response, primitives)
|
|
1180
|
+
|
|
1181
|
+
# Validate labels (only accept primitives that exist)
|
|
1182
|
+
valid_labels = []
|
|
1183
|
+
primitive_ids = {p.id for p in primitives}
|
|
1184
|
+
for label in labels:
|
|
1185
|
+
if label.entity_id in primitive_ids:
|
|
1186
|
+
valid_labels.append(label)
|
|
1187
|
+
else:
|
|
1188
|
+
logger.warning(f"Label references non-existent entity: {label.entity_id}")
|
|
1189
|
+
|
|
1190
|
+
logger.info(f"Labeled {len(valid_labels)}/{len(primitives)} primitives")
|
|
1191
|
+
return valid_labels
|
|
1192
|
+
|
|
1193
|
+
except Exception as e:
|
|
1194
|
+
logger.error(f"Error labeling primitives: {e}")
|
|
1195
|
+
return []
|
|
1196
|
+
|
|
1197
|
+
def _parse_label_response(
|
|
1198
|
+
self,
|
|
1199
|
+
response: Union[str, Dict, Any],
|
|
1200
|
+
primitives: List[PrimitiveEntity]
|
|
1201
|
+
) -> List[SemanticLabel]:
|
|
1202
|
+
"""Parse GPT response into SemanticLabel objects."""
|
|
1203
|
+
labels = []
|
|
1204
|
+
|
|
1205
|
+
# Extract JSON from response
|
|
1206
|
+
response_str = str(response)
|
|
1207
|
+
|
|
1208
|
+
# Try to find JSON array in response
|
|
1209
|
+
import re
|
|
1210
|
+
json_match = re.search(r'\[.*\]', response_str, re.DOTALL)
|
|
1211
|
+
if json_match:
|
|
1212
|
+
try:
|
|
1213
|
+
label_data = json.loads(json_match.group())
|
|
1214
|
+
for item in label_data:
|
|
1215
|
+
try:
|
|
1216
|
+
label = SemanticLabel(
|
|
1217
|
+
entity_id=item.get("entity_id", ""),
|
|
1218
|
+
label=item.get("label", "unknown"),
|
|
1219
|
+
uncertainty=float(item.get("uncertainty", 0.5)),
|
|
1220
|
+
tentative=bool(item.get("tentative", False)),
|
|
1221
|
+
reasoning=item.get("reasoning")
|
|
1222
|
+
)
|
|
1223
|
+
labels.append(label)
|
|
1224
|
+
except Exception as e:
|
|
1225
|
+
logger.warning(f"Failed to parse label: {e}")
|
|
1226
|
+
except json.JSONDecodeError as e:
|
|
1227
|
+
logger.warning(f"Failed to parse JSON from response: {e}")
|
|
1228
|
+
|
|
1229
|
+
return labels
|
|
1230
|
+
|
|
1231
|
+
def _suggest_relations(self, entities: List[PrimitiveEntity]) -> List[Relation]:
|
|
1232
|
+
"""
|
|
1233
|
+
Suggest relations between entities (flagged as tentative).
|
|
1234
|
+
|
|
1235
|
+
Args:
|
|
1236
|
+
entities: List of entities
|
|
1237
|
+
|
|
1238
|
+
Returns:
|
|
1239
|
+
List of tentative relations
|
|
1240
|
+
"""
|
|
1241
|
+
# This would use GPT-4o-mini to suggest relations
|
|
1242
|
+
# For now, return empty list (can be implemented later)
|
|
1243
|
+
return []
|
|
1244
|
+
|
|
1245
|
+
# ==================== Graph Compilation ====================
|
|
1246
|
+
|
|
1247
|
+
def _compile_graph(self, annotations: AnnotationGraph) -> rx.PyDiGraph:
|
|
1248
|
+
"""
|
|
1249
|
+
Compile annotations into typed graph using rustworkx.
|
|
1250
|
+
|
|
1251
|
+
Args:
|
|
1252
|
+
annotations: Annotation graph
|
|
1253
|
+
|
|
1254
|
+
Returns:
|
|
1255
|
+
rustworkx directed graph
|
|
1256
|
+
"""
|
|
1257
|
+
graph = rx.PyDiGraph()
|
|
1258
|
+
|
|
1259
|
+
# Add nodes (entities)
|
|
1260
|
+
node_map = {} # entity_id -> node index
|
|
1261
|
+
for entity in annotations.entities:
|
|
1262
|
+
node_idx = graph.add_node(entity)
|
|
1263
|
+
node_map[entity.id] = node_idx
|
|
1264
|
+
|
|
1265
|
+
# Add edges (relations)
|
|
1266
|
+
for relation in annotations.relations:
|
|
1267
|
+
if relation.source_id in node_map and relation.target_id in node_map:
|
|
1268
|
+
source_idx = node_map[relation.source_id]
|
|
1269
|
+
target_idx = node_map[relation.target_id]
|
|
1270
|
+
graph.add_edge(source_idx, target_idx, relation)
|
|
1271
|
+
|
|
1272
|
+
return graph
|
|
1273
|
+
|
|
1274
|
+
def _detect_cycles(self, graph: rx.PyDiGraph) -> List[List[str]]:
|
|
1275
|
+
"""Detect cycles in the graph."""
|
|
1276
|
+
cycles = []
|
|
1277
|
+
|
|
1278
|
+
# Simple cycle detection using DFS
|
|
1279
|
+
# This is a simplified version - rustworkx has cycle detection methods
|
|
1280
|
+
try:
|
|
1281
|
+
# Check if graph is acyclic
|
|
1282
|
+
if not rx.is_directed_acyclic_graph(graph):
|
|
1283
|
+
# Find cycles (simplified - would need proper cycle enumeration)
|
|
1284
|
+
cycles.append(["cycle_detected"])
|
|
1285
|
+
except Exception as e:
|
|
1286
|
+
logger.warning(f"Cycle detection error: {e}")
|
|
1287
|
+
|
|
1288
|
+
return cycles
|
|
1289
|
+
|
|
1290
|
+
def _detect_mutually_exclusive(self, graph: rx.PyDiGraph) -> List[Tuple[str, str]]:
|
|
1291
|
+
"""Detect mutually exclusive relations."""
|
|
1292
|
+
# This would check for conflicting relation types
|
|
1293
|
+
# For now, return empty list
|
|
1294
|
+
return []
|
|
1295
|
+
|
|
1296
|
+
def _detect_unsupported_relations(
|
|
1297
|
+
self,
|
|
1298
|
+
graph: rx.PyDiGraph,
|
|
1299
|
+
primitives: List[PrimitiveEntity]
|
|
1300
|
+
) -> List[str]:
|
|
1301
|
+
"""Detect relations that reference non-existent primitives."""
|
|
1302
|
+
unsupported = []
|
|
1303
|
+
primitive_ids = {p.id for p in primitives}
|
|
1304
|
+
|
|
1305
|
+
# Check all edges
|
|
1306
|
+
for edge in graph.edge_list():
|
|
1307
|
+
source_idx, target_idx = edge
|
|
1308
|
+
source_node = graph[source_idx]
|
|
1309
|
+
target_node = graph[target_idx]
|
|
1310
|
+
|
|
1311
|
+
if isinstance(source_node, PrimitiveEntity) and isinstance(target_node, PrimitiveEntity):
|
|
1312
|
+
if source_node.id not in primitive_ids or target_node.id not in primitive_ids:
|
|
1313
|
+
unsupported.append(f"relation_{source_node.id}_{target_node.id}")
|
|
1314
|
+
|
|
1315
|
+
return unsupported
|
|
1316
|
+
|
|
1317
|
+
def _validate_graph(self, graph: rx.PyDiGraph) -> Dict[str, Any]:
|
|
1318
|
+
"""Comprehensive graph validation."""
|
|
1319
|
+
validation_result = {
|
|
1320
|
+
"is_valid": True,
|
|
1321
|
+
"cycles": [],
|
|
1322
|
+
"mutually_exclusive": [],
|
|
1323
|
+
"unsupported_relations": [],
|
|
1324
|
+
"warnings": []
|
|
1325
|
+
}
|
|
1326
|
+
|
|
1327
|
+
# Check for cycles
|
|
1328
|
+
cycles = self._detect_cycles(graph)
|
|
1329
|
+
if cycles:
|
|
1330
|
+
validation_result["is_valid"] = False
|
|
1331
|
+
validation_result["cycles"] = cycles
|
|
1332
|
+
|
|
1333
|
+
# Check for mutually exclusive relations
|
|
1334
|
+
mutually_exclusive = self._detect_mutually_exclusive(graph)
|
|
1335
|
+
if mutually_exclusive:
|
|
1336
|
+
validation_result["mutually_exclusive"] = mutually_exclusive
|
|
1337
|
+
|
|
1338
|
+
return validation_result
|
|
1339
|
+
|
|
1340
|
+
# ==================== Deterministic Math Layer ====================
|
|
1341
|
+
|
|
1342
|
+
def _compute_distance(
|
|
1343
|
+
self,
|
|
1344
|
+
point1: Tuple[float, float],
|
|
1345
|
+
point2: Tuple[float, float]
|
|
1346
|
+
) -> float:
|
|
1347
|
+
"""Compute Euclidean distance between two points."""
|
|
1348
|
+
return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)
|
|
1349
|
+
|
|
1350
|
+
def _compute_angle(self, line1: Line, line2: Line) -> float:
|
|
1351
|
+
"""Compute angle between two lines in degrees."""
|
|
1352
|
+
# Get direction vectors
|
|
1353
|
+
dx1 = line1.end_point[0] - line1.start_point[0]
|
|
1354
|
+
dy1 = line1.end_point[1] - line1.start_point[1]
|
|
1355
|
+
dx2 = line2.end_point[0] - line2.start_point[0]
|
|
1356
|
+
dy2 = line2.end_point[1] - line2.start_point[1]
|
|
1357
|
+
|
|
1358
|
+
# Compute angle using dot product
|
|
1359
|
+
dot = dx1 * dx2 + dy1 * dy2
|
|
1360
|
+
mag1 = np.sqrt(dx1**2 + dy1**2)
|
|
1361
|
+
mag2 = np.sqrt(dx2**2 + dy2**2)
|
|
1362
|
+
|
|
1363
|
+
if mag1 < 1e-10 or mag2 < 1e-10:
|
|
1364
|
+
return 0.0
|
|
1365
|
+
|
|
1366
|
+
cos_angle = dot / (mag1 * mag2)
|
|
1367
|
+
cos_angle = np.clip(cos_angle, -1.0, 1.0)
|
|
1368
|
+
angle_rad = np.arccos(cos_angle)
|
|
1369
|
+
angle_deg = np.degrees(angle_rad)
|
|
1370
|
+
|
|
1371
|
+
return angle_deg
|
|
1372
|
+
|
|
1373
|
+
def _solve_constraints(self, constraints: List[Any]) -> Dict[str, float]:
|
|
1374
|
+
"""Solve constraints using z3-solver."""
|
|
1375
|
+
if not Z3_AVAILABLE:
|
|
1376
|
+
logger.warning("z3-solver not available, constraint solving disabled")
|
|
1377
|
+
return {}
|
|
1378
|
+
|
|
1379
|
+
# This would use z3 to solve constraints
|
|
1380
|
+
# Simplified implementation
|
|
1381
|
+
return {}
|
|
1382
|
+
|
|
1383
|
+
def _verify_geometry(self, primitives: List[PrimitiveEntity]) -> bool:
|
|
1384
|
+
"""Check geometric consistency."""
|
|
1385
|
+
# Basic validation: check that coordinates are valid
|
|
1386
|
+
for prim in primitives:
|
|
1387
|
+
if not self._validate_primitive(prim, np.zeros((100, 100), dtype=np.uint8)):
|
|
1388
|
+
return False
|
|
1389
|
+
return True
|
|
1390
|
+
|
|
1391
|
+
# ==================== CR-CA Integration ====================
|
|
1392
|
+
|
|
1393
|
+
def _create_claim(
|
|
1394
|
+
self,
|
|
1395
|
+
annotation: SemanticLabel,
|
|
1396
|
+
dependencies: List[str]
|
|
1397
|
+
) -> Claim:
|
|
1398
|
+
"""Create a claim for failure-aware reasoning."""
|
|
1399
|
+
claim = Claim(
|
|
1400
|
+
annotation=annotation,
|
|
1401
|
+
dependencies=dependencies,
|
|
1402
|
+
robustness_score=1.0
|
|
1403
|
+
)
|
|
1404
|
+
self._claims[claim.claim_id] = claim
|
|
1405
|
+
return claim
|
|
1406
|
+
|
|
1407
|
+
def _trace_dependencies(self, claim_id: str) -> List[str]:
|
|
1408
|
+
"""Trace dependency chain for a claim."""
|
|
1409
|
+
if claim_id not in self._claims:
|
|
1410
|
+
return []
|
|
1411
|
+
|
|
1412
|
+
visited = set()
|
|
1413
|
+
dependencies = []
|
|
1414
|
+
|
|
1415
|
+
def _collect_deps(cid: str):
|
|
1416
|
+
if cid in visited or cid not in self._claims:
|
|
1417
|
+
return
|
|
1418
|
+
visited.add(cid)
|
|
1419
|
+
claim = self._claims[cid]
|
|
1420
|
+
for dep_id in claim.dependencies:
|
|
1421
|
+
if dep_id not in visited:
|
|
1422
|
+
_collect_deps(dep_id)
|
|
1423
|
+
dependencies.append(dep_id)
|
|
1424
|
+
|
|
1425
|
+
_collect_deps(claim_id)
|
|
1426
|
+
return dependencies
|
|
1427
|
+
|
|
1428
|
+
def _counterfactual_explanation(
|
|
1429
|
+
self,
|
|
1430
|
+
claim_id: str,
|
|
1431
|
+
removed_dependency: str
|
|
1432
|
+
) -> str:
|
|
1433
|
+
"""Generate counterfactual explanation."""
|
|
1434
|
+
if claim_id not in self._claims:
|
|
1435
|
+
return "Claim not found"
|
|
1436
|
+
|
|
1437
|
+
claim = self._claims[claim_id]
|
|
1438
|
+
if removed_dependency not in claim.dependencies:
|
|
1439
|
+
return "Dependency not found in claim"
|
|
1440
|
+
|
|
1441
|
+
return f"If dependency {removed_dependency} were removed, claim {claim_id} would be invalidated."
|
|
1442
|
+
|
|
1443
|
+
def _robustness_analysis(self, claims: List[Claim]) -> Dict[str, float]:
|
|
1444
|
+
"""Analyze robustness of claims."""
|
|
1445
|
+
robustness_scores = {}
|
|
1446
|
+
|
|
1447
|
+
for claim in claims:
|
|
1448
|
+
# Simple robustness: inverse of uncertainty and dependency count
|
|
1449
|
+
uncertainty_penalty = claim.annotation.uncertainty
|
|
1450
|
+
dependency_penalty = len(claim.dependencies) * 0.1
|
|
1451
|
+
|
|
1452
|
+
robustness = max(0.0, 1.0 - uncertainty_penalty - dependency_penalty)
|
|
1453
|
+
robustness_scores[claim.claim_id] = robustness
|
|
1454
|
+
claim.robustness_score = robustness
|
|
1455
|
+
|
|
1456
|
+
return robustness_scores
|
|
1457
|
+
|
|
1458
|
+
# ==================== Temporal Coherence ====================
|
|
1459
|
+
|
|
1460
|
+
def _track_entity(
|
|
1461
|
+
self,
|
|
1462
|
+
entity: PrimitiveEntity,
|
|
1463
|
+
frame_id: int
|
|
1464
|
+
) -> str:
|
|
1465
|
+
"""Track entity across frames using Kalman filter."""
|
|
1466
|
+
if not self.enable_temporal_tracking:
|
|
1467
|
+
return entity.id
|
|
1468
|
+
|
|
1469
|
+
# Get or create Kalman filter for entity
|
|
1470
|
+
if entity.id not in self._entity_trackers:
|
|
1471
|
+
if FILTERPY_AVAILABLE:
|
|
1472
|
+
kf = KalmanFilter(dim_x=4, dim_z=2) # x, y, vx, vy
|
|
1473
|
+
# Initialize with entity position
|
|
1474
|
+
if entity.pixel_coords:
|
|
1475
|
+
center = self._get_entity_center(entity)
|
|
1476
|
+
kf.x = np.array([center[0], center[1], 0.0, 0.0])
|
|
1477
|
+
kf.P *= 1000.0 # Initial uncertainty
|
|
1478
|
+
kf.R = np.eye(2) * 5 # Measurement noise
|
|
1479
|
+
kf.Q = np.eye(4) * 0.1 # Process noise
|
|
1480
|
+
self._entity_trackers[entity.id] = kf
|
|
1481
|
+
else:
|
|
1482
|
+
# Fallback: simple ID mapping
|
|
1483
|
+
self._entity_trackers[entity.id] = {"center": self._get_entity_center(entity)}
|
|
1484
|
+
|
|
1485
|
+
# Update tracker
|
|
1486
|
+
if FILTERPY_AVAILABLE and entity.id in self._entity_trackers:
|
|
1487
|
+
kf = self._entity_trackers[entity.id]
|
|
1488
|
+
if isinstance(kf, KalmanFilter) and entity.pixel_coords:
|
|
1489
|
+
center = self._get_entity_center(entity)
|
|
1490
|
+
kf.predict()
|
|
1491
|
+
kf.update(np.array([center[0], center[1]]))
|
|
1492
|
+
|
|
1493
|
+
return entity.id
|
|
1494
|
+
|
|
1495
|
+
def _get_entity_center(self, entity: PrimitiveEntity) -> Tuple[float, float]:
|
|
1496
|
+
"""Get center point of entity."""
|
|
1497
|
+
if not entity.pixel_coords:
|
|
1498
|
+
return (0.0, 0.0)
|
|
1499
|
+
|
|
1500
|
+
if entity.primitive_type == "circle" and len(entity.pixel_coords) >= 1:
|
|
1501
|
+
return (float(entity.pixel_coords[0][0]), float(entity.pixel_coords[0][1]))
|
|
1502
|
+
elif entity.primitive_type == "line" and len(entity.pixel_coords) >= 2:
|
|
1503
|
+
x1, y1 = entity.pixel_coords[0]
|
|
1504
|
+
x2, y2 = entity.pixel_coords[1]
|
|
1505
|
+
return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)
|
|
1506
|
+
else:
|
|
1507
|
+
# Use centroid of all points
|
|
1508
|
+
xs = [p[0] for p in entity.pixel_coords]
|
|
1509
|
+
ys = [p[1] for p in entity.pixel_coords]
|
|
1510
|
+
return (float(np.mean(xs)), float(np.mean(ys)))
|
|
1511
|
+
|
|
1512
|
+
def _predict_next_position(self, entity_id: str) -> Tuple[float, float]:
|
|
1513
|
+
"""Predict next position using Kalman filter."""
|
|
1514
|
+
if not self.enable_temporal_tracking or entity_id not in self._entity_trackers:
|
|
1515
|
+
return (0.0, 0.0)
|
|
1516
|
+
|
|
1517
|
+
tracker = self._entity_trackers[entity_id]
|
|
1518
|
+
if FILTERPY_AVAILABLE and isinstance(tracker, KalmanFilter):
|
|
1519
|
+
tracker.predict()
|
|
1520
|
+
return (float(tracker.x[0]), float(tracker.x[1]))
|
|
1521
|
+
elif isinstance(tracker, dict) and "center" in tracker:
|
|
1522
|
+
return tracker["center"]
|
|
1523
|
+
|
|
1524
|
+
return (0.0, 0.0)
|
|
1525
|
+
|
|
1526
|
+
def _check_continuity(
|
|
1527
|
+
self,
|
|
1528
|
+
entity_id: str,
|
|
1529
|
+
new_annotation: SemanticLabel
|
|
1530
|
+
) -> bool:
|
|
1531
|
+
"""Verify consistency of annotation across frames."""
|
|
1532
|
+
# Check if label changed significantly
|
|
1533
|
+
if entity_id in self._claims:
|
|
1534
|
+
old_claim = self._claims[entity_id]
|
|
1535
|
+
old_label = old_claim.annotation.label
|
|
1536
|
+
new_label = new_annotation.label
|
|
1537
|
+
|
|
1538
|
+
# Simple check: if labels are very different, flag as inconsistent
|
|
1539
|
+
if old_label.lower() != new_label.lower():
|
|
1540
|
+
return False
|
|
1541
|
+
|
|
1542
|
+
return True
|
|
1543
|
+
|
|
1544
|
+
def _detect_instability(self, entity_id: str) -> Dict[str, Any]:
|
|
1545
|
+
"""Detect temporal instability (angle jumps, position shifts)."""
|
|
1546
|
+
instability = {
|
|
1547
|
+
"detected": False,
|
|
1548
|
+
"reason": None,
|
|
1549
|
+
"metrics": {}
|
|
1550
|
+
}
|
|
1551
|
+
|
|
1552
|
+
if not self.enable_temporal_tracking or entity_id not in self._entity_trackers:
|
|
1553
|
+
return instability
|
|
1554
|
+
|
|
1555
|
+
# Check position variance
|
|
1556
|
+
if len(self._frame_history) >= 2:
|
|
1557
|
+
# Compare positions across frames
|
|
1558
|
+
positions = []
|
|
1559
|
+
for graph in self._frame_history[-5:]: # Last 5 frames
|
|
1560
|
+
entity = graph.get_entity_by_id(entity_id)
|
|
1561
|
+
if entity:
|
|
1562
|
+
center = self._get_entity_center(entity)
|
|
1563
|
+
positions.append(center)
|
|
1564
|
+
|
|
1565
|
+
if len(positions) >= 2:
|
|
1566
|
+
# Compute position variance
|
|
1567
|
+
positions_array = np.array(positions)
|
|
1568
|
+
variance = np.var(positions_array, axis=0)
|
|
1569
|
+
max_variance = np.max(variance)
|
|
1570
|
+
|
|
1571
|
+
if max_variance > 100.0: # Threshold
|
|
1572
|
+
instability["detected"] = True
|
|
1573
|
+
instability["reason"] = f"High position variance: {max_variance:.2f}"
|
|
1574
|
+
instability["metrics"]["position_variance"] = float(max_variance)
|
|
1575
|
+
|
|
1576
|
+
return instability
|
|
1577
|
+
|
|
1578
|
+
# ==================== Output Generation ====================
|
|
1579
|
+
|
|
1580
|
+
def _generate_overlay(
|
|
1581
|
+
self,
|
|
1582
|
+
image: np.ndarray,
|
|
1583
|
+
annotations: AnnotationGraph
|
|
1584
|
+
) -> np.ndarray:
|
|
1585
|
+
"""Generate overlay image with annotations drawn."""
|
|
1586
|
+
overlay = image.copy()
|
|
1587
|
+
|
|
1588
|
+
# Draw entities
|
|
1589
|
+
for entity in annotations.entities:
|
|
1590
|
+
if entity.primitive_type == "line" and len(entity.pixel_coords) >= 2:
|
|
1591
|
+
cv2.line(
|
|
1592
|
+
overlay,
|
|
1593
|
+
entity.pixel_coords[0],
|
|
1594
|
+
entity.pixel_coords[1],
|
|
1595
|
+
(0, 255, 0),
|
|
1596
|
+
2
|
|
1597
|
+
)
|
|
1598
|
+
elif entity.primitive_type == "circle" and len(entity.pixel_coords) >= 1:
|
|
1599
|
+
center = entity.pixel_coords[0]
|
|
1600
|
+
radius = int(entity.metadata.get("radius", 10))
|
|
1601
|
+
cv2.circle(overlay, center, radius, (255, 0, 0), 2)
|
|
1602
|
+
elif entity.primitive_type == "contour":
|
|
1603
|
+
points = np.array(entity.pixel_coords, dtype=np.int32)
|
|
1604
|
+
cv2.polylines(overlay, [points], True, (0, 0, 255), 2)
|
|
1605
|
+
|
|
1606
|
+
# Draw labels
|
|
1607
|
+
for label in annotations.labels:
|
|
1608
|
+
entity = annotations.get_entity_by_id(label.entity_id)
|
|
1609
|
+
if entity and entity.pixel_coords:
|
|
1610
|
+
center = self._get_entity_center(entity)
|
|
1611
|
+
center_int = (int(center[0]), int(center[1]))
|
|
1612
|
+
cv2.putText(
|
|
1613
|
+
overlay,
|
|
1614
|
+
label.label,
|
|
1615
|
+
center_int,
|
|
1616
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1617
|
+
0.5,
|
|
1618
|
+
(255, 255, 255),
|
|
1619
|
+
1
|
|
1620
|
+
)
|
|
1621
|
+
|
|
1622
|
+
return overlay
|
|
1623
|
+
|
|
1624
|
+
def _generate_formal_report(
|
|
1625
|
+
self,
|
|
1626
|
+
annotations: AnnotationGraph,
|
|
1627
|
+
contradictions: List[Contradiction]
|
|
1628
|
+
) -> str:
|
|
1629
|
+
"""Generate structured formal report."""
|
|
1630
|
+
report_lines = []
|
|
1631
|
+
|
|
1632
|
+
report_lines.append("=" * 80)
|
|
1633
|
+
report_lines.append("FORMAL ANNOTATION REPORT")
|
|
1634
|
+
report_lines.append("=" * 80)
|
|
1635
|
+
report_lines.append("")
|
|
1636
|
+
|
|
1637
|
+
# KNOWN section
|
|
1638
|
+
report_lines.append("KNOWN:")
|
|
1639
|
+
report_lines.append("-" * 80)
|
|
1640
|
+
for entity in annotations.entities:
|
|
1641
|
+
coords_str = ", ".join([f"({x},{y})" for x, y in entity.pixel_coords[:5]])
|
|
1642
|
+
if len(entity.pixel_coords) > 5:
|
|
1643
|
+
coords_str += f" ... ({len(entity.pixel_coords)} total points)"
|
|
1644
|
+
report_lines.append(f" {entity.primitive_type.upper()} {entity.id}: {coords_str}")
|
|
1645
|
+
|
|
1646
|
+
# Measured quantities
|
|
1647
|
+
report_lines.append("")
|
|
1648
|
+
report_lines.append(" Measured quantities:")
|
|
1649
|
+
for i, entity1 in enumerate(annotations.entities):
|
|
1650
|
+
for entity2 in annotations.entities[i+1:]:
|
|
1651
|
+
if len(entity1.pixel_coords) >= 1 and len(entity2.pixel_coords) >= 1:
|
|
1652
|
+
center1 = self._get_entity_center(entity1)
|
|
1653
|
+
center2 = self._get_entity_center(entity2)
|
|
1654
|
+
distance = self._compute_distance(center1, center2)
|
|
1655
|
+
report_lines.append(f" Distance {entity1.id} <-> {entity2.id}: {distance:.2f} pixels")
|
|
1656
|
+
|
|
1657
|
+
# ASSUMED section
|
|
1658
|
+
report_lines.append("")
|
|
1659
|
+
report_lines.append("ASSUMED:")
|
|
1660
|
+
report_lines.append("-" * 80)
|
|
1661
|
+
tentative_labels = [l for l in annotations.labels if l.tentative]
|
|
1662
|
+
tentative_relations = [r for r in annotations.relations if r.tentative]
|
|
1663
|
+
|
|
1664
|
+
for label in tentative_labels:
|
|
1665
|
+
report_lines.append(f" Label '{label.label}' for entity {label.entity_id} (uncertainty: {label.uncertainty:.2f})")
|
|
1666
|
+
|
|
1667
|
+
for relation in tentative_relations:
|
|
1668
|
+
report_lines.append(f" Relation {relation.relation_type} between {relation.source_id} and {relation.target_id}")
|
|
1669
|
+
|
|
1670
|
+
# AMBIGUOUS section
|
|
1671
|
+
report_lines.append("")
|
|
1672
|
+
report_lines.append("AMBIGUOUS:")
|
|
1673
|
+
report_lines.append("-" * 80)
|
|
1674
|
+
uncertain_labels = [l for l in annotations.labels if l.uncertainty > 0.5]
|
|
1675
|
+
for label in uncertain_labels:
|
|
1676
|
+
report_lines.append(f" {label.entity_id}: {label.label} (uncertainty: {label.uncertainty:.2f})")
|
|
1677
|
+
|
|
1678
|
+
# CONTRADICTS section
|
|
1679
|
+
if contradictions:
|
|
1680
|
+
report_lines.append("")
|
|
1681
|
+
report_lines.append("CONTRADICTS:")
|
|
1682
|
+
report_lines.append("-" * 80)
|
|
1683
|
+
for contradiction in contradictions:
|
|
1684
|
+
report_lines.append(f" {contradiction.contradiction_type}: {contradiction.description}")
|
|
1685
|
+
report_lines.append(f" Entities involved: {', '.join(contradiction.entities_involved)}")
|
|
1686
|
+
|
|
1687
|
+
report_lines.append("")
|
|
1688
|
+
report_lines.append("=" * 80)
|
|
1689
|
+
|
|
1690
|
+
return "\n".join(report_lines)
|
|
1691
|
+
|
|
1692
|
+
def _generate_json(self, annotations: AnnotationGraph) -> Dict[str, Any]:
|
|
1693
|
+
"""Generate JSON representation of annotations."""
|
|
1694
|
+
return {
|
|
1695
|
+
"entities": [
|
|
1696
|
+
{
|
|
1697
|
+
"id": e.id,
|
|
1698
|
+
"type": e.primitive_type,
|
|
1699
|
+
"pixel_coords": e.pixel_coords,
|
|
1700
|
+
"metadata": e.metadata
|
|
1701
|
+
}
|
|
1702
|
+
for e in annotations.entities
|
|
1703
|
+
],
|
|
1704
|
+
"labels": [
|
|
1705
|
+
{
|
|
1706
|
+
"entity_id": l.entity_id,
|
|
1707
|
+
"label": l.label,
|
|
1708
|
+
"uncertainty": l.uncertainty,
|
|
1709
|
+
"tentative": l.tentative,
|
|
1710
|
+
"reasoning": l.reasoning
|
|
1711
|
+
}
|
|
1712
|
+
for l in annotations.labels
|
|
1713
|
+
],
|
|
1714
|
+
"relations": [
|
|
1715
|
+
{
|
|
1716
|
+
"source_id": r.source_id,
|
|
1717
|
+
"target_id": r.target_id,
|
|
1718
|
+
"relation_type": r.relation_type,
|
|
1719
|
+
"uncertainty": r.uncertainty,
|
|
1720
|
+
"tentative": r.tentative
|
|
1721
|
+
}
|
|
1722
|
+
for r in annotations.relations
|
|
1723
|
+
],
|
|
1724
|
+
"contradictions": [
|
|
1725
|
+
{
|
|
1726
|
+
"type": c.contradiction_type,
|
|
1727
|
+
"entities_involved": c.entities_involved,
|
|
1728
|
+
"description": c.description,
|
|
1729
|
+
"severity": c.severity
|
|
1730
|
+
}
|
|
1731
|
+
for c in annotations.contradictions
|
|
1732
|
+
]
|
|
1733
|
+
}
|
|
1734
|
+
|
|
1735
|
+
# ==================== Core Annotation Logic ====================
|
|
1736
|
+
|
|
1737
|
+
def _annotate_core(
|
|
1738
|
+
self,
|
|
1739
|
+
image: np.ndarray,
|
|
1740
|
+
frame_id: Optional[int] = None,
|
|
1741
|
+
params: Optional[Dict[str, Any]] = None
|
|
1742
|
+
) -> AnnotationResult:
|
|
1743
|
+
"""
|
|
1744
|
+
Core annotation logic (internal method).
|
|
1745
|
+
|
|
1746
|
+
Args:
|
|
1747
|
+
image: Input image as numpy array (BGR format from OpenCV)
|
|
1748
|
+
frame_id: Optional frame ID for temporal tracking
|
|
1749
|
+
params: Optional parameters dict for tuning
|
|
1750
|
+
|
|
1751
|
+
Returns:
|
|
1752
|
+
AnnotationResult with overlay, report, and JSON
|
|
1753
|
+
"""
|
|
1754
|
+
start_time = time.time()
|
|
1755
|
+
|
|
1756
|
+
try:
|
|
1757
|
+
# Get or generate cache key
|
|
1758
|
+
if params is None:
|
|
1759
|
+
params = self._auto_tune_params(image)
|
|
1760
|
+
cache_key = self._get_cache_key(image, params) if self.config.cache_enabled else None
|
|
1761
|
+
|
|
1762
|
+
# Try to load cached primitives
|
|
1763
|
+
cached_primitives = None
|
|
1764
|
+
if cache_key:
|
|
1765
|
+
cached_primitives = self._get_cached_primitives(cache_key)
|
|
1766
|
+
|
|
1767
|
+
# 1. Preprocess image
|
|
1768
|
+
processed_image = self._preprocess_image(image, params)
|
|
1769
|
+
|
|
1770
|
+
# 2. Extract primitives (use cache if available)
|
|
1771
|
+
if cached_primitives:
|
|
1772
|
+
entities = cached_primitives
|
|
1773
|
+
logger.debug(f"Using {len(entities)} cached primitives")
|
|
1774
|
+
else:
|
|
1775
|
+
lines = self._extract_lines(processed_image, params)
|
|
1776
|
+
circles = self._extract_circles(processed_image, params)
|
|
1777
|
+
contours = self._extract_contours(processed_image, params)
|
|
1778
|
+
|
|
1779
|
+
# Convert to PrimitiveEntity objects
|
|
1780
|
+
entities = []
|
|
1781
|
+
for line in lines:
|
|
1782
|
+
entity = PrimitiveEntity(
|
|
1783
|
+
id=line.entity_id or str(len(entities)),
|
|
1784
|
+
pixel_coords=[line.start_point, line.end_point],
|
|
1785
|
+
primitive_type="line"
|
|
1786
|
+
)
|
|
1787
|
+
entities.append(entity)
|
|
1788
|
+
line.entity_id = entity.id
|
|
1789
|
+
|
|
1790
|
+
for circle in circles:
|
|
1791
|
+
entity = PrimitiveEntity(
|
|
1792
|
+
id=circle.entity_id or str(len(entities)),
|
|
1793
|
+
pixel_coords=[circle.center],
|
|
1794
|
+
primitive_type="circle",
|
|
1795
|
+
metadata={"radius": circle.radius}
|
|
1796
|
+
)
|
|
1797
|
+
entities.append(entity)
|
|
1798
|
+
circle.entity_id = entity.id
|
|
1799
|
+
|
|
1800
|
+
for contour in contours:
|
|
1801
|
+
entity = PrimitiveEntity(
|
|
1802
|
+
id=contour.entity_id or str(len(entities)),
|
|
1803
|
+
pixel_coords=contour.points,
|
|
1804
|
+
primitive_type="contour"
|
|
1805
|
+
)
|
|
1806
|
+
entities.append(entity)
|
|
1807
|
+
contour.entity_id = entity.id
|
|
1808
|
+
|
|
1809
|
+
# Compute intersections
|
|
1810
|
+
intersections = self._compute_intersections(entities)
|
|
1811
|
+
for intersection in intersections:
|
|
1812
|
+
entity = PrimitiveEntity(
|
|
1813
|
+
id=intersection.entity_id or str(len(entities)),
|
|
1814
|
+
pixel_coords=[intersection.point],
|
|
1815
|
+
primitive_type="intersection",
|
|
1816
|
+
metadata={"primitive_ids": intersection.primitive_ids}
|
|
1817
|
+
)
|
|
1818
|
+
entities.append(entity)
|
|
1819
|
+
intersection.entity_id = entity.id
|
|
1820
|
+
|
|
1821
|
+
# Cache primitives
|
|
1822
|
+
if cache_key:
|
|
1823
|
+
self._cache_primitives(cache_key, entities)
|
|
1824
|
+
|
|
1825
|
+
# 3. Label primitives (GPT-4o-mini, restricted) - never cached
|
|
1826
|
+
labels = self._label_primitives(entities, image)
|
|
1827
|
+
|
|
1828
|
+
# 4. Compile graph
|
|
1829
|
+
annotation_graph = AnnotationGraph(
|
|
1830
|
+
entities=entities,
|
|
1831
|
+
labels=labels,
|
|
1832
|
+
relations=[],
|
|
1833
|
+
contradictions=[],
|
|
1834
|
+
metadata={"image_shape": image.shape, "image_type": params.get("detected_type", "unknown")}
|
|
1835
|
+
)
|
|
1836
|
+
|
|
1837
|
+
graph = self._compile_graph(annotation_graph)
|
|
1838
|
+
|
|
1839
|
+
# 5. Detect contradictions
|
|
1840
|
+
validation = self._validate_graph(graph)
|
|
1841
|
+
contradictions = []
|
|
1842
|
+
if validation.get("cycles"):
|
|
1843
|
+
contradictions.append(Contradiction(
|
|
1844
|
+
contradiction_type="cycle",
|
|
1845
|
+
entities_involved=[],
|
|
1846
|
+
description="Cyclic dependencies detected in graph",
|
|
1847
|
+
severity="high"
|
|
1848
|
+
))
|
|
1849
|
+
if validation.get("unsupported_relations"):
|
|
1850
|
+
contradictions.append(Contradiction(
|
|
1851
|
+
contradiction_type="unsupported",
|
|
1852
|
+
entities_involved=[],
|
|
1853
|
+
description=f"Unsupported relations: {validation['unsupported_relations']}",
|
|
1854
|
+
severity="medium"
|
|
1855
|
+
))
|
|
1856
|
+
|
|
1857
|
+
annotation_graph.contradictions = contradictions
|
|
1858
|
+
|
|
1859
|
+
# 6. Compute measurements (deterministic math)
|
|
1860
|
+
# Already done in _generate_formal_report
|
|
1861
|
+
|
|
1862
|
+
# 7. Track temporally (Kalman filters)
|
|
1863
|
+
instability_detected = False
|
|
1864
|
+
instability_reason = None
|
|
1865
|
+
if self.enable_temporal_tracking and frame_id is not None:
|
|
1866
|
+
for entity in entities:
|
|
1867
|
+
self._track_entity(entity, frame_id)
|
|
1868
|
+
instability = self._detect_instability(entity.id)
|
|
1869
|
+
if instability["detected"]:
|
|
1870
|
+
instability_detected = True
|
|
1871
|
+
instability_reason = instability["reason"]
|
|
1872
|
+
|
|
1873
|
+
self._frame_history.append(annotation_graph)
|
|
1874
|
+
if len(self._frame_history) > 10: # Keep last 10 frames
|
|
1875
|
+
self._frame_history.pop(0)
|
|
1876
|
+
|
|
1877
|
+
# 8. Generate outputs
|
|
1878
|
+
overlay_image = self._generate_overlay(image, annotation_graph)
|
|
1879
|
+
formal_report = self._generate_formal_report(annotation_graph, contradictions)
|
|
1880
|
+
json_output = self._generate_json(annotation_graph)
|
|
1881
|
+
|
|
1882
|
+
# Serialize overlay image
|
|
1883
|
+
_, buffer = cv2.imencode('.jpg', overlay_image)
|
|
1884
|
+
overlay_bytes = buffer.tobytes()
|
|
1885
|
+
|
|
1886
|
+
processing_time = time.time() - start_time
|
|
1887
|
+
|
|
1888
|
+
result = AnnotationResult(
|
|
1889
|
+
annotation_graph=annotation_graph,
|
|
1890
|
+
overlay_image=overlay_bytes,
|
|
1891
|
+
formal_report=formal_report,
|
|
1892
|
+
json_output=json_output,
|
|
1893
|
+
instability_detected=instability_detected,
|
|
1894
|
+
instability_reason=instability_reason,
|
|
1895
|
+
processing_time=processing_time,
|
|
1896
|
+
frame_id=frame_id
|
|
1897
|
+
)
|
|
1898
|
+
|
|
1899
|
+
logger.info(f"Annotation complete: {len(entities)} entities, {len(labels)} labels, {processing_time:.2f}s")
|
|
1900
|
+
return result
|
|
1901
|
+
|
|
1902
|
+
except Exception as e:
|
|
1903
|
+
logger.error(f"Error in annotation: {e}", exc_info=True)
|
|
1904
|
+
# Return empty result on error
|
|
1905
|
+
return AnnotationResult(
|
|
1906
|
+
annotation_graph=AnnotationGraph(),
|
|
1907
|
+
overlay_image=None,
|
|
1908
|
+
formal_report=f"Error: {str(e)}",
|
|
1909
|
+
json_output={"error": str(e)},
|
|
1910
|
+
processing_time=time.time() - start_time
|
|
1911
|
+
)
|
|
1912
|
+
|
|
1913
|
+
# ==================== Main Public API ====================
|
|
1914
|
+
|
|
1915
|
+
def annotate(
|
|
1916
|
+
self,
|
|
1917
|
+
input: Union[str, np.ndarray, Image.Image, List, Path],
|
|
1918
|
+
frame_id: Optional[int] = None,
|
|
1919
|
+
output: Optional[str] = None
|
|
1920
|
+
) -> Union[AnnotationResult, np.ndarray, Dict[str, Any], str, List]:
|
|
1921
|
+
"""
|
|
1922
|
+
Annotate image(s) with full automation.
|
|
1923
|
+
|
|
1924
|
+
Automatically handles:
|
|
1925
|
+
- Input type detection (file path, URL, numpy array, PIL Image, batch)
|
|
1926
|
+
- Image type detection (circuit, architectural, mathematical, etc.)
|
|
1927
|
+
- Parameter tuning
|
|
1928
|
+
- Retry logic
|
|
1929
|
+
- Caching
|
|
1930
|
+
- Batch processing with parallelization
|
|
1931
|
+
|
|
1932
|
+
Args:
|
|
1933
|
+
input: Image input - can be:
|
|
1934
|
+
- File path (str or Path): "image.png"
|
|
1935
|
+
- URL (str): "https://example.com/image.png"
|
|
1936
|
+
- NumPy array: np.ndarray
|
|
1937
|
+
- PIL Image: PIL.Image.Image
|
|
1938
|
+
- List: [path1, path2, ...] for batch processing
|
|
1939
|
+
frame_id: Optional frame ID for temporal tracking
|
|
1940
|
+
output: Output format - "overlay" (default), "json", "report", "all"
|
|
1941
|
+
If None, uses config.output_format
|
|
1942
|
+
|
|
1943
|
+
Returns:
|
|
1944
|
+
- If output="overlay" (default): numpy array (overlay image)
|
|
1945
|
+
- If output="json": dict (JSON data)
|
|
1946
|
+
- If output="report": str (formal report)
|
|
1947
|
+
- If output="all": AnnotationResult object
|
|
1948
|
+
- If input is list: List of above (based on output format)
|
|
1949
|
+
"""
|
|
1950
|
+
# Determine output format
|
|
1951
|
+
# For backward compatibility: if input is np.ndarray and output is None, return AnnotationResult
|
|
1952
|
+
is_old_api = isinstance(input, np.ndarray) and output is None
|
|
1953
|
+
|
|
1954
|
+
if is_old_api:
|
|
1955
|
+
output_format = "all" # Return full AnnotationResult for backward compatibility
|
|
1956
|
+
else:
|
|
1957
|
+
output_format = output or self.config.output_format
|
|
1958
|
+
|
|
1959
|
+
# Auto-detect if input is batch
|
|
1960
|
+
input_type = self._detect_input_type(input)
|
|
1961
|
+
if input_type == 'batch':
|
|
1962
|
+
return self._annotate_batch_auto(input, frame_id, output_format)
|
|
1963
|
+
|
|
1964
|
+
# Single image processing
|
|
1965
|
+
# Auto-load input
|
|
1966
|
+
try:
|
|
1967
|
+
image = self._auto_load_input(input)
|
|
1968
|
+
except Exception as e:
|
|
1969
|
+
logger.error(f"Error loading input: {e}")
|
|
1970
|
+
if output_format == "overlay":
|
|
1971
|
+
return np.zeros((100, 100, 3), dtype=np.uint8) # Empty image
|
|
1972
|
+
elif output_format == "json":
|
|
1973
|
+
return {"error": str(e)}
|
|
1974
|
+
elif output_format == "report":
|
|
1975
|
+
return f"Error: {str(e)}"
|
|
1976
|
+
else:
|
|
1977
|
+
return AnnotationResult(
|
|
1978
|
+
annotation_graph=AnnotationGraph(),
|
|
1979
|
+
overlay_image=None,
|
|
1980
|
+
formal_report=f"Error: {str(e)}",
|
|
1981
|
+
json_output={"error": str(e)},
|
|
1982
|
+
processing_time=0.0
|
|
1983
|
+
)
|
|
1984
|
+
|
|
1985
|
+
# Auto-detect image type and tune parameters
|
|
1986
|
+
image_type = None
|
|
1987
|
+
if self.config.auto_detect_type:
|
|
1988
|
+
image_type = self._detect_image_type(image)
|
|
1989
|
+
logger.info(f"Detected image type: {image_type}")
|
|
1990
|
+
|
|
1991
|
+
# Auto-tune parameters
|
|
1992
|
+
params = self._auto_tune_params(image, image_type)
|
|
1993
|
+
if image_type:
|
|
1994
|
+
params["detected_type"] = image_type
|
|
1995
|
+
|
|
1996
|
+
# Perform annotation with retry logic
|
|
1997
|
+
result = self._annotate_with_retry(image, frame_id, params)
|
|
1998
|
+
|
|
1999
|
+
# Return in requested format
|
|
2000
|
+
return self._format_output(result, output_format)
|
|
2001
|
+
|
|
2002
|
+
def _format_output(self, result: AnnotationResult, output_format: str) -> Union[AnnotationResult, np.ndarray, Dict[str, Any], str]:
|
|
2003
|
+
"""Format annotation result according to output format."""
|
|
2004
|
+
if output_format == "overlay":
|
|
2005
|
+
# Return numpy array
|
|
2006
|
+
if result.overlay_image:
|
|
2007
|
+
return cv2.imdecode(
|
|
2008
|
+
np.frombuffer(result.overlay_image, np.uint8),
|
|
2009
|
+
cv2.IMREAD_COLOR
|
|
2010
|
+
)
|
|
2011
|
+
else:
|
|
2012
|
+
return np.zeros((100, 100, 3), dtype=np.uint8)
|
|
2013
|
+
|
|
2014
|
+
elif output_format == "json":
|
|
2015
|
+
return result.json_output
|
|
2016
|
+
|
|
2017
|
+
elif output_format == "report":
|
|
2018
|
+
return result.formal_report
|
|
2019
|
+
|
|
2020
|
+
else: # "all" or default
|
|
2021
|
+
return result
|
|
2022
|
+
|
|
2023
|
+
def _annotate_batch_auto(
|
|
2024
|
+
self,
|
|
2025
|
+
inputs: List,
|
|
2026
|
+
frame_id: Optional[int] = None,
|
|
2027
|
+
output_format: str = "overlay"
|
|
2028
|
+
) -> List:
|
|
2029
|
+
"""
|
|
2030
|
+
Automatically batch process multiple inputs.
|
|
2031
|
+
|
|
2032
|
+
Features:
|
|
2033
|
+
- Auto-detects input types
|
|
2034
|
+
- Parallel processing
|
|
2035
|
+
- Progress reporting (tqdm)
|
|
2036
|
+
- Error recovery (continues on individual failures)
|
|
2037
|
+
- Smart caching
|
|
2038
|
+
"""
|
|
2039
|
+
# Determine number of workers
|
|
2040
|
+
if self.config.parallel_workers is None:
|
|
2041
|
+
import os
|
|
2042
|
+
workers = min(len(inputs), os.cpu_count() or 4)
|
|
2043
|
+
else:
|
|
2044
|
+
workers = self.config.parallel_workers
|
|
2045
|
+
|
|
2046
|
+
results = []
|
|
2047
|
+
|
|
2048
|
+
# Use progress bar if available
|
|
2049
|
+
if TQDM_AVAILABLE and self.config.show_progress:
|
|
2050
|
+
iterator = tqdm(inputs, desc="Annotating images")
|
|
2051
|
+
else:
|
|
2052
|
+
iterator = inputs
|
|
2053
|
+
|
|
2054
|
+
# Process in parallel if multiple workers
|
|
2055
|
+
if workers > 1 and len(inputs) > 1:
|
|
2056
|
+
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
2057
|
+
# Submit all tasks
|
|
2058
|
+
future_to_input = {
|
|
2059
|
+
executor.submit(self._annotate_single_with_error_handling, inp, i, output_format): (i, inp)
|
|
2060
|
+
for i, inp in enumerate(inputs)
|
|
2061
|
+
}
|
|
2062
|
+
|
|
2063
|
+
# Collect results in order
|
|
2064
|
+
results_dict = {}
|
|
2065
|
+
for future in as_completed(future_to_input):
|
|
2066
|
+
idx, inp = future_to_input[future]
|
|
2067
|
+
try:
|
|
2068
|
+
result = future.result()
|
|
2069
|
+
results_dict[idx] = result
|
|
2070
|
+
except Exception as e:
|
|
2071
|
+
logger.error(f"Error processing input {idx}: {e}")
|
|
2072
|
+
# Return empty result based on output format
|
|
2073
|
+
if output_format == "overlay":
|
|
2074
|
+
results_dict[idx] = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
2075
|
+
elif output_format == "json":
|
|
2076
|
+
results_dict[idx] = {"error": str(e)}
|
|
2077
|
+
elif output_format == "report":
|
|
2078
|
+
results_dict[idx] = f"Error: {str(e)}"
|
|
2079
|
+
else:
|
|
2080
|
+
results_dict[idx] = AnnotationResult(
|
|
2081
|
+
annotation_graph=AnnotationGraph(),
|
|
2082
|
+
overlay_image=None,
|
|
2083
|
+
formal_report=f"Error: {str(e)}",
|
|
2084
|
+
json_output={"error": str(e)},
|
|
2085
|
+
processing_time=0.0
|
|
2086
|
+
)
|
|
2087
|
+
|
|
2088
|
+
# Sort by index to maintain order
|
|
2089
|
+
results = [results_dict[i] for i in sorted(results_dict.keys())]
|
|
2090
|
+
else:
|
|
2091
|
+
# Sequential processing
|
|
2092
|
+
for i, inp in enumerate(iterator):
|
|
2093
|
+
result = self._annotate_single_with_error_handling(inp, i, output_format)
|
|
2094
|
+
results.append(result)
|
|
2095
|
+
|
|
2096
|
+
return results
|
|
2097
|
+
|
|
2098
|
+
def _annotate_single_with_error_handling(
|
|
2099
|
+
self,
|
|
2100
|
+
input: Any,
|
|
2101
|
+
index: int,
|
|
2102
|
+
output_format: str
|
|
2103
|
+
) -> Union[AnnotationResult, np.ndarray, Dict[str, Any], str]:
|
|
2104
|
+
"""Annotate single input with error handling."""
|
|
2105
|
+
try:
|
|
2106
|
+
# Auto-load input
|
|
2107
|
+
image = self._auto_load_input(input)
|
|
2108
|
+
|
|
2109
|
+
# Auto-detect type and tune
|
|
2110
|
+
image_type = None
|
|
2111
|
+
if self.config.auto_detect_type:
|
|
2112
|
+
image_type = self._detect_image_type(image)
|
|
2113
|
+
|
|
2114
|
+
params = self._auto_tune_params(image, image_type)
|
|
2115
|
+
if image_type:
|
|
2116
|
+
params["detected_type"] = image_type
|
|
2117
|
+
|
|
2118
|
+
# Annotate with retry
|
|
2119
|
+
result = self._annotate_with_retry(image, frame_id=index, params=params)
|
|
2120
|
+
|
|
2121
|
+
return self._format_output(result, output_format)
|
|
2122
|
+
|
|
2123
|
+
except Exception as e:
|
|
2124
|
+
logger.error(f"Error processing input {index}: {e}")
|
|
2125
|
+
# Return appropriate error result
|
|
2126
|
+
if output_format == "overlay":
|
|
2127
|
+
return np.zeros((100, 100, 3), dtype=np.uint8)
|
|
2128
|
+
elif output_format == "json":
|
|
2129
|
+
return {"error": str(e), "index": index}
|
|
2130
|
+
elif output_format == "report":
|
|
2131
|
+
return f"Error processing input {index}: {str(e)}"
|
|
2132
|
+
else:
|
|
2133
|
+
return AnnotationResult(
|
|
2134
|
+
annotation_graph=AnnotationGraph(),
|
|
2135
|
+
overlay_image=None,
|
|
2136
|
+
formal_report=f"Error: {str(e)}",
|
|
2137
|
+
json_output={"error": str(e), "index": index},
|
|
2138
|
+
processing_time=0.0
|
|
2139
|
+
)
|
|
2140
|
+
|
|
2141
|
+
def annotate_batch(
|
|
2142
|
+
self,
|
|
2143
|
+
images: List[Union[str, np.ndarray, Image.Image, Path]],
|
|
2144
|
+
frame_id: Optional[int] = None,
|
|
2145
|
+
output: Optional[str] = None
|
|
2146
|
+
) -> List:
|
|
2147
|
+
"""
|
|
2148
|
+
Annotate a batch of images (explicit batch method).
|
|
2149
|
+
|
|
2150
|
+
Note: The main annotate() method auto-detects batch processing,
|
|
2151
|
+
so this method is mainly for explicit batch calls.
|
|
2152
|
+
|
|
2153
|
+
Args:
|
|
2154
|
+
images: List of input images (any supported type)
|
|
2155
|
+
frame_id: Optional starting frame ID
|
|
2156
|
+
output: Output format (overrides config)
|
|
2157
|
+
|
|
2158
|
+
Returns:
|
|
2159
|
+
List of results (format depends on output parameter)
|
|
2160
|
+
"""
|
|
2161
|
+
return self._annotate_batch_auto(images, frame_id, output or self.config.output_format)
|
|
2162
|
+
|
|
2163
|
+
def reset_temporal_state(self) -> None:
|
|
2164
|
+
"""Reset temporal tracking state."""
|
|
2165
|
+
self._entity_trackers.clear()
|
|
2166
|
+
self._frame_history.clear()
|
|
2167
|
+
self._entity_id_map.clear()
|
|
2168
|
+
logger.info("Temporal state reset")
|
|
2169
|
+
|
|
2170
|
+
# ==================== Query/Task-Based Interface ====================
|
|
2171
|
+
|
|
2172
|
+
def query(
|
|
2173
|
+
self,
|
|
2174
|
+
input: Union[str, np.ndarray, Image.Image, Path],
|
|
2175
|
+
query: str,
|
|
2176
|
+
frame_id: Optional[int] = None
|
|
2177
|
+
) -> Dict[str, Any]:
|
|
2178
|
+
"""
|
|
2179
|
+
Answer a specific query about an image using natural language.
|
|
2180
|
+
|
|
2181
|
+
This method performs annotation first, then uses the CR-CA agent to interpret
|
|
2182
|
+
the query and analyze the annotation graph to provide a structured answer.
|
|
2183
|
+
|
|
2184
|
+
Examples:
|
|
2185
|
+
- "find the largest building in this image"
|
|
2186
|
+
- "measure the height of the tallest structure"
|
|
2187
|
+
- "identify all circles and calculate their total area"
|
|
2188
|
+
- "find the longest line and measure its length"
|
|
2189
|
+
- "count how many buildings are in this cityscape"
|
|
2190
|
+
- "find the largest building in this city and measure its dimensions"
|
|
2191
|
+
|
|
2192
|
+
Usage Example:
|
|
2193
|
+
```python
|
|
2194
|
+
from image_annotation import ImageAnnotationEngine
|
|
2195
|
+
|
|
2196
|
+
engine = ImageAnnotationEngine()
|
|
2197
|
+
|
|
2198
|
+
# Query: Find largest building and measure it
|
|
2199
|
+
result = engine.query(
|
|
2200
|
+
"cityscape.jpg",
|
|
2201
|
+
"find the largest building in this image and measure its dimensions"
|
|
2202
|
+
)
|
|
2203
|
+
|
|
2204
|
+
print(result["answer"])
|
|
2205
|
+
# Output: "Found 1 matching entities. The largest entity is abc123 with size 12500.00 pixels².
|
|
2206
|
+
# Measurements:
|
|
2207
|
+
# Entity abc123: area=12500.00px², width=150.00px, height=83.33px"
|
|
2208
|
+
|
|
2209
|
+
# Access entities
|
|
2210
|
+
for entity in result["entities"]:
|
|
2211
|
+
print(f"Entity {entity['id']}: {entity['label']}")
|
|
2212
|
+
|
|
2213
|
+
# Access measurements
|
|
2214
|
+
for entity_id, measurements in result["measurements"].items():
|
|
2215
|
+
print(f"{entity_id}: {measurements}")
|
|
2216
|
+
|
|
2217
|
+
# Save visualization
|
|
2218
|
+
if result["visualization"] is not None:
|
|
2219
|
+
cv2.imwrite("query_result.png", result["visualization"])
|
|
2220
|
+
```
|
|
2221
|
+
|
|
2222
|
+
Args:
|
|
2223
|
+
input: Image input (file path, URL, numpy array, PIL Image)
|
|
2224
|
+
query: Natural language query about the image
|
|
2225
|
+
frame_id: Optional frame ID for temporal tracking
|
|
2226
|
+
|
|
2227
|
+
Returns:
|
|
2228
|
+
Dictionary with:
|
|
2229
|
+
- "answer": Natural language answer to the query
|
|
2230
|
+
- "entities": List of relevant entities found (with id, type, label, pixel_coords, metadata)
|
|
2231
|
+
- "measurements": Dict of measurements performed (area, length, width, height, etc.)
|
|
2232
|
+
- "confidence": Confidence score (0.0-1.0)
|
|
2233
|
+
- "reasoning": Step-by-step reasoning from CR-CA agent
|
|
2234
|
+
- "visualization": Optional overlay image (numpy array) highlighting relevant entities
|
|
2235
|
+
- "annotation_graph": Full AnnotationGraph object for advanced analysis
|
|
2236
|
+
"""
|
|
2237
|
+
logger.info(f"Processing query: {query}")
|
|
2238
|
+
|
|
2239
|
+
# Step 1: Annotate the image first
|
|
2240
|
+
annotation_result = self.annotate(input, frame_id=frame_id, output="all")
|
|
2241
|
+
|
|
2242
|
+
if not isinstance(annotation_result, AnnotationResult):
|
|
2243
|
+
# Fallback if annotate returned something else
|
|
2244
|
+
logger.warning("Annotation returned unexpected format, re-annotating...")
|
|
2245
|
+
image = self._auto_load_input(input)
|
|
2246
|
+
annotation_result = self._annotate_with_retry(image, frame_id, self._auto_tune_params(image))
|
|
2247
|
+
|
|
2248
|
+
# Step 2: Use CR-CA agent to interpret query and analyze annotation graph
|
|
2249
|
+
query_result = self._process_query(query, annotation_result)
|
|
2250
|
+
|
|
2251
|
+
# Step 3: Extract relevant entities based on query
|
|
2252
|
+
relevant_entities = self._extract_relevant_entities(query, annotation_result.annotation_graph)
|
|
2253
|
+
|
|
2254
|
+
# Step 4: Perform measurements if requested
|
|
2255
|
+
measurements = self._perform_query_measurements(query, relevant_entities, annotation_result.annotation_graph)
|
|
2256
|
+
|
|
2257
|
+
# Step 5: Generate visualization highlighting relevant entities
|
|
2258
|
+
visualization = self._generate_query_visualization(
|
|
2259
|
+
self._auto_load_input(input) if not isinstance(input, np.ndarray) else input,
|
|
2260
|
+
annotation_result.annotation_graph,
|
|
2261
|
+
relevant_entities
|
|
2262
|
+
)
|
|
2263
|
+
|
|
2264
|
+
# Step 6: Compile final answer
|
|
2265
|
+
answer = self._generate_query_answer(query, query_result, relevant_entities, measurements)
|
|
2266
|
+
|
|
2267
|
+
return {
|
|
2268
|
+
"answer": answer,
|
|
2269
|
+
"entities": [
|
|
2270
|
+
{
|
|
2271
|
+
"id": e.id,
|
|
2272
|
+
"type": e.primitive_type,
|
|
2273
|
+
"label": self._get_entity_label(e.id, annotation_result.annotation_graph),
|
|
2274
|
+
"pixel_coords": e.pixel_coords,
|
|
2275
|
+
"metadata": e.metadata
|
|
2276
|
+
}
|
|
2277
|
+
for e in relevant_entities
|
|
2278
|
+
],
|
|
2279
|
+
"measurements": measurements,
|
|
2280
|
+
"confidence": query_result.get("confidence", 0.5),
|
|
2281
|
+
"reasoning": query_result.get("reasoning", ""),
|
|
2282
|
+
"visualization": visualization,
|
|
2283
|
+
"annotation_graph": annotation_result.annotation_graph
|
|
2284
|
+
}
|
|
2285
|
+
|
|
2286
|
+
def _process_query(self, query: str, annotation_result: AnnotationResult) -> Dict[str, Any]:
|
|
2287
|
+
"""Use CR-CA agent to interpret query and analyze annotation graph."""
|
|
2288
|
+
# Prepare context for the agent
|
|
2289
|
+
graph_summary = self._summarize_annotation_graph(annotation_result.annotation_graph)
|
|
2290
|
+
|
|
2291
|
+
query_prompt = f"""You are analyzing an annotated image to answer a specific query.
|
|
2292
|
+
|
|
2293
|
+
ANNOTATION SUMMARY:
|
|
2294
|
+
{graph_summary}
|
|
2295
|
+
|
|
2296
|
+
USER QUERY: {query}
|
|
2297
|
+
|
|
2298
|
+
Your task is to:
|
|
2299
|
+
1. Understand what the user is asking for
|
|
2300
|
+
2. Identify which entities in the annotation graph are relevant
|
|
2301
|
+
3. Determine what measurements or analysis are needed
|
|
2302
|
+
4. Provide reasoning for your answer
|
|
2303
|
+
|
|
2304
|
+
Consider:
|
|
2305
|
+
- Entity types (lines, circles, contours, intersections)
|
|
2306
|
+
- Semantic labels (what each entity represents)
|
|
2307
|
+
- Spatial relationships between entities
|
|
2308
|
+
- Size/scale comparisons
|
|
2309
|
+
- Counting operations
|
|
2310
|
+
- Measurement requirements
|
|
2311
|
+
|
|
2312
|
+
Provide your analysis in a structured format:
|
|
2313
|
+
- Relevant entity IDs
|
|
2314
|
+
- Required measurements
|
|
2315
|
+
- Reasoning steps
|
|
2316
|
+
- Confidence level (0.0-1.0)
|
|
2317
|
+
"""
|
|
2318
|
+
|
|
2319
|
+
try:
|
|
2320
|
+
# Use CR-CA agent to process query
|
|
2321
|
+
response = self._crca_agent.run(task=query_prompt)
|
|
2322
|
+
|
|
2323
|
+
# Parse response
|
|
2324
|
+
if isinstance(response, dict):
|
|
2325
|
+
return {
|
|
2326
|
+
"reasoning": response.get("response", str(response)),
|
|
2327
|
+
"confidence": 0.7,
|
|
2328
|
+
"raw_response": response
|
|
2329
|
+
}
|
|
2330
|
+
else:
|
|
2331
|
+
return {
|
|
2332
|
+
"reasoning": str(response),
|
|
2333
|
+
"confidence": 0.7,
|
|
2334
|
+
"raw_response": response
|
|
2335
|
+
}
|
|
2336
|
+
except Exception as e:
|
|
2337
|
+
logger.error(f"Error processing query with CR-CA agent: {e}")
|
|
2338
|
+
return {
|
|
2339
|
+
"reasoning": f"Query processing encountered an error: {str(e)}",
|
|
2340
|
+
"confidence": 0.3,
|
|
2341
|
+
"raw_response": None
|
|
2342
|
+
}
|
|
2343
|
+
|
|
2344
|
+
def _summarize_annotation_graph(self, graph: AnnotationGraph) -> str:
|
|
2345
|
+
"""Create a summary of the annotation graph for query processing."""
|
|
2346
|
+
summary_lines = []
|
|
2347
|
+
|
|
2348
|
+
summary_lines.append(f"Total entities: {len(graph.entities)}")
|
|
2349
|
+
summary_lines.append(f"Total labels: {len(graph.labels)}")
|
|
2350
|
+
summary_lines.append(f"Total relations: {len(graph.relations)}")
|
|
2351
|
+
summary_lines.append(f"Contradictions: {len(graph.contradictions)}")
|
|
2352
|
+
|
|
2353
|
+
# Entity type breakdown
|
|
2354
|
+
type_counts = {}
|
|
2355
|
+
for entity in graph.entities:
|
|
2356
|
+
type_counts[entity.primitive_type] = type_counts.get(entity.primitive_type, 0) + 1
|
|
2357
|
+
summary_lines.append(f"Entity types: {dict(type_counts)}")
|
|
2358
|
+
|
|
2359
|
+
# Sample labels
|
|
2360
|
+
if graph.labels:
|
|
2361
|
+
summary_lines.append("\nSample labels:")
|
|
2362
|
+
for label in graph.labels[:10]: # First 10 labels
|
|
2363
|
+
entity = graph.get_entity_by_id(label.entity_id)
|
|
2364
|
+
if entity:
|
|
2365
|
+
summary_lines.append(f" - {label.label} (entity {label.entity_id}, type: {entity.primitive_type}, uncertainty: {label.uncertainty:.2f})")
|
|
2366
|
+
|
|
2367
|
+
# Relations
|
|
2368
|
+
if graph.relations:
|
|
2369
|
+
summary_lines.append("\nRelations:")
|
|
2370
|
+
for relation in graph.relations[:10]: # First 10 relations
|
|
2371
|
+
summary_lines.append(f" - {relation.source_id} --[{relation.relation_type}]--> {relation.target_id}")
|
|
2372
|
+
|
|
2373
|
+
return "\n".join(summary_lines)
|
|
2374
|
+
|
|
2375
|
+
def _extract_relevant_entities(
|
|
2376
|
+
self,
|
|
2377
|
+
query: str,
|
|
2378
|
+
graph: AnnotationGraph
|
|
2379
|
+
) -> List[PrimitiveEntity]:
|
|
2380
|
+
"""Extract entities relevant to the query."""
|
|
2381
|
+
query_lower = query.lower()
|
|
2382
|
+
relevant_entities = []
|
|
2383
|
+
|
|
2384
|
+
# Keywords that suggest entity types
|
|
2385
|
+
if "building" in query_lower or "structure" in query_lower:
|
|
2386
|
+
# Look for entities labeled as buildings
|
|
2387
|
+
for label in graph.labels:
|
|
2388
|
+
if "building" in label.label.lower() or "structure" in label.label.lower():
|
|
2389
|
+
entity = graph.get_entity_by_id(label.entity_id)
|
|
2390
|
+
if entity:
|
|
2391
|
+
relevant_entities.append(entity)
|
|
2392
|
+
|
|
2393
|
+
if "circle" in query_lower or "round" in query_lower:
|
|
2394
|
+
for entity in graph.entities:
|
|
2395
|
+
if entity.primitive_type == "circle":
|
|
2396
|
+
relevant_entities.append(entity)
|
|
2397
|
+
|
|
2398
|
+
if "line" in query_lower or "edge" in query_lower:
|
|
2399
|
+
for entity in graph.entities:
|
|
2400
|
+
if entity.primitive_type == "line":
|
|
2401
|
+
relevant_entities.append(entity)
|
|
2402
|
+
|
|
2403
|
+
# Size-related queries
|
|
2404
|
+
if "largest" in query_lower or "biggest" in query_lower or "tallest" in query_lower:
|
|
2405
|
+
# Find largest entity by area/radius
|
|
2406
|
+
if relevant_entities:
|
|
2407
|
+
# Filter to largest
|
|
2408
|
+
largest = max(relevant_entities, key=lambda e: self._get_entity_size(e))
|
|
2409
|
+
relevant_entities = [largest]
|
|
2410
|
+
else:
|
|
2411
|
+
# Find largest overall
|
|
2412
|
+
if graph.entities:
|
|
2413
|
+
largest = max(graph.entities, key=lambda e: self._get_entity_size(e))
|
|
2414
|
+
relevant_entities = [largest]
|
|
2415
|
+
|
|
2416
|
+
if "smallest" in query_lower or "tiny" in query_lower:
|
|
2417
|
+
if relevant_entities:
|
|
2418
|
+
smallest = min(relevant_entities, key=lambda e: self._get_entity_size(e))
|
|
2419
|
+
relevant_entities = [smallest]
|
|
2420
|
+
else:
|
|
2421
|
+
if graph.entities:
|
|
2422
|
+
smallest = min(graph.entities, key=lambda e: self._get_entity_size(e))
|
|
2423
|
+
relevant_entities = [smallest]
|
|
2424
|
+
|
|
2425
|
+
# Count queries - return all matching entities
|
|
2426
|
+
if "count" in query_lower or "how many" in query_lower:
|
|
2427
|
+
# Return all relevant entities for counting
|
|
2428
|
+
pass # Already collected above
|
|
2429
|
+
|
|
2430
|
+
# If no specific matches, return all entities
|
|
2431
|
+
if not relevant_entities:
|
|
2432
|
+
relevant_entities = graph.entities[:20] # Limit to first 20 for performance
|
|
2433
|
+
|
|
2434
|
+
return relevant_entities
|
|
2435
|
+
|
|
2436
|
+
def _get_entity_size(self, entity: PrimitiveEntity) -> float:
|
|
2437
|
+
"""Calculate size metric for entity (area, length, etc.)."""
|
|
2438
|
+
if entity.primitive_type == "circle":
|
|
2439
|
+
radius = entity.metadata.get("radius", 0)
|
|
2440
|
+
return np.pi * radius * radius # Area
|
|
2441
|
+
elif entity.primitive_type == "line":
|
|
2442
|
+
if len(entity.pixel_coords) >= 2:
|
|
2443
|
+
p1, p2 = entity.pixel_coords[0], entity.pixel_coords[1]
|
|
2444
|
+
return np.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) # Length
|
|
2445
|
+
elif entity.primitive_type == "contour":
|
|
2446
|
+
if len(entity.pixel_coords) >= 3:
|
|
2447
|
+
# Calculate polygon area using shoelace formula
|
|
2448
|
+
points = np.array(entity.pixel_coords)
|
|
2449
|
+
x = points[:, 0]
|
|
2450
|
+
y = points[:, 1]
|
|
2451
|
+
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
|
|
2452
|
+
|
|
2453
|
+
return 0.0
|
|
2454
|
+
|
|
2455
|
+
def _perform_query_measurements(
|
|
2456
|
+
self,
|
|
2457
|
+
query: str,
|
|
2458
|
+
entities: List[PrimitiveEntity],
|
|
2459
|
+
graph: AnnotationGraph
|
|
2460
|
+
) -> Dict[str, Any]:
|
|
2461
|
+
"""Perform measurements requested by the query."""
|
|
2462
|
+
query_lower = query.lower()
|
|
2463
|
+
measurements = {}
|
|
2464
|
+
|
|
2465
|
+
# Check if measurement is requested
|
|
2466
|
+
if "measure" in query_lower or "size" in query_lower or "height" in query_lower or "width" in query_lower or "area" in query_lower or "length" in query_lower:
|
|
2467
|
+
for entity in entities:
|
|
2468
|
+
entity_id = entity.id
|
|
2469
|
+
|
|
2470
|
+
if entity.primitive_type == "circle":
|
|
2471
|
+
radius = entity.metadata.get("radius", 0)
|
|
2472
|
+
measurements[entity_id] = {
|
|
2473
|
+
"type": "circle",
|
|
2474
|
+
"radius": float(radius),
|
|
2475
|
+
"diameter": float(radius * 2),
|
|
2476
|
+
"area": float(np.pi * radius * radius),
|
|
2477
|
+
"circumference": float(2 * np.pi * radius)
|
|
2478
|
+
}
|
|
2479
|
+
|
|
2480
|
+
elif entity.primitive_type == "line":
|
|
2481
|
+
if len(entity.pixel_coords) >= 2:
|
|
2482
|
+
p1, p2 = entity.pixel_coords[0], entity.pixel_coords[1]
|
|
2483
|
+
length = np.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)
|
|
2484
|
+
measurements[entity_id] = {
|
|
2485
|
+
"type": "line",
|
|
2486
|
+
"length": float(length),
|
|
2487
|
+
"start_point": p1,
|
|
2488
|
+
"end_point": p2
|
|
2489
|
+
}
|
|
2490
|
+
|
|
2491
|
+
elif entity.primitive_type == "contour":
|
|
2492
|
+
if len(entity.pixel_coords) >= 3:
|
|
2493
|
+
points = np.array(entity.pixel_coords)
|
|
2494
|
+
x = points[:, 0]
|
|
2495
|
+
y = points[:, 1]
|
|
2496
|
+
area = 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
|
|
2497
|
+
|
|
2498
|
+
# Bounding box
|
|
2499
|
+
x_min, x_max = int(np.min(x)), int(np.max(x))
|
|
2500
|
+
y_min, y_max = int(np.min(y)), int(np.max(y))
|
|
2501
|
+
width = x_max - x_min
|
|
2502
|
+
height = y_max - y_min
|
|
2503
|
+
|
|
2504
|
+
measurements[entity_id] = {
|
|
2505
|
+
"type": "contour",
|
|
2506
|
+
"area": float(area),
|
|
2507
|
+
"width": float(width),
|
|
2508
|
+
"height": float(height),
|
|
2509
|
+
"bounding_box": {
|
|
2510
|
+
"x_min": int(x_min),
|
|
2511
|
+
"y_min": int(y_min),
|
|
2512
|
+
"x_max": int(x_max),
|
|
2513
|
+
"y_max": int(y_max)
|
|
2514
|
+
}
|
|
2515
|
+
}
|
|
2516
|
+
|
|
2517
|
+
return measurements
|
|
2518
|
+
|
|
2519
|
+
def _get_entity_label(self, entity_id: str, graph: AnnotationGraph) -> Optional[str]:
|
|
2520
|
+
"""Get label for entity."""
|
|
2521
|
+
labels = graph.get_labels_for_entity(entity_id)
|
|
2522
|
+
if labels:
|
|
2523
|
+
return labels[0].label
|
|
2524
|
+
return None
|
|
2525
|
+
|
|
2526
|
+
def _generate_query_visualization(
|
|
2527
|
+
self,
|
|
2528
|
+
image: np.ndarray,
|
|
2529
|
+
graph: AnnotationGraph,
|
|
2530
|
+
relevant_entities: List[PrimitiveEntity]
|
|
2531
|
+
) -> Optional[np.ndarray]:
|
|
2532
|
+
"""Generate visualization highlighting relevant entities."""
|
|
2533
|
+
try:
|
|
2534
|
+
# Create overlay with highlighted relevant entities
|
|
2535
|
+
overlay = image.copy()
|
|
2536
|
+
|
|
2537
|
+
# Highlight relevant entities in different color
|
|
2538
|
+
for entity in relevant_entities:
|
|
2539
|
+
if entity.primitive_type == "line" and len(entity.pixel_coords) >= 2:
|
|
2540
|
+
cv2.line(overlay, entity.pixel_coords[0], entity.pixel_coords[1], (0, 255, 0), 3)
|
|
2541
|
+
elif entity.primitive_type == "circle" and len(entity.pixel_coords) >= 1:
|
|
2542
|
+
center = entity.pixel_coords[0]
|
|
2543
|
+
radius = int(entity.metadata.get("radius", 10))
|
|
2544
|
+
cv2.circle(overlay, center, radius, (0, 255, 0), 3)
|
|
2545
|
+
elif entity.primitive_type == "contour" and len(entity.pixel_coords) >= 3:
|
|
2546
|
+
points = np.array(entity.pixel_coords, dtype=np.int32)
|
|
2547
|
+
cv2.polylines(overlay, [points], True, (0, 255, 0), 3)
|
|
2548
|
+
|
|
2549
|
+
return overlay
|
|
2550
|
+
except Exception as e:
|
|
2551
|
+
logger.error(f"Error generating visualization: {e}")
|
|
2552
|
+
return None
|
|
2553
|
+
|
|
2554
|
+
def _generate_query_answer(
|
|
2555
|
+
self,
|
|
2556
|
+
query: str,
|
|
2557
|
+
query_result: Dict[str, Any],
|
|
2558
|
+
entities: List[PrimitiveEntity],
|
|
2559
|
+
measurements: Dict[str, Any]
|
|
2560
|
+
) -> str:
|
|
2561
|
+
"""Generate natural language answer to the query."""
|
|
2562
|
+
answer_parts = []
|
|
2563
|
+
|
|
2564
|
+
# Count queries
|
|
2565
|
+
if "count" in query.lower() or "how many" in query.lower():
|
|
2566
|
+
answer_parts.append(f"Found {len(entities)} matching entities.")
|
|
2567
|
+
|
|
2568
|
+
# Size queries
|
|
2569
|
+
if "largest" in query.lower() or "biggest" in query.lower():
|
|
2570
|
+
if entities:
|
|
2571
|
+
entity = entities[0]
|
|
2572
|
+
size = self._get_entity_size(entity)
|
|
2573
|
+
answer_parts.append(f"The largest entity is {entity.id} with size {size:.2f} pixels².")
|
|
2574
|
+
|
|
2575
|
+
# Measurement queries
|
|
2576
|
+
if measurements:
|
|
2577
|
+
answer_parts.append("\nMeasurements:")
|
|
2578
|
+
for entity_id, meas in measurements.items():
|
|
2579
|
+
if meas["type"] == "circle":
|
|
2580
|
+
answer_parts.append(f" Entity {entity_id}: radius={meas['radius']:.2f}px, area={meas['area']:.2f}px²")
|
|
2581
|
+
elif meas["type"] == "line":
|
|
2582
|
+
answer_parts.append(f" Entity {entity_id}: length={meas['length']:.2f}px")
|
|
2583
|
+
elif meas["type"] == "contour":
|
|
2584
|
+
answer_parts.append(f" Entity {entity_id}: area={meas['area']:.2f}px², width={meas['width']:.2f}px, height={meas['height']:.2f}px")
|
|
2585
|
+
|
|
2586
|
+
# Add reasoning if available
|
|
2587
|
+
if query_result.get("reasoning"):
|
|
2588
|
+
answer_parts.append(f"\nReasoning: {query_result['reasoning']}")
|
|
2589
|
+
|
|
2590
|
+
if not answer_parts:
|
|
2591
|
+
answer_parts.append("Query processed, but no specific answer could be generated.")
|
|
2592
|
+
|
|
2593
|
+
return "\n".join(answer_parts)
|