lithicore 0.4.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lithicore/__init__.py +111 -0
- lithicore/_annotations.py +181 -0
- lithicore/_assistant.py +311 -0
- lithicore/_batch.py +97 -0
- lithicore/_classification.py +1411 -0
- lithicore/_cli.py +224 -0
- lithicore/_comparison.py +101 -0
- lithicore/_edge_detection.py +56 -0
- lithicore/_figure.py +383 -0
- lithicore/_landmarks.py +152 -0
- lithicore/_metrics.py +70 -0
- lithicore/_models.py +172 -0
- lithicore/_orientation.py +109 -0
- lithicore/_ph_features.py +349 -0
- lithicore/_photo_preprocessing.py +483 -0
- lithicore/_photogrammetry.py +695 -0
- lithicore/_platform_angle.py +121 -0
- lithicore/_scale_detection.py +460 -0
- lithicore/_scar_detection.py +318 -0
- lithicore/_validation.py +137 -0
- lithicore-0.4.1b0.dist-info/METADATA +14 -0
- lithicore-0.4.1b0.dist-info/RECORD +25 -0
- lithicore-0.4.1b0.dist-info/WHEEL +5 -0
- lithicore-0.4.1b0.dist-info/entry_points.txt +2 -0
- lithicore-0.4.1b0.dist-info/top_level.txt +1 -0
lithicore/__init__.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""lithicore — 3D lithic artefact morphological measurement library.
|
|
2
|
+
|
|
3
|
+
exports: orient_auto(mesh, config) -> tuple[trimesh.Trimesh, np.ndarray]
|
|
4
|
+
orient_manual(mesh, points, config) -> tuple[trimesh.Trimesh, np.ndarray]
|
|
5
|
+
extract_metrics(mesh, config) -> list[MeasurementResult]
|
|
6
|
+
detect_edges(mesh, config) -> np.ndarray
|
|
7
|
+
platform_angles(mesh, config) -> tuple[MeasurementResult, MeasurementResult]
|
|
8
|
+
validate_mesh(mesh) -> MeshQualityReport
|
|
9
|
+
repair_mesh(mesh) -> trimesh.Trimesh
|
|
10
|
+
batch_process(directory, config) -> list[ArtefactResult]
|
|
11
|
+
PhotogrammetryConfig, PhotogrammetryResult, PhotogrammetryError,
|
|
12
|
+
ColmapNotFoundError, ColmapStageError, InsufficientPhotosError,
|
|
13
|
+
PhotogrammetryCancelledError, colmap_available, run_pipeline
|
|
14
|
+
ScaleResult, detect_scale_aruco, apply_scale_to_mesh
|
|
15
|
+
used_by: lithicope GUI, CLI users
|
|
16
|
+
rules: No GUI imports. Every public function takes a mesh + config and returns typed results.
|
|
17
|
+
agent: deepseek-v4-flash | 2026-05-26 | Initial scaffolding
|
|
18
|
+
deepseek-v4-flash | 2026-05-26 | All modules wired — full public API exported
|
|
19
|
+
deepseek-v4-flash | 2026-05-26 | Added FigureConfig + generate_figure export
|
|
20
|
+
deepseek-v4-pro | 2026-06-12 | Added AssistantResult to __all__ and imports
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
# pylint: disable=unused-import
|
|
24
|
+
try:
|
|
25
|
+
from lithicore._models import (
|
|
26
|
+
MeasurementConfig,
|
|
27
|
+
MeasurementResult,
|
|
28
|
+
ArtefactResult,
|
|
29
|
+
Landmark,
|
|
30
|
+
MeshQualityReport,
|
|
31
|
+
MeshGrade,
|
|
32
|
+
ClassificationResult,
|
|
33
|
+
FeatureImportance,
|
|
34
|
+
LithicFeatureVector,
|
|
35
|
+
AssistantResult,
|
|
36
|
+
)
|
|
37
|
+
from lithicore._orientation import orient_auto, orient_manual
|
|
38
|
+
from lithicore._metrics import extract_metrics
|
|
39
|
+
from lithicore._edge_detection import detect_edges
|
|
40
|
+
from lithicore._platform_angle import platform_angles
|
|
41
|
+
from lithicore._validation import validate_mesh, repair_mesh
|
|
42
|
+
from lithicore._batch import batch_process
|
|
43
|
+
from lithicore._figure import FigureConfig, generate_figure
|
|
44
|
+
from lithicore._comparison import compare_meshes, ComparisonResult
|
|
45
|
+
from lithicore._photogrammetry import (
|
|
46
|
+
PhotogrammetryConfig,
|
|
47
|
+
PhotogrammetryResult,
|
|
48
|
+
PhotogrammetryError,
|
|
49
|
+
ColmapNotFoundError,
|
|
50
|
+
ColmapStageError,
|
|
51
|
+
InsufficientPhotosError,
|
|
52
|
+
PhotogrammetryCancelledError,
|
|
53
|
+
run_pipeline,
|
|
54
|
+
colmap_available,
|
|
55
|
+
clean_point_cloud,
|
|
56
|
+
ProgressCallback,
|
|
57
|
+
)
|
|
58
|
+
from lithicore._scale_detection import (
|
|
59
|
+
ScaleResult,
|
|
60
|
+
detect_scale_aruco,
|
|
61
|
+
apply_scale_to_mesh,
|
|
62
|
+
)
|
|
63
|
+
from lithicore._photo_preprocessing import (
|
|
64
|
+
PreprocessingConfig,
|
|
65
|
+
PreprocessingResult,
|
|
66
|
+
preprocess_photos,
|
|
67
|
+
compute_laplacian_variance,
|
|
68
|
+
compute_blur_scores,
|
|
69
|
+
)
|
|
70
|
+
from lithicore._annotations import (
|
|
71
|
+
Annotation,
|
|
72
|
+
AnnotationSet,
|
|
73
|
+
)
|
|
74
|
+
from lithicore._classification import (
|
|
75
|
+
ClassifierModel,
|
|
76
|
+
extract_features,
|
|
77
|
+
train_model,
|
|
78
|
+
extract_diagnostic_coordinates,
|
|
79
|
+
)
|
|
80
|
+
from lithicore._assistant import (
|
|
81
|
+
AssistantEngine,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
__all__ = [
|
|
85
|
+
"MeasurementConfig", "MeasurementResult", "ArtefactResult", "Landmark",
|
|
86
|
+
"MeshQualityReport", "MeshGrade",
|
|
87
|
+
"orient_auto", "orient_manual",
|
|
88
|
+
"extract_metrics", "detect_edges", "platform_angles",
|
|
89
|
+
"validate_mesh", "repair_mesh", "batch_process",
|
|
90
|
+
"FigureConfig", "generate_figure",
|
|
91
|
+
"compare_meshes", "ComparisonResult",
|
|
92
|
+
"PhotogrammetryConfig", "PhotogrammetryResult",
|
|
93
|
+
"PhotogrammetryError", "ColmapNotFoundError", "ColmapStageError",
|
|
94
|
+
"InsufficientPhotosError", "PhotogrammetryCancelledError",
|
|
95
|
+
"run_pipeline", "colmap_available", "clean_point_cloud",
|
|
96
|
+
"ProgressCallback",
|
|
97
|
+
"ScaleResult", "detect_scale_aruco", "apply_scale_to_mesh",
|
|
98
|
+
"PreprocessingConfig", "PreprocessingResult",
|
|
99
|
+
"preprocess_photos", "compute_laplacian_variance", "compute_blur_scores",
|
|
100
|
+
"Annotation", "AnnotationSet",
|
|
101
|
+
"ClassificationResult", "FeatureImportance", "LithicFeatureVector",
|
|
102
|
+
"ClassifierModel", "extract_features", "train_model",
|
|
103
|
+
"extract_diagnostic_coordinates",
|
|
104
|
+
"AssistantEngine", "AssistantResult",
|
|
105
|
+
]
|
|
106
|
+
except ImportError as _exc:
|
|
107
|
+
# Forward reference — all modules exist since Phases 2-4
|
|
108
|
+
raise ImportError(
|
|
109
|
+
f"lithicore module import failed: {_exc}. "
|
|
110
|
+
"Try: pip install --no-deps -e lithicore"
|
|
111
|
+
) from _exc
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""_annotations.py — 3D mesh annotation data model.
|
|
2
|
+
|
|
3
|
+
exports: Annotation
|
|
4
|
+
AnnotationSet
|
|
5
|
+
used_by: lithicope annotation panel
|
|
6
|
+
rules: Pure dataclasses with JSON serialization. No GUI imports.
|
|
7
|
+
Coordinates are (x, y, z) floats matching mesh vertex space.
|
|
8
|
+
agent: deepseek-v4-flash | 2026-05-27 | Initial implementation
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import hashlib
|
|
14
|
+
import json
|
|
15
|
+
from dataclasses import asdict, dataclass, field
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class Annotation:
|
|
21
|
+
"""A single annotation attached to a 3D mesh point.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
point: (x, y, z) coordinates on the mesh surface.
|
|
25
|
+
title: Short label for the annotation.
|
|
26
|
+
description: Multi-line descriptive notes.
|
|
27
|
+
category: Type classification — e.g. "scar", "ridge",
|
|
28
|
+
"notch", "cortex", "flake", "breakage", "other".
|
|
29
|
+
measurement_mm: Optional numeric measurement at this point.
|
|
30
|
+
confidence: Estimated reliability (0 = uncertain, 1 = certain).
|
|
31
|
+
author: Name or identifier of the annotator.
|
|
32
|
+
timestamp: ISO 8601 datetime string of creation/last edit.
|
|
33
|
+
attached_photos: List of file paths to associated images.
|
|
34
|
+
sub_annotations: Child annotations nested under this one.
|
|
35
|
+
"""
|
|
36
|
+
point: tuple[float, float, float]
|
|
37
|
+
title: str
|
|
38
|
+
description: str = ""
|
|
39
|
+
category: str = ""
|
|
40
|
+
measurement_mm: float = 0.0
|
|
41
|
+
confidence: float = 1.0
|
|
42
|
+
author: str = ""
|
|
43
|
+
timestamp: str = ""
|
|
44
|
+
attached_photos: list[str] = field(default_factory=list)
|
|
45
|
+
sub_annotations: list["Annotation"] = field(default_factory=list)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class AnnotationSet:
|
|
50
|
+
"""A collection of annotations for a single artefact mesh.
|
|
51
|
+
|
|
52
|
+
Attributes:
|
|
53
|
+
format_version: Schema version for forward compatibility.
|
|
54
|
+
artefact_label: Human-readable artefact identifier.
|
|
55
|
+
mesh_path: Relative or absolute path to the associated mesh file.
|
|
56
|
+
mesh_checksum: SHA-256 hex digest of the mesh file for validation.
|
|
57
|
+
author: Name of the person who created this set.
|
|
58
|
+
created: ISO 8601 datetime of initial creation.
|
|
59
|
+
annotations: All top-level annotations for this artefact.
|
|
60
|
+
"""
|
|
61
|
+
format_version: int = 1
|
|
62
|
+
artefact_label: str = ""
|
|
63
|
+
mesh_path: str = ""
|
|
64
|
+
mesh_checksum: str = ""
|
|
65
|
+
author: str = ""
|
|
66
|
+
created: str = ""
|
|
67
|
+
annotations: list[Annotation] = field(default_factory=list)
|
|
68
|
+
|
|
69
|
+
def to_json(self) -> str:
|
|
70
|
+
"""Serialize this annotation set to a JSON string."""
|
|
71
|
+
data = asdict(self)
|
|
72
|
+
return json.dumps(data, indent=2, default=str)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def from_json(cls, data: str) -> AnnotationSet:
|
|
76
|
+
"""Deserialize a JSON string into an AnnotationSet."""
|
|
77
|
+
raw = json.loads(data)
|
|
78
|
+
# Reconstruct nested Annotation objects
|
|
79
|
+
raw["annotations"] = [
|
|
80
|
+
cls._annotation_from_dict(a) for a in raw.get("annotations", [])
|
|
81
|
+
]
|
|
82
|
+
return cls(**raw)
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def _annotation_from_dict(raw: dict) -> Annotation:
|
|
86
|
+
"""Recursively build an Annotation from a dict, handling sub-annotations."""
|
|
87
|
+
subs = raw.pop("sub_annotations", [])
|
|
88
|
+
ann = Annotation(**raw)
|
|
89
|
+
ann.sub_annotations = [
|
|
90
|
+
AnnotationSet._annotation_from_dict(s) for s in subs
|
|
91
|
+
]
|
|
92
|
+
return ann
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _point_key(point: tuple[float, float, float]) -> tuple[float, float, float]:
|
|
96
|
+
"""Round a 3D point to 3 decimal places for stable merge matching."""
|
|
97
|
+
return (round(point[0], 3), round(point[1], 3), round(point[2], 3))
|
|
98
|
+
|
|
99
|
+
def merge(self, other: AnnotationSet) -> tuple[AnnotationSet, list[str]]:
|
|
100
|
+
"""Merge another AnnotationSet into this one.
|
|
101
|
+
|
|
102
|
+
Annotations at the same 3D position (rounded to 3 dp) are merged.
|
|
103
|
+
Unique positions are appended. Conflicts (same position, different
|
|
104
|
+
data) keep both entries with an author suffix and a warning.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
other: The incoming annotation set to merge.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
A tuple of (merged AnnotationSet, list of warning strings).
|
|
111
|
+
"""
|
|
112
|
+
warnings: list[str] = []
|
|
113
|
+
merged = AnnotationSet(
|
|
114
|
+
format_version=self.format_version,
|
|
115
|
+
artefact_label=self.artefact_label or other.artefact_label,
|
|
116
|
+
mesh_path=self.mesh_path or other.mesh_path,
|
|
117
|
+
mesh_checksum=self.mesh_checksum or other.mesh_checksum,
|
|
118
|
+
author=f"{self.author}+{other.author}" if self.author and other.author
|
|
119
|
+
else self.author or other.author,
|
|
120
|
+
created=self.created or other.created,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Index existing annotations by position key
|
|
124
|
+
existing: dict[tuple[float, float, float], Annotation] = {}
|
|
125
|
+
for ann in self.annotations:
|
|
126
|
+
existing[self._point_key(ann.point)] = ann
|
|
127
|
+
|
|
128
|
+
# Add all current annotations
|
|
129
|
+
merged.annotations = list(self.annotations)
|
|
130
|
+
|
|
131
|
+
for ann in other.annotations:
|
|
132
|
+
key = self._point_key(ann.point)
|
|
133
|
+
if key in existing:
|
|
134
|
+
existing_ann = existing[key]
|
|
135
|
+
# Conflict detection: same point, different title/desc
|
|
136
|
+
if (existing_ann.title != ann.title or
|
|
137
|
+
existing_ann.description != ann.description):
|
|
138
|
+
suffix = f" ({ann.author})" if ann.author else " (imported)"
|
|
139
|
+
merged_ann = Annotation(
|
|
140
|
+
point=ann.point,
|
|
141
|
+
title=ann.title + suffix,
|
|
142
|
+
description=ann.description,
|
|
143
|
+
category=ann.category or existing_ann.category,
|
|
144
|
+
measurement_mm=ann.measurement_mm or existing_ann.measurement_mm,
|
|
145
|
+
confidence=ann.confidence or existing_ann.confidence,
|
|
146
|
+
author=ann.author or existing_ann.author,
|
|
147
|
+
timestamp=max(ann.timestamp, existing_ann.timestamp)
|
|
148
|
+
if ann.timestamp and existing_ann.timestamp
|
|
149
|
+
else ann.timestamp or existing_ann.timestamp,
|
|
150
|
+
attached_photos=list(set(existing_ann.attached_photos + ann.attached_photos)),
|
|
151
|
+
)
|
|
152
|
+
# Replace in-place
|
|
153
|
+
for i, e in enumerate(merged.annotations):
|
|
154
|
+
if self._point_key(e.point) == key:
|
|
155
|
+
merged.annotations[i] = merged_ann
|
|
156
|
+
break
|
|
157
|
+
warnings.append(
|
|
158
|
+
f"Merged annotation at ({ann.point[0]:.3f}, {ann.point[1]:.3f}, "
|
|
159
|
+
f"{ann.point[2]:.3f}): conflicting data resolved with author suffix"
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
# Same content — prefer newer timestamp
|
|
163
|
+
if ann.timestamp and (not existing_ann.timestamp or
|
|
164
|
+
ann.timestamp > existing_ann.timestamp):
|
|
165
|
+
for i, e in enumerate(merged.annotations):
|
|
166
|
+
if self._point_key(e.point) == key:
|
|
167
|
+
e.timestamp = ann.timestamp
|
|
168
|
+
e.author = ann.author or e.author
|
|
169
|
+
break
|
|
170
|
+
else:
|
|
171
|
+
merged.annotations.append(ann)
|
|
172
|
+
|
|
173
|
+
return merged, warnings
|
|
174
|
+
|
|
175
|
+
def compute_checksum(self, mesh_path: Path) -> str:
|
|
176
|
+
"""Compute SHA-256 hex digest of a mesh file."""
|
|
177
|
+
sha = hashlib.sha256()
|
|
178
|
+
with open(mesh_path, "rb") as f:
|
|
179
|
+
for chunk in iter(lambda: f.read(65536), b""):
|
|
180
|
+
sha.update(chunk)
|
|
181
|
+
return f"sha256:{sha.hexdigest()}"
|
lithicore/_assistant.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""_assistant.py — AI-powered natural language query engine for lithic collections.
|
|
2
|
+
|
|
3
|
+
exports: AssistantEngine
|
|
4
|
+
used_by: lithicope assistant panel
|
|
5
|
+
rules: No GUI imports. DuckDB queries are read-only SELECT. LLM is optional dependency.
|
|
6
|
+
All functions safe to call when model not loaded (returns error result).
|
|
7
|
+
agent: deepseek-v4-flash | 2026-05-27 | Initial implementation
|
|
8
|
+
deepseek-v4-pro | 2026-06-12 | Added _validate_sql() safety gate rejecting non-SELECT, multi-statement, and dangerous keywords
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Callable, Optional
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
from lithicore._models import AssistantResult
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
MODEL_DIR = Path.home() / ".dibble" / "models" / "assistant"
|
|
25
|
+
MODEL_FILENAME = "qwen3-4b-q4_k_m.gguf"
|
|
26
|
+
GRAMMAR_DIR = Path(__file__).resolve().parent.parent / "data" / "grammars"
|
|
27
|
+
SQL_GRAMMAR_PATH = GRAMMAR_DIR / "sql_query.gbnf"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
SCHEMA_DESCRIPTION = """
|
|
31
|
+
Table: artifacts
|
|
32
|
+
Columns:
|
|
33
|
+
- length_mm (FLOAT) — maximum length in mm
|
|
34
|
+
- width_mm (FLOAT) — maximum width in mm
|
|
35
|
+
- thickness_mm (FLOAT) — maximum thickness in mm
|
|
36
|
+
- surface_area_mm2 (FLOAT) — total surface area
|
|
37
|
+
- volume_mm3 (FLOAT) — total volume
|
|
38
|
+
- elongation (FLOAT) — length/width ratio
|
|
39
|
+
- flatness (FLOAT) — width/thickness ratio
|
|
40
|
+
- compactness (FLOAT) — volume/length^3
|
|
41
|
+
- relative_thickness (FLOAT) — thickness/length
|
|
42
|
+
- scar_count (INTEGER) — number of flake scars detected
|
|
43
|
+
- mean_scar_area_mm2 (FLOAT) — average scar area
|
|
44
|
+
- platform_angle_deg (FLOAT) — platform angle in degrees
|
|
45
|
+
- edge_angle_mean_deg (FLOAT) — mean edge angle
|
|
46
|
+
- edge_angle_std_deg (FLOAT) — edge angle std deviation
|
|
47
|
+
- curvature_index (FLOAT) — dorsal curvature
|
|
48
|
+
- cross_section_profile (FLOAT) — 0=flat, 1=triangular, 2=round
|
|
49
|
+
- symmetry_score (FLOAT) — bilateral symmetry (0-1)
|
|
50
|
+
- com_z_ratio (FLOAT) — centre of mass height ratio
|
|
51
|
+
- dorsal_ridge_count (INTEGER) — number of parallel ridges
|
|
52
|
+
- surface_roughness (FLOAT) — texture metric
|
|
53
|
+
- typology (TEXT) — predicted type
|
|
54
|
+
- artefact_label (TEXT) — user-assigned name
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
FEW_SHOT_EXAMPLES = """
|
|
58
|
+
-- Example 1: "show me all blades longer than 100mm"
|
|
59
|
+
SELECT * FROM artifacts WHERE typology = 'Blade' AND length_mm > 100 LIMIT 50;
|
|
60
|
+
|
|
61
|
+
-- Example 2: "what's the average platform angle of crested blades?"
|
|
62
|
+
SELECT AVG(platform_angle_deg) FROM artifacts WHERE typology = 'Crested blade';
|
|
63
|
+
|
|
64
|
+
-- Example 3: "find the 5 most symmetrical handaxes"
|
|
65
|
+
SELECT * FROM artifacts WHERE typology = 'Handaxe' ORDER BY symmetry_score DESC LIMIT 5;
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class AssistantEngine:
|
|
70
|
+
"""LLM-powered natural language query engine for lithic collections."""
|
|
71
|
+
|
|
72
|
+
def __init__(self) -> None:
|
|
73
|
+
self._llm = None
|
|
74
|
+
self._grammar: Optional[str] = None
|
|
75
|
+
self._model_available = False
|
|
76
|
+
|
|
77
|
+
def load_model(self, progress_cb: Optional[Callable] = None) -> None:
|
|
78
|
+
"""Load the LLM model. Downloads from HuggingFace if not cached."""
|
|
79
|
+
try:
|
|
80
|
+
import llama_cpp
|
|
81
|
+
except ImportError:
|
|
82
|
+
if progress_cb:
|
|
83
|
+
progress_cb("error", 0.0, "llama-cpp-python not installed")
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
model_path = MODEL_DIR / MODEL_FILENAME
|
|
87
|
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
|
|
89
|
+
if not model_path.exists():
|
|
90
|
+
if progress_cb:
|
|
91
|
+
progress_cb("download", 0.0, "Downloading model (~2.5GB)...")
|
|
92
|
+
try:
|
|
93
|
+
self._llm = llama_cpp.Llama.from_pretrained(
|
|
94
|
+
repo_id="Qwen/Qwen3-4B-GGUF",
|
|
95
|
+
filename="*q4_k_m.gguf",
|
|
96
|
+
verbose=False,
|
|
97
|
+
)
|
|
98
|
+
except Exception as exc:
|
|
99
|
+
if progress_cb:
|
|
100
|
+
progress_cb("error", 0.0, f"Download failed: {exc}")
|
|
101
|
+
return
|
|
102
|
+
else:
|
|
103
|
+
if progress_cb:
|
|
104
|
+
progress_cb("loading", 0.0, "Loading model...")
|
|
105
|
+
try:
|
|
106
|
+
self._llm = llama_cpp.Llama(
|
|
107
|
+
model_path=str(model_path),
|
|
108
|
+
n_ctx=4096,
|
|
109
|
+
n_threads=os.cpu_count() or 4,
|
|
110
|
+
verbose=False,
|
|
111
|
+
)
|
|
112
|
+
except Exception as exc:
|
|
113
|
+
if progress_cb:
|
|
114
|
+
progress_cb("error", 0.0, f"Model load failed: {exc}")
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
if SQL_GRAMMAR_PATH.exists():
|
|
118
|
+
self._grammar = SQL_GRAMMAR_PATH.read_text()
|
|
119
|
+
|
|
120
|
+
self._model_available = True
|
|
121
|
+
if progress_cb:
|
|
122
|
+
progress_cb("ready", 1.0, "AI assistant ready")
|
|
123
|
+
|
|
124
|
+
def is_loaded(self) -> bool:
|
|
125
|
+
"""Check if the model is ready for queries."""
|
|
126
|
+
return self._model_available and self._llm is not None
|
|
127
|
+
|
|
128
|
+
def _validate_sql(self, sql: str) -> bool:
|
|
129
|
+
"""Validate that a generated SQL string is a safe SELECT query.
|
|
130
|
+
|
|
131
|
+
Rejects non-SELECT statements, multi-statement queries, and queries
|
|
132
|
+
that reference system tables. This is a hard safety gate independent
|
|
133
|
+
of the GBNF grammar.
|
|
134
|
+
"""
|
|
135
|
+
stripped = sql.strip().upper()
|
|
136
|
+
if not stripped.startswith("SELECT"):
|
|
137
|
+
return False
|
|
138
|
+
if ";" in stripped[:-1]:
|
|
139
|
+
return False # Multi-statement — potential injection
|
|
140
|
+
# Reject dangerous keywords outside of quotes
|
|
141
|
+
forbidden = {"DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE",
|
|
142
|
+
"TRUNCATE", "EXEC", "EXECUTE", "ATTACH", "DETACH",
|
|
143
|
+
"PRAGMA", "IMPORT", "EXPORT"}
|
|
144
|
+
words = set(stripped.replace("(", " ").replace(")", " ").replace(",", " ").split())
|
|
145
|
+
if words & forbidden:
|
|
146
|
+
return False
|
|
147
|
+
return True
|
|
148
|
+
|
|
149
|
+
def query(self, user_text: str, collection_df: pd.DataFrame) -> AssistantResult:
|
|
150
|
+
"""Run the full query loop: natural language -> SQL -> execute -> explain.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
user_text: The user's natural language query.
|
|
154
|
+
collection_df: In-memory DataFrame of the current artefact collection.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
AssistantResult with natural_language explanation, sql_query, row_count.
|
|
158
|
+
"""
|
|
159
|
+
import duckdb
|
|
160
|
+
|
|
161
|
+
if not self.is_loaded():
|
|
162
|
+
return AssistantResult(
|
|
163
|
+
error="AI model not loaded. Use Assistant > Download Model to set up.",
|
|
164
|
+
processing_time_s=0.0,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if collection_df.empty:
|
|
168
|
+
return AssistantResult(
|
|
169
|
+
natural_language="No artefacts in the current collection to query.",
|
|
170
|
+
processing_time_s=0.0,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
start = time.time()
|
|
174
|
+
|
|
175
|
+
# Step 1: Build prompt
|
|
176
|
+
prompt = self._build_sql_prompt(user_text, collection_df)
|
|
177
|
+
|
|
178
|
+
# Step 2: Generate SQL with GBNF grammar
|
|
179
|
+
sql = self._generate_sql(prompt)
|
|
180
|
+
if sql is None:
|
|
181
|
+
return AssistantResult(
|
|
182
|
+
error="Failed to generate a valid SQL query.",
|
|
183
|
+
processing_time_s=round(time.time() - start, 2),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
if not self._validate_sql(sql):
|
|
187
|
+
return AssistantResult(
|
|
188
|
+
sql_query=sql,
|
|
189
|
+
error="Generated SQL query was rejected by safety validator.",
|
|
190
|
+
processing_time_s=round(time.time() - start, 2),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Step 3: Self-correcting execution loop
|
|
194
|
+
result_df = None
|
|
195
|
+
last_error = ""
|
|
196
|
+
for attempt in range(3):
|
|
197
|
+
try:
|
|
198
|
+
result_df = duckdb.sql(sql).df()
|
|
199
|
+
break
|
|
200
|
+
except Exception as exc:
|
|
201
|
+
last_error = str(exc)
|
|
202
|
+
if attempt < 2:
|
|
203
|
+
sql = self._fix_sql(prompt, sql, last_error)
|
|
204
|
+
|
|
205
|
+
elapsed = time.time() - start
|
|
206
|
+
|
|
207
|
+
if result_df is None:
|
|
208
|
+
return AssistantResult(
|
|
209
|
+
sql_query=sql,
|
|
210
|
+
error=f"SQL execution failed after 3 attempts: {last_error}",
|
|
211
|
+
processing_time_s=round(elapsed, 2),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Step 4: Summarize results
|
|
215
|
+
summary = self._summarize_results(user_text, result_df)
|
|
216
|
+
|
|
217
|
+
return AssistantResult(
|
|
218
|
+
natural_language=summary or f"Found {len(result_df)} matching artefacts.",
|
|
219
|
+
sql_query=sql,
|
|
220
|
+
row_count=len(result_df),
|
|
221
|
+
processing_time_s=round(elapsed, 2),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def _build_sql_prompt(self, user_text: str, df: pd.DataFrame) -> str:
|
|
225
|
+
"""Build the system prompt with schema and user query."""
|
|
226
|
+
return (
|
|
227
|
+
"You are a SQL query generator for a lithic (stone tool) analysis database. "
|
|
228
|
+
"Generate ONLY a DuckDB SQL SELECT query. No explanations, no markdown.\n\n"
|
|
229
|
+
f"Schema:\n{SCHEMA_DESCRIPTION}\n\n"
|
|
230
|
+
f"Examples:\n{FEW_SHOT_EXAMPLES}\n\n"
|
|
231
|
+
f"Table has {len(df)} rows.\n\n"
|
|
232
|
+
f"User query: {user_text}\n\n"
|
|
233
|
+
"SQL:"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def _generate_sql(self, prompt: str) -> Optional[str]:
|
|
237
|
+
"""Generate SQL constrained by GBNF grammar."""
|
|
238
|
+
try:
|
|
239
|
+
output = self._llm.create_completion(
|
|
240
|
+
prompt,
|
|
241
|
+
max_tokens=256,
|
|
242
|
+
temperature=0.1,
|
|
243
|
+
grammar=self._grammar,
|
|
244
|
+
stop=[";"],
|
|
245
|
+
)
|
|
246
|
+
text = output["choices"][0]["text"].strip()
|
|
247
|
+
if not text.endswith(";"):
|
|
248
|
+
text += ";"
|
|
249
|
+
return text
|
|
250
|
+
except Exception:
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
def _fix_sql(self, prompt: str, bad_sql: str, error: str) -> Optional[str]:
|
|
254
|
+
"""Fix a SQL error by appending the error and regenerating."""
|
|
255
|
+
fix_prompt = (
|
|
256
|
+
f"{prompt}\n\n"
|
|
257
|
+
f"Previous (failed) SQL: {bad_sql}\n"
|
|
258
|
+
f"Error: {error}\n"
|
|
259
|
+
"Fixed SQL:"
|
|
260
|
+
)
|
|
261
|
+
return self._generate_sql(fix_prompt)
|
|
262
|
+
|
|
263
|
+
def _summarize_results(self, user_text: str, df: pd.DataFrame) -> Optional[str]:
|
|
264
|
+
"""Generate a natural language summary of query results."""
|
|
265
|
+
if df.empty:
|
|
266
|
+
return "No artefacts matched your query."
|
|
267
|
+
|
|
268
|
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|
269
|
+
stats = {}
|
|
270
|
+
for col in list(numeric_cols)[:5]:
|
|
271
|
+
stats[col] = {
|
|
272
|
+
"mean": round(float(df[col].mean()), 1),
|
|
273
|
+
"min": round(float(df[col].min()), 1),
|
|
274
|
+
"max": round(float(df[col].max()), 1),
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
type_counts = {}
|
|
278
|
+
if "typology" in df.columns:
|
|
279
|
+
type_counts = df["typology"].value_counts().head(5).to_dict()
|
|
280
|
+
|
|
281
|
+
summary_prompt = (
|
|
282
|
+
"Summarize these lithic analysis results in 1-3 concise sentences.\n"
|
|
283
|
+
f"User asked: {user_text}\n"
|
|
284
|
+
f"Found {len(df)} matching artefacts.\n"
|
|
285
|
+
f"Summary stats: {stats}\n"
|
|
286
|
+
f"Typology breakdown: {type_counts}\n\n"
|
|
287
|
+
"Response:"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
output = self._llm.create_completion(
|
|
292
|
+
summary_prompt,
|
|
293
|
+
max_tokens=200,
|
|
294
|
+
temperature=0.3,
|
|
295
|
+
stop=["\n\n"],
|
|
296
|
+
)
|
|
297
|
+
return output["choices"][0]["text"].strip()
|
|
298
|
+
except Exception:
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
@staticmethod
|
|
302
|
+
def get_model_status() -> dict:
|
|
303
|
+
"""Return model status info for the UI."""
|
|
304
|
+
model_path = MODEL_DIR / MODEL_FILENAME
|
|
305
|
+
size_mb = round(model_path.stat().st_size / (1024 * 1024), 1) if model_path.exists() else 0
|
|
306
|
+
return {
|
|
307
|
+
"installed": model_path.exists(),
|
|
308
|
+
"size_mb": size_mb,
|
|
309
|
+
"path": str(model_path),
|
|
310
|
+
"grammar_exists": SQL_GRAMMAR_PATH.exists(),
|
|
311
|
+
}
|