anywidget-vector 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. anywidget_vector/__init__.py +1 -1
  2. anywidget_vector/backends/__init__.py +103 -0
  3. anywidget_vector/backends/chroma/__init__.py +27 -0
  4. anywidget_vector/backends/chroma/client.py +60 -0
  5. anywidget_vector/backends/chroma/converter.py +86 -0
  6. anywidget_vector/backends/grafeo/__init__.py +20 -0
  7. anywidget_vector/backends/grafeo/client.py +33 -0
  8. anywidget_vector/backends/grafeo/converter.py +46 -0
  9. anywidget_vector/backends/lancedb/__init__.py +22 -0
  10. anywidget_vector/backends/lancedb/client.py +56 -0
  11. anywidget_vector/backends/lancedb/converter.py +71 -0
  12. anywidget_vector/backends/pinecone/__init__.py +21 -0
  13. anywidget_vector/backends/pinecone/client.js +45 -0
  14. anywidget_vector/backends/pinecone/converter.py +62 -0
  15. anywidget_vector/backends/qdrant/__init__.py +26 -0
  16. anywidget_vector/backends/qdrant/client.js +61 -0
  17. anywidget_vector/backends/qdrant/converter.py +83 -0
  18. anywidget_vector/backends/weaviate/__init__.py +33 -0
  19. anywidget_vector/backends/weaviate/client.js +50 -0
  20. anywidget_vector/backends/weaviate/converter.py +81 -0
  21. anywidget_vector/static/icons.js +14 -0
  22. anywidget_vector/traitlets.py +84 -0
  23. anywidget_vector/ui/__init__.py +206 -0
  24. anywidget_vector/ui/canvas.js +521 -0
  25. anywidget_vector/ui/constants.js +64 -0
  26. anywidget_vector/ui/properties.js +158 -0
  27. anywidget_vector/ui/settings.js +265 -0
  28. anywidget_vector/ui/styles.css +348 -0
  29. anywidget_vector/ui/toolbar.js +117 -0
  30. anywidget_vector/widget.py +174 -1120
  31. {anywidget_vector-0.2.0.dist-info → anywidget_vector-0.2.1.dist-info}/METADATA +3 -3
  32. anywidget_vector-0.2.1.dist-info/RECORD +34 -0
  33. anywidget_vector-0.2.0.dist-info/RECORD +0 -6
  34. {anywidget_vector-0.2.0.dist-info → anywidget_vector-0.2.1.dist-info}/WHEEL +0 -0
@@ -2,771 +2,156 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
6
+ import math
5
7
  from typing import TYPE_CHECKING, Any
6
8
 
7
9
  import anywidget
8
10
  import traitlets
9
11
 
10
- if TYPE_CHECKING:
11
- from collections.abc import Callable
12
-
13
- _ESM = """
14
- import * as THREE from "https://esm.sh/three@0.160.0";
15
- import { OrbitControls } from "https://esm.sh/three@0.160.0/addons/controls/OrbitControls.js";
16
-
17
- // Color scales
18
- const COLOR_SCALES = {
19
- viridis: [[0.267,0.004,0.329],[0.282,0.140,0.458],[0.253,0.265,0.530],[0.206,0.371,0.553],[0.163,0.471,0.558],[0.127,0.566,0.551],[0.134,0.658,0.518],[0.267,0.749,0.441],[0.478,0.821,0.318],[0.741,0.873,0.150],[0.993,0.906,0.144]],
20
- plasma: [[0.050,0.030,0.528],[0.254,0.014,0.615],[0.417,0.001,0.658],[0.578,0.015,0.643],[0.716,0.135,0.538],[0.826,0.268,0.407],[0.906,0.411,0.271],[0.959,0.567,0.137],[0.981,0.733,0.106],[0.964,0.903,0.259],[0.940,0.975,0.131]],
21
- inferno: [[0.001,0.000,0.014],[0.046,0.031,0.186],[0.140,0.046,0.357],[0.258,0.039,0.406],[0.366,0.071,0.432],[0.478,0.107,0.429],[0.591,0.148,0.404],[0.706,0.206,0.347],[0.815,0.290,0.259],[0.905,0.411,0.145],[0.969,0.565,0.026]],
22
- magma: [[0.001,0.000,0.014],[0.035,0.028,0.144],[0.114,0.049,0.315],[0.206,0.053,0.431],[0.306,0.064,0.505],[0.413,0.086,0.531],[0.529,0.113,0.527],[0.654,0.158,0.501],[0.776,0.232,0.459],[0.878,0.338,0.418],[0.953,0.468,0.392]],
23
- cividis: [[0.000,0.135,0.304],[0.000,0.179,0.345],[0.117,0.222,0.360],[0.214,0.263,0.365],[0.293,0.304,0.370],[0.366,0.345,0.375],[0.437,0.387,0.382],[0.509,0.429,0.393],[0.582,0.473,0.409],[0.659,0.520,0.431],[0.739,0.570,0.461]],
24
- turbo: [[0.190,0.072,0.232],[0.254,0.265,0.530],[0.163,0.471,0.558],[0.134,0.658,0.518],[0.478,0.821,0.318],[0.741,0.873,0.150],[0.993,0.906,0.144],[0.988,0.652,0.198],[0.925,0.394,0.235],[0.796,0.177,0.214],[0.480,0.016,0.110]],
25
- };
26
-
27
- const CATEGORICAL_COLORS = [
28
- "#6366f1", "#f59e0b", "#10b981", "#ef4444", "#8b5cf6",
29
- "#06b6d4", "#f97316", "#84cc16", "#ec4899", "#14b8a6"
30
- ];
31
-
32
- // Shape geometries factory
33
- const SHAPES = {
34
- sphere: () => new THREE.SphereGeometry(1, 16, 16),
35
- cube: () => new THREE.BoxGeometry(1, 1, 1),
36
- cone: () => new THREE.ConeGeometry(0.7, 1.4, 16),
37
- tetrahedron: () => new THREE.TetrahedronGeometry(1),
38
- octahedron: () => new THREE.OctahedronGeometry(1),
39
- cylinder: () => new THREE.CylinderGeometry(0.5, 0.5, 1, 16),
40
- };
41
-
42
- // Distance metrics
43
- const DISTANCE_METRICS = {
44
- euclidean: (a, b) => {
45
- const dx = a.x - b.x, dy = a.y - b.y, dz = a.z - b.z;
46
- return Math.sqrt(dx*dx + dy*dy + dz*dz);
47
- },
48
- cosine: (a, b) => {
49
- const dot = a.x*b.x + a.y*b.y + a.z*b.z;
50
- const magA = Math.sqrt(a.x*a.x + a.y*a.y + a.z*a.z);
51
- const magB = Math.sqrt(b.x*b.x + b.y*b.y + b.z*b.z);
52
- if (magA === 0 || magB === 0) return 1;
53
- return 1 - (dot / (magA * magB)); // Convert similarity to distance
54
- },
55
- manhattan: (a, b) => {
56
- return Math.abs(a.x - b.x) + Math.abs(a.y - b.y) + Math.abs(a.z - b.z);
57
- },
58
- dot_product: (a, b) => {
59
- // Negative dot product as distance (higher dot = closer)
60
- return -(a.x*b.x + a.y*b.y + a.z*b.z);
61
- },
62
- };
63
-
64
- function computeDistance(p1, p2, metric) {
65
- const fn = DISTANCE_METRICS[metric] || DISTANCE_METRICS.euclidean;
66
- return fn(p1, p2);
67
- }
68
-
69
- function getColorFromScale(value, scaleName, domain) {
70
- const scale = COLOR_SCALES[scaleName] || COLOR_SCALES.viridis;
71
- const [min, max] = domain || [0, 1];
72
- const t = Math.max(0, Math.min(1, (value - min) / (max - min)));
73
- const idx = t * (scale.length - 1);
74
- const i = Math.floor(idx);
75
- const f = idx - i;
76
- if (i >= scale.length - 1) {
77
- const c = scale[scale.length - 1];
78
- return new THREE.Color(c[0], c[1], c[2]);
79
- }
80
- const c1 = scale[i], c2 = scale[i + 1];
81
- return new THREE.Color(
82
- c1[0] + f * (c2[0] - c1[0]),
83
- c1[1] + f * (c2[1] - c1[1]),
84
- c1[2] + f * (c2[2] - c1[2])
85
- );
86
- }
87
-
88
- function hashString(str) {
89
- let hash = 0;
90
- for (let i = 0; i < str.length; i++) {
91
- hash = ((hash << 5) - hash) + str.charCodeAt(i);
92
- hash |= 0;
93
- }
94
- return Math.abs(hash);
95
- }
96
-
97
- function getCategoricalColor(value) {
98
- const idx = hashString(String(value)) % CATEGORICAL_COLORS.length;
99
- return new THREE.Color(CATEGORICAL_COLORS[idx]);
100
- }
101
-
102
- function render({ model, el }) {
103
- let scene, camera, renderer, controls;
104
- let pointsGroup, connectionsGroup;
105
- let raycaster, mouse;
106
- let hoveredObject = null;
107
- let tooltip;
108
- let axesGroup, gridHelper;
109
- let animationId;
110
-
111
- init();
112
- animate();
113
-
114
- function init() {
115
- // Container
116
- const container = document.createElement("div");
117
- container.className = "anywidget-vector";
118
- container.style.width = model.get("width") + "px";
119
- container.style.height = model.get("height") + "px";
120
- container.style.position = "relative";
121
- el.appendChild(container);
122
-
123
- // Scene
124
- scene = new THREE.Scene();
125
- scene.background = new THREE.Color(model.get("background"));
126
-
127
- // Camera
128
- const aspect = model.get("width") / model.get("height");
129
- camera = new THREE.PerspectiveCamera(60, aspect, 0.01, 1000);
130
- const camPos = model.get("camera_position") || [2, 2, 2];
131
- camera.position.set(camPos[0], camPos[1], camPos[2]);
132
-
133
- // Renderer
134
- renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
135
- renderer.setSize(model.get("width"), model.get("height"));
136
- renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
137
- container.appendChild(renderer.domElement);
138
-
139
- // Controls
140
- controls = new OrbitControls(camera, renderer.domElement);
141
- controls.enableDamping = true;
142
- controls.dampingFactor = 0.05;
143
- const target = model.get("camera_target") || [0, 0, 0];
144
- controls.target.set(target[0], target[1], target[2]);
145
- controls.addEventListener("change", onCameraChange);
146
-
147
- // Lighting
148
- const ambient = new THREE.AmbientLight(0xffffff, 0.6);
149
- scene.add(ambient);
150
- const directional = new THREE.DirectionalLight(0xffffff, 0.8);
151
- directional.position.set(5, 10, 7);
152
- scene.add(directional);
153
-
154
- // Groups
155
- pointsGroup = new THREE.Group();
156
- scene.add(pointsGroup);
157
- connectionsGroup = new THREE.Group();
158
- scene.add(connectionsGroup);
159
- axesGroup = new THREE.Group();
160
- scene.add(axesGroup);
161
-
162
- // Setup
163
- setupAxesAndGrid();
164
- setupRaycaster(container);
165
- setupTooltip(container);
166
- createPoints();
167
- createConnections();
168
- bindModelEvents();
169
- }
170
-
171
- function setupAxesAndGrid() {
172
- // Clear existing
173
- while (axesGroup.children.length > 0) {
174
- axesGroup.remove(axesGroup.children[0]);
175
- }
176
- if (gridHelper) {
177
- scene.remove(gridHelper);
178
- gridHelper = null;
179
- }
180
-
181
- if (model.get("show_axes")) {
182
- const axes = new THREE.AxesHelper(1.2);
183
- axesGroup.add(axes);
184
-
185
- // Axis labels
186
- const labels = model.get("axis_labels") || { x: "X", y: "Y", z: "Z" };
187
- addAxisLabel(labels.x, [1.3, 0, 0], 0xff4444);
188
- addAxisLabel(labels.y, [0, 1.3, 0], 0x44ff44);
189
- addAxisLabel(labels.z, [0, 0, 1.3], 0x4444ff);
190
- }
191
-
192
- if (model.get("show_grid")) {
193
- gridHelper = new THREE.GridHelper(2, model.get("grid_divisions") || 10, 0x444444, 0x333333);
194
- scene.add(gridHelper);
195
- }
196
- }
197
-
198
- function addAxisLabel(text, position, color) {
199
- const canvas = document.createElement("canvas");
200
- const ctx = canvas.getContext("2d");
201
- canvas.width = 64;
202
- canvas.height = 32;
203
- ctx.font = "bold 24px Arial";
204
- ctx.fillStyle = "#" + color.toString(16).padStart(6, "0");
205
- ctx.textAlign = "center";
206
- ctx.fillText(text, 32, 24);
207
-
208
- const texture = new THREE.CanvasTexture(canvas);
209
- const material = new THREE.SpriteMaterial({ map: texture });
210
- const sprite = new THREE.Sprite(material);
211
- sprite.position.set(position[0], position[1], position[2]);
212
- sprite.scale.set(0.25, 0.125, 1);
213
- axesGroup.add(sprite);
214
- }
215
-
216
- function createPoints() {
217
- // Clear existing
218
- while (pointsGroup.children.length > 0) {
219
- const obj = pointsGroup.children[0];
220
- if (obj.geometry) obj.geometry.dispose();
221
- if (obj.material) obj.material.dispose();
222
- pointsGroup.remove(obj);
223
- }
224
-
225
- const points = model.get("points") || [];
226
- if (points.length === 0) return;
227
-
228
- const colorField = model.get("color_field");
229
- const colorScale = model.get("color_scale") || "viridis";
230
- const colorDomain = model.get("color_domain");
231
- const sizeField = model.get("size_field");
232
- const sizeRange = model.get("size_range") || [0.02, 0.1];
233
- const shapeField = model.get("shape_field");
234
- const shapeMap = model.get("shape_map") || {};
235
-
236
- // Compute color domain if needed
237
- let computedColorDomain = colorDomain;
238
- if (colorField && !colorDomain) {
239
- const values = points.map(p => p[colorField]).filter(v => typeof v === "number");
240
- if (values.length > 0) {
241
- computedColorDomain = [Math.min(...values), Math.max(...values)];
242
- }
243
- }
244
-
245
- // Compute size domain if needed
246
- let sizeDomain = null;
247
- if (sizeField) {
248
- const values = points.map(p => p[sizeField]).filter(v => typeof v === "number");
249
- if (values.length > 0) {
250
- sizeDomain = [Math.min(...values), Math.max(...values)];
251
- }
252
- }
253
-
254
- // Group points by shape for instanced rendering
255
- const useInstancing = model.get("use_instancing") && points.length > 100;
256
-
257
- if (useInstancing) {
258
- createInstancedPoints(points, {
259
- colorField, colorScale, computedColorDomain,
260
- sizeField, sizeRange, sizeDomain,
261
- shapeField, shapeMap
262
- });
263
- } else {
264
- createIndividualPoints(points, {
265
- colorField, colorScale, computedColorDomain,
266
- sizeField, sizeRange, sizeDomain,
267
- shapeField, shapeMap
268
- });
269
- }
270
- }
271
-
272
- function getPointColor(point, colorField, colorScale, colorDomain) {
273
- if (point.color) {
274
- return new THREE.Color(point.color);
275
- }
276
- if (colorField && point[colorField] !== undefined) {
277
- const value = point[colorField];
278
- if (typeof value === "number") {
279
- return getColorFromScale(value, colorScale, colorDomain);
280
- }
281
- return getCategoricalColor(value);
282
- }
283
- return new THREE.Color(0x6366f1);
284
- }
285
-
286
- function getPointSize(point, sizeField, sizeRange, sizeDomain) {
287
- if (point.size !== undefined) {
288
- return point.size;
289
- }
290
- if (sizeField && point[sizeField] !== undefined && sizeDomain) {
291
- const value = point[sizeField];
292
- const [min, max] = sizeDomain;
293
- const t = max > min ? (value - min) / (max - min) : 0.5;
294
- return sizeRange[0] + t * (sizeRange[1] - sizeRange[0]);
295
- }
296
- return sizeRange[0] + (sizeRange[1] - sizeRange[0]) * 0.5;
297
- }
298
-
299
- function getPointShape(point, shapeField, shapeMap) {
300
- if (point.shape && SHAPES[point.shape]) {
301
- return point.shape;
302
- }
303
- if (shapeField && point[shapeField] !== undefined) {
304
- const value = String(point[shapeField]);
305
- if (shapeMap[value] && SHAPES[shapeMap[value]]) {
306
- return shapeMap[value];
307
- }
308
- // Default shape rotation for unmapped categories
309
- const shapes = Object.keys(SHAPES);
310
- return shapes[hashString(value) % shapes.length];
311
- }
312
- return "sphere";
313
- }
314
-
315
- function createIndividualPoints(points, opts) {
316
- points.forEach((point, idx) => {
317
- const shape = getPointShape(point, opts.shapeField, opts.shapeMap);
318
- const geometry = SHAPES[shape]();
319
- const color = getPointColor(point, opts.colorField, opts.colorScale, opts.computedColorDomain);
320
- const material = new THREE.MeshPhongMaterial({ color });
321
- const mesh = new THREE.Mesh(geometry, material);
322
-
323
- const size = getPointSize(point, opts.sizeField, opts.sizeRange, opts.sizeDomain);
324
- mesh.scale.set(size, size, size);
325
- mesh.position.set(
326
- point.x ?? 0,
327
- point.y ?? 0,
328
- point.z ?? 0
329
- );
330
-
331
- mesh.userData = { pointIndex: idx, pointId: point.id || `point_${idx}` };
332
- pointsGroup.add(mesh);
333
- });
334
- }
335
-
336
- function createInstancedPoints(points, opts) {
337
- // Group by shape
338
- const groups = {};
339
- points.forEach((point, idx) => {
340
- const shape = getPointShape(point, opts.shapeField, opts.shapeMap);
341
- if (!groups[shape]) groups[shape] = [];
342
- groups[shape].push({ point, idx });
343
- });
344
-
345
- for (const [shape, items] of Object.entries(groups)) {
346
- const geometry = SHAPES[shape]();
347
- const material = new THREE.MeshPhongMaterial({ vertexColors: false });
348
- const instancedMesh = new THREE.InstancedMesh(geometry, material, items.length);
349
-
350
- const matrix = new THREE.Matrix4();
351
- const color = new THREE.Color();
352
- const colors = new Float32Array(items.length * 3);
353
-
354
- items.forEach(({ point, idx }, i) => {
355
- const size = getPointSize(point, opts.sizeField, opts.sizeRange, opts.sizeDomain);
356
- const pointColor = getPointColor(point, opts.colorField, opts.colorScale, opts.computedColorDomain);
357
-
358
- matrix.identity();
359
- matrix.makeScale(size, size, size);
360
- matrix.setPosition(point.x ?? 0, point.y ?? 0, point.z ?? 0);
361
- instancedMesh.setMatrixAt(i, matrix);
362
-
363
- colors[i * 3] = pointColor.r;
364
- colors[i * 3 + 1] = pointColor.g;
365
- colors[i * 3 + 2] = pointColor.b;
366
- });
367
-
368
- // Store color per instance using custom attribute
369
- geometry.setAttribute("color", new THREE.InstancedBufferAttribute(colors, 3));
370
- material.vertexColors = true;
371
-
372
- instancedMesh.instanceMatrix.needsUpdate = true;
373
- instancedMesh.userData = {
374
- isInstanced: true,
375
- pointIndices: items.map(({ idx }) => idx),
376
- pointIds: items.map(({ point, idx }) => point.id || `point_${idx}`)
377
- };
378
- pointsGroup.add(instancedMesh);
379
- }
380
- }
381
-
382
- function createConnections() {
383
- // Clear existing connections
384
- while (connectionsGroup.children.length > 0) {
385
- const obj = connectionsGroup.children[0];
386
- if (obj.geometry) obj.geometry.dispose();
387
- if (obj.material) obj.material.dispose();
388
- connectionsGroup.remove(obj);
389
- }
390
-
391
- const points = model.get("points") || [];
392
- const showConnections = model.get("show_connections");
393
- const kNeighbors = model.get("k_neighbors") || 0;
394
- const distanceThreshold = model.get("distance_threshold");
395
- const referencePoint = model.get("reference_point");
396
- const distanceMetric = model.get("distance_metric") || "euclidean";
397
- const connectionColor = model.get("connection_color") || "#ffffff";
398
- const connectionOpacity = model.get("connection_opacity") || 0.3;
399
-
400
- if (!showConnections || points.length < 2) return;
401
-
402
- const material = new THREE.LineBasicMaterial({
403
- color: new THREE.Color(connectionColor),
404
- transparent: true,
405
- opacity: connectionOpacity,
406
- });
407
-
408
- // If reference point is set, connect to k-nearest or within threshold
409
- if (referencePoint) {
410
- const refIdx = points.findIndex(p => p.id === referencePoint);
411
- if (refIdx === -1) return;
412
- const ref = points[refIdx];
413
-
414
- // Compute distances from reference
415
- const distances = points.map((p, i) => ({
416
- idx: i,
417
- point: p,
418
- dist: i === refIdx ? Infinity : computeDistance(ref, p, distanceMetric)
419
- })).filter(d => d.dist !== Infinity);
420
-
421
- // Sort by distance
422
- distances.sort((a, b) => a.dist - b.dist);
423
-
424
- // Select neighbors
425
- let neighbors;
426
- if (distanceThreshold !== null && distanceThreshold !== undefined) {
427
- neighbors = distances.filter(d => d.dist <= distanceThreshold);
428
- } else if (kNeighbors > 0) {
429
- neighbors = distances.slice(0, kNeighbors);
430
- } else {
431
- return;
432
- }
433
-
434
- // Draw lines
435
- neighbors.forEach(n => {
436
- const geometry = new THREE.BufferGeometry();
437
- const positions = new Float32Array([
438
- ref.x ?? 0, ref.y ?? 0, ref.z ?? 0,
439
- n.point.x ?? 0, n.point.y ?? 0, n.point.z ?? 0
440
- ]);
441
- geometry.setAttribute("position", new THREE.BufferAttribute(positions, 3));
442
- const line = new THREE.Line(geometry, material);
443
- connectionsGroup.add(line);
444
- });
445
- } else if (kNeighbors > 0) {
446
- // Connect each point to its k-nearest neighbors
447
- points.forEach((p, i) => {
448
- const distances = points.map((other, j) => ({
449
- idx: j,
450
- point: other,
451
- dist: i === j ? Infinity : computeDistance(p, other, distanceMetric)
452
- })).filter(d => d.dist !== Infinity);
453
-
454
- distances.sort((a, b) => a.dist - b.dist);
455
- const neighbors = distances.slice(0, kNeighbors);
456
-
457
- neighbors.forEach(n => {
458
- // Only draw if i < n.idx to avoid duplicate lines
459
- if (i < n.idx) {
460
- const geometry = new THREE.BufferGeometry();
461
- const positions = new Float32Array([
462
- p.x ?? 0, p.y ?? 0, p.z ?? 0,
463
- n.point.x ?? 0, n.point.y ?? 0, n.point.z ?? 0
464
- ]);
465
- geometry.setAttribute("position", new THREE.BufferAttribute(positions, 3));
466
- const line = new THREE.Line(geometry, material);
467
- connectionsGroup.add(line);
468
- }
469
- });
470
- });
471
- }
472
- }
473
-
474
- function setupRaycaster(container) {
475
- raycaster = new THREE.Raycaster();
476
- mouse = new THREE.Vector2();
477
-
478
- container.addEventListener("mousemove", onMouseMove);
479
- container.addEventListener("click", onClick);
480
- }
481
-
482
- function onMouseMove(event) {
483
- const rect = event.target.getBoundingClientRect();
484
- mouse.x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
485
- mouse.y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
12
+ from anywidget_vector.backends import is_python_backend
13
+ from anywidget_vector.backends.chroma.client import execute_query as chroma_query
14
+ from anywidget_vector.backends.grafeo.client import execute_query as grafeo_query
15
+ from anywidget_vector.backends.lancedb.client import execute_query as lancedb_query
16
+ from anywidget_vector.ui import get_css, get_esm
486
17
 
487
- raycaster.setFromCamera(mouse, camera);
488
- const intersects = raycaster.intersectObjects(pointsGroup.children, true);
489
-
490
- if (intersects.length > 0) {
491
- const hit = intersects[0];
492
- const points = model.get("points") || [];
493
- let pointIndex, pointId;
494
-
495
- if (hit.object.userData.isInstanced) {
496
- const instanceId = hit.instanceId;
497
- pointIndex = hit.object.userData.pointIndices[instanceId];
498
- pointId = hit.object.userData.pointIds[instanceId];
499
- } else {
500
- pointIndex = hit.object.userData.pointIndex;
501
- pointId = hit.object.userData.pointId;
502
- }
503
-
504
- const point = points[pointIndex];
505
- if (point && (!hoveredObject || hoveredObject.pointId !== pointId)) {
506
- hoveredObject = { pointIndex, pointId };
507
- model.set("hovered_point", point);
508
- model.save_changes();
509
- showTooltip(event, point);
510
- }
511
- } else if (hoveredObject) {
512
- hoveredObject = null;
513
- model.set("hovered_point", null);
514
- model.save_changes();
515
- hideTooltip();
516
- }
517
- }
518
-
519
- function onClick(event) {
520
- const rect = event.target.getBoundingClientRect();
521
- mouse.x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
522
- mouse.y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
523
-
524
- raycaster.setFromCamera(mouse, camera);
525
- const intersects = raycaster.intersectObjects(pointsGroup.children, true);
526
-
527
- if (intersects.length > 0) {
528
- const hit = intersects[0];
529
- const points = model.get("points") || [];
530
- let pointIndex, pointId;
531
-
532
- if (hit.object.userData.isInstanced) {
533
- const instanceId = hit.instanceId;
534
- pointIndex = hit.object.userData.pointIndices[instanceId];
535
- pointId = hit.object.userData.pointIds[instanceId];
536
- } else {
537
- pointIndex = hit.object.userData.pointIndex;
538
- pointId = hit.object.userData.pointId;
539
- }
540
-
541
- const point = points[pointIndex];
542
- const selectionMode = model.get("selection_mode") || "click";
543
- const currentSelection = model.get("selected_points") || [];
544
-
545
- if (selectionMode === "click") {
546
- model.set("selected_points", [pointId]);
547
- } else {
548
- // Toggle in multi-select mode
549
- if (currentSelection.includes(pointId)) {
550
- model.set("selected_points", currentSelection.filter(id => id !== pointId));
551
- } else {
552
- model.set("selected_points", [...currentSelection, pointId]);
553
- }
554
- }
555
- model.save_changes();
556
- } else {
557
- // Click on empty space - clear selection
558
- model.set("selected_points", []);
559
- model.save_changes();
560
- }
561
- }
562
-
563
- function setupTooltip(container) {
564
- tooltip = document.createElement("div");
565
- tooltip.className = "anywidget-vector-tooltip";
566
- tooltip.style.cssText = `
567
- position: absolute;
568
- display: none;
569
- background: rgba(0, 0, 0, 0.85);
570
- color: white;
571
- padding: 8px 12px;
572
- border-radius: 4px;
573
- font-size: 12px;
574
- pointer-events: none;
575
- z-index: 1000;
576
- max-width: 250px;
577
- box-shadow: 0 2px 8px rgba(0,0,0,0.3);
578
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
579
- `;
580
- container.appendChild(tooltip);
581
- }
582
-
583
- function showTooltip(event, point) {
584
- if (!model.get("show_tooltip")) return;
585
-
586
- const fields = model.get("tooltip_fields") || ["label", "x", "y", "z"];
587
- let html = "";
588
-
589
- if (point.label) {
590
- html += `<div style="font-weight: 600; margin-bottom: 4px;">${point.label}</div>`;
591
- }
592
-
593
- const rows = fields
594
- .filter(f => f !== "label" && point[f] !== undefined)
595
- .map(f => {
596
- let value = point[f];
597
- if (typeof value === "number") {
598
- value = value.toFixed(3);
599
- }
600
- return `<div style="display: flex; justify-content: space-between; gap: 12px;"><span style="color: #999;">${f}:</span><span>${value}</span></div>`;
601
- });
602
-
603
- html += rows.join("");
604
- tooltip.innerHTML = html;
605
- tooltip.style.display = "block";
606
-
607
- const rect = event.target.getBoundingClientRect();
608
- const x = event.clientX - rect.left + 15;
609
- const y = event.clientY - rect.top + 15;
610
- tooltip.style.left = x + "px";
611
- tooltip.style.top = y + "px";
612
- }
613
-
614
- function hideTooltip() {
615
- tooltip.style.display = "none";
616
- }
617
-
618
- function onCameraChange() {
619
- model.set("camera_position", [camera.position.x, camera.position.y, camera.position.z]);
620
- model.set("camera_target", [controls.target.x, controls.target.y, controls.target.z]);
621
- model.save_changes();
622
- }
623
-
624
- function bindModelEvents() {
625
- model.on("change:points", () => { createPoints(); createConnections(); });
626
- model.on("change:background", () => {
627
- scene.background = new THREE.Color(model.get("background"));
628
- });
629
- model.on("change:show_axes", setupAxesAndGrid);
630
- model.on("change:show_grid", setupAxesAndGrid);
631
- model.on("change:color_field", createPoints);
632
- model.on("change:color_scale", createPoints);
633
- model.on("change:color_domain", createPoints);
634
- model.on("change:size_field", createPoints);
635
- model.on("change:size_range", createPoints);
636
- model.on("change:shape_field", createPoints);
637
- model.on("change:shape_map", createPoints);
638
-
639
- // Distance/connection related
640
- model.on("change:show_connections", createConnections);
641
- model.on("change:k_neighbors", createConnections);
642
- model.on("change:distance_threshold", createConnections);
643
- model.on("change:reference_point", createConnections);
644
- model.on("change:distance_metric", createConnections);
645
- model.on("change:connection_color", createConnections);
646
- model.on("change:connection_opacity", createConnections);
647
-
648
- model.on("change:camera_position", () => {
649
- const pos = model.get("camera_position");
650
- if (pos) camera.position.set(pos[0], pos[1], pos[2]);
651
- });
652
- model.on("change:camera_target", () => {
653
- const target = model.get("camera_target");
654
- if (target) controls.target.set(target[0], target[1], target[2]);
655
- });
656
- }
657
-
658
- function animate() {
659
- animationId = requestAnimationFrame(animate);
660
- controls.update();
661
- renderer.render(scene, camera);
662
- }
663
-
664
- function cleanup() {
665
- cancelAnimationFrame(animationId);
666
- controls.dispose();
667
- renderer.dispose();
668
- scene.traverse((obj) => {
669
- if (obj.geometry) obj.geometry.dispose();
670
- if (obj.material) {
671
- if (Array.isArray(obj.material)) {
672
- obj.material.forEach(m => m.dispose());
673
- } else {
674
- obj.material.dispose();
675
- }
676
- }
677
- });
678
- }
679
-
680
- return cleanup;
681
- }
682
-
683
- export default { render };
684
- """
685
-
686
- _CSS = """
687
- .anywidget-vector {
688
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
689
- border-radius: 8px;
690
- overflow: hidden;
691
- }
692
- .anywidget-vector canvas {
693
- display: block;
694
- }
695
- """
18
+ if TYPE_CHECKING:
19
+ pass
696
20
 
697
21
 
698
22
  class VectorSpace(anywidget.AnyWidget):
699
- """Interactive 3D vector visualization widget using Three.js."""
23
+ """Interactive 3D vector visualization widget using Three.js.
700
24
 
701
- _esm = _ESM
702
- _css = _CSS
25
+ Supports multiple vector database backends with native query formats:
26
+ - Qdrant, Pinecone, Weaviate (browser-side REST)
27
+ - Chroma, LanceDB, Grafeo (Python-side)
28
+ """
703
29
 
704
- # Data
30
+ _esm = get_esm()
31
+ _css = get_css()
32
+
33
+ # === Data ===
705
34
  points = traitlets.List(trait=traitlets.Dict()).tag(sync=True)
706
35
 
707
- # Display
36
+ # === Display ===
708
37
  width = traitlets.Int(default_value=800).tag(sync=True)
709
38
  height = traitlets.Int(default_value=600).tag(sync=True)
710
39
  background = traitlets.Unicode(default_value="#1a1a2e").tag(sync=True)
711
40
 
712
- # Axes and grid
41
+ # === Axes and Grid ===
713
42
  show_axes = traitlets.Bool(default_value=True).tag(sync=True)
714
43
  show_grid = traitlets.Bool(default_value=True).tag(sync=True)
715
44
  axis_labels = traitlets.Dict(default_value={"x": "X", "y": "Y", "z": "Z"}).tag(sync=True)
716
45
  grid_divisions = traitlets.Int(default_value=10).tag(sync=True)
717
46
 
718
- # Color
47
+ # === Color Mapping ===
719
48
  color_field = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True)
720
49
  color_scale = traitlets.Unicode(default_value="viridis").tag(sync=True)
721
50
  color_domain = traitlets.List(default_value=None, allow_none=True).tag(sync=True)
722
51
 
723
- # Size
52
+ # === Size Mapping ===
724
53
  size_field = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True)
725
54
  size_range = traitlets.List(default_value=[0.02, 0.1]).tag(sync=True)
726
55
 
727
- # Shape
56
+ # === Shape Mapping ===
728
57
  shape_field = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True)
729
58
  shape_map = traitlets.Dict(default_value={}).tag(sync=True)
730
59
 
731
- # Camera
60
+ # === Camera ===
732
61
  camera_position = traitlets.List(default_value=[2, 2, 2]).tag(sync=True)
733
62
  camera_target = traitlets.List(default_value=[0, 0, 0]).tag(sync=True)
734
63
 
735
- # Interaction
64
+ # === Interaction ===
736
65
  selected_points = traitlets.List(default_value=[]).tag(sync=True)
737
66
  hovered_point = traitlets.Dict(default_value=None, allow_none=True).tag(sync=True)
738
67
  selection_mode = traitlets.Unicode(default_value="click").tag(sync=True)
739
68
 
740
- # Tooltip
69
+ # === Tooltip ===
741
70
  show_tooltip = traitlets.Bool(default_value=True).tag(sync=True)
742
71
  tooltip_fields = traitlets.List(default_value=["label", "x", "y", "z"]).tag(sync=True)
743
72
 
744
- # Performance
73
+ # === Performance ===
745
74
  use_instancing = traitlets.Bool(default_value=True).tag(sync=True)
746
75
  point_budget = traitlets.Int(default_value=100000).tag(sync=True)
747
76
 
748
- # Distance metrics and connections
749
- distance_metric = traitlets.Unicode(default_value="euclidean").tag(
750
- sync=True
751
- ) # euclidean, cosine, manhattan, dot_product
77
+ # === Distance and Connections ===
78
+ distance_metric = traitlets.Unicode(default_value="euclidean").tag(sync=True)
752
79
  show_connections = traitlets.Bool(default_value=False).tag(sync=True)
753
- k_neighbors = traitlets.Int(default_value=0).tag(sync=True) # k-nearest neighbors to connect
754
- distance_threshold = traitlets.Float(default_value=None, allow_none=True).tag(
755
- sync=True
756
- ) # max distance for connections
757
- reference_point = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True) # point ID to measure from
80
+ k_neighbors = traitlets.Int(default_value=0).tag(sync=True)
81
+ distance_threshold = traitlets.Float(default_value=None, allow_none=True).tag(sync=True)
82
+ reference_point = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True)
758
83
  connection_color = traitlets.Unicode(default_value="#ffffff").tag(sync=True)
759
84
  connection_opacity = traitlets.Float(default_value=0.3).tag(sync=True)
760
85
 
761
- def __init__(
762
- self,
763
- points: list[dict[str, Any]] | None = None,
764
- **kwargs: Any,
765
- ) -> None:
86
+ # === UI ===
87
+ show_toolbar = traitlets.Bool(default_value=False).tag(sync=True)
88
+ show_settings = traitlets.Bool(default_value=False).tag(sync=True)
89
+ show_properties = traitlets.Bool(default_value=False).tag(sync=True)
90
+
91
+ # === Backend ===
92
+ backend = traitlets.Unicode(default_value="qdrant").tag(sync=True)
93
+ backend_config = traitlets.Dict(default_value={}).tag(sync=True)
94
+ connection_status = traitlets.Unicode(default_value="disconnected").tag(sync=True)
95
+
96
+ # === Query (native format per backend) ===
97
+ query_input = traitlets.Unicode(default_value="").tag(sync=True)
98
+ query_error = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True)
99
+ _execute_query = traitlets.Int(default_value=0).tag(sync=True)
100
+
101
+ def __init__(self, points: list[dict[str, Any]] | None = None, **kwargs: Any) -> None:
766
102
  super().__init__(points=points or [], **kwargs)
767
- self._click_callbacks: list[Callable] = []
768
- self._hover_callbacks: list[Callable] = []
769
- self._selection_callbacks: list[Callable] = []
103
+ self._backend_client: Any = None
104
+ self.observe(self._on_execute_query, names=["_execute_query"])
105
+
106
+ # === Backend Configuration ===
107
+
108
+ def set_backend(self, backend: str, client: Any = None, **config: Any) -> VectorSpace:
109
+ """Configure backend for querying.
110
+
111
+ Args:
112
+ backend: Backend name (qdrant, pinecone, weaviate, chroma, lancedb, grafeo)
113
+ client: Client object for Python-side backends
114
+ **config: Connection config (url, apiKey, collection, etc.)
115
+
116
+ Returns:
117
+ Self for chaining
118
+ """
119
+ self.backend = backend
120
+ self._backend_client = client
121
+ self.backend_config = config
122
+ self.show_toolbar = True
123
+ self.show_settings = True
124
+ return self
125
+
126
+ def _on_execute_query(self, change: dict[str, Any]) -> None:
127
+ """Handle query execution for Python-side backends."""
128
+ if change["new"] == 0 or not is_python_backend(self.backend):
129
+ return
130
+ try:
131
+ self.connection_status = "connecting"
132
+ results = self._execute_python_query()
133
+ if results:
134
+ self.points = results
135
+ self.connection_status = "connected"
136
+ except Exception as e:
137
+ self.query_error = str(e)
138
+ self.connection_status = "error"
139
+
140
+ def _execute_python_query(self) -> list[dict[str, Any]]:
141
+ """Execute query using Python-side backend."""
142
+ if not self._backend_client:
143
+ raise ValueError("Backend not configured. Call set_backend() first.")
144
+
145
+ query = self.query_input
146
+
147
+ if self.backend == "chroma":
148
+ return chroma_query(self._backend_client, query)
149
+ elif self.backend == "lancedb":
150
+ return lancedb_query(self._backend_client, query)
151
+ elif self.backend == "grafeo":
152
+ return grafeo_query(self._backend_client, query)
153
+
154
+ raise ValueError(f"Unknown Python backend: {self.backend}")
770
155
 
771
156
  # === Factory Methods ===
772
157
 
@@ -787,441 +172,115 @@ class VectorSpace(anywidget.AnyWidget):
787
172
  positions: Any,
788
173
  *,
789
174
  ids: list[str] | None = None,
790
- colors: Any = None,
791
- sizes: Any = None,
792
175
  labels: list[str] | None = None,
793
- metadata: list[dict[str, Any]] | None = None,
794
176
  **kwargs: Any,
795
177
  ) -> VectorSpace:
796
- """Create from arrays of positions and optional attributes."""
178
+ """Create from arrays of positions."""
797
179
  pos_list = _to_list(positions)
798
- n = len(pos_list)
799
-
800
180
  points = []
801
- for i in range(n):
802
- point: dict[str, Any] = {
181
+ for i, p in enumerate(pos_list):
182
+ point = {
803
183
  "id": ids[i] if ids else f"point_{i}",
804
- "x": float(pos_list[i][0]),
805
- "y": float(pos_list[i][1]),
806
- "z": float(pos_list[i][2]) if len(pos_list[i]) > 2 else 0.0,
184
+ "x": float(p[0]),
185
+ "y": float(p[1]),
186
+ "z": float(p[2]) if len(p) > 2 else 0.0,
807
187
  }
808
- if colors is not None:
809
- color_list = _to_list(colors)
810
- point["color"] = color_list[i] if i < len(color_list) else None
811
- if sizes is not None:
812
- size_list = _to_list(sizes)
813
- point["size"] = float(size_list[i]) if i < len(size_list) else None
814
- if labels is not None and i < len(labels):
188
+ if labels:
815
189
  point["label"] = labels[i]
816
- if metadata is not None and i < len(metadata):
817
- point.update(metadata[i])
818
- points.append(point)
819
-
820
- return cls(points=points, **kwargs)
821
-
822
- @classmethod
823
- def from_numpy(
824
- cls,
825
- arr: Any,
826
- *,
827
- labels: list[str] | None = None,
828
- **kwargs: Any,
829
- ) -> VectorSpace:
830
- """Create from numpy array (N, 3) or (N, D)."""
831
- arr_list = _to_list(arr)
832
- return cls.from_arrays(arr_list, labels=labels, **kwargs)
833
-
834
- @classmethod
835
- def from_dataframe(
836
- cls,
837
- df: Any,
838
- *,
839
- x: str = "x",
840
- y: str = "y",
841
- z: str = "z",
842
- id_col: str | None = None,
843
- color_col: str | None = None,
844
- size_col: str | None = None,
845
- shape_col: str | None = None,
846
- label_col: str | None = None,
847
- include_cols: list[str] | None = None,
848
- **kwargs: Any,
849
- ) -> VectorSpace:
850
- """Create from pandas DataFrame with column mapping."""
851
- points = []
852
- for i, row in enumerate(df.to_dict("records")):
853
- point: dict[str, Any] = {
854
- "id": str(row[id_col]) if id_col and id_col in row else f"point_{i}",
855
- "x": float(row[x]) if x in row else 0.0,
856
- "y": float(row[y]) if y in row else 0.0,
857
- "z": float(row[z]) if z in row else 0.0,
858
- }
859
- if label_col and label_col in row:
860
- point["label"] = str(row[label_col])
861
- if color_col and color_col in row:
862
- point[color_col] = row[color_col]
863
- if size_col and size_col in row:
864
- point[size_col] = row[size_col]
865
- if shape_col and shape_col in row:
866
- point[shape_col] = row[shape_col]
867
- if include_cols:
868
- for col in include_cols:
869
- if col in row:
870
- point[col] = row[col]
871
190
  points.append(point)
872
-
873
- # Auto-set field mappings
874
- if color_col and "color_field" not in kwargs:
875
- kwargs["color_field"] = color_col
876
- if size_col and "size_field" not in kwargs:
877
- kwargs["size_field"] = size_col
878
- if shape_col and "shape_field" not in kwargs:
879
- kwargs["shape_field"] = shape_col
880
-
881
191
  return cls(points=points, **kwargs)
882
192
 
883
- # === Vector DB Adapters ===
884
-
885
193
  @classmethod
886
- def from_qdrant(
887
- cls,
888
- client: Any,
889
- collection: str,
890
- *,
891
- limit: int = 1000,
892
- with_vectors: bool = True,
893
- scroll_filter: Any = None,
894
- **kwargs: Any,
895
- ) -> VectorSpace:
896
- """Create from Qdrant collection."""
897
- records, _ = client.scroll(
898
- collection_name=collection,
899
- limit=limit,
900
- with_vectors=with_vectors,
901
- scroll_filter=scroll_filter,
902
- )
903
- points = []
904
- for record in records:
905
- vec = record.vector if hasattr(record, "vector") else None
906
- point: dict[str, Any] = {"id": str(record.id)}
907
- if vec and len(vec) >= 3:
908
- point["x"], point["y"], point["z"] = float(vec[0]), float(vec[1]), float(vec[2])
909
- if hasattr(record, "payload") and record.payload:
910
- point.update(record.payload)
911
- points.append(point)
194
+ def from_dataframe(cls, df: Any, *, x: str = "x", y: str = "y", z: str = "z", **kwargs: Any) -> VectorSpace:
195
+ """Create from pandas DataFrame."""
196
+ points = [
197
+ {"id": f"point_{i}", "x": float(row[x]), "y": float(row[y]), "z": float(row.get(z, 0)), **row}
198
+ for i, row in enumerate(df.to_dict("records"))
199
+ ]
912
200
  return cls(points=points, **kwargs)
913
201
 
914
- @classmethod
915
- def from_chroma(
916
- cls,
917
- collection: Any,
918
- *,
919
- n_results: int = 1000,
920
- where: dict[str, Any] | None = None,
921
- include: list[str] | None = None,
922
- **kwargs: Any,
923
- ) -> VectorSpace:
924
- """Create from ChromaDB collection."""
925
- include = include or ["embeddings", "metadatas"]
926
- result = collection.get(limit=n_results, where=where, include=include)
927
- points = []
928
- ids = result.get("ids", [])
929
- embeddings = result.get("embeddings", [])
930
- metadatas = result.get("metadatas", [])
931
-
932
- for i, id_ in enumerate(ids):
933
- point: dict[str, Any] = {"id": str(id_)}
934
- if embeddings and i < len(embeddings) and embeddings[i]:
935
- vec = embeddings[i]
936
- if len(vec) >= 3:
937
- point["x"], point["y"], point["z"] = float(vec[0]), float(vec[1]), float(vec[2])
938
- if metadatas and i < len(metadatas) and metadatas[i]:
939
- point.update(metadatas[i])
940
- points.append(point)
941
- return cls(points=points, **kwargs)
942
-
943
- @classmethod
944
- def from_lancedb(
945
- cls,
946
- table: Any,
947
- *,
948
- limit: int = 1000,
949
- **kwargs: Any,
950
- ) -> VectorSpace:
951
- """Create from LanceDB table."""
952
- df = table.to_pandas()
953
- if len(df) > limit:
954
- df = df.head(limit)
955
- return cls.from_dataframe(df, **kwargs)
956
-
957
- # === Dimensionality Reduction Adapters ===
958
-
959
- @classmethod
960
- def from_umap(
961
- cls,
962
- embedding: Any,
963
- *,
964
- labels: list[str] | None = None,
965
- metadata: list[dict[str, Any]] | None = None,
966
- **kwargs: Any,
967
- ) -> VectorSpace:
968
- """Create from UMAP embedding (N, 3)."""
969
- return cls.from_arrays(embedding, labels=labels, metadata=metadata, **kwargs)
970
-
971
- @classmethod
972
- def from_tsne(
973
- cls,
974
- embedding: Any,
975
- *,
976
- labels: list[str] | None = None,
977
- metadata: list[dict[str, Any]] | None = None,
978
- **kwargs: Any,
979
- ) -> VectorSpace:
980
- """Create from t-SNE embedding (N, 3)."""
981
- return cls.from_arrays(embedding, labels=labels, metadata=metadata, **kwargs)
982
-
983
- @classmethod
984
- def from_pca(
985
- cls,
986
- embedding: Any,
987
- *,
988
- labels: list[str] | None = None,
989
- explained_variance: list[float] | None = None,
990
- **kwargs: Any,
991
- ) -> VectorSpace:
992
- """Create from PCA embedding (N, 3) with optional variance info."""
993
- if explained_variance and len(explained_variance) >= 3:
994
- kwargs.setdefault(
995
- "axis_labels",
996
- {
997
- "x": f"PC1 ({explained_variance[0]:.1%})",
998
- "y": f"PC2 ({explained_variance[1]:.1%})",
999
- "z": f"PC3 ({explained_variance[2]:.1%})",
1000
- },
1001
- )
1002
- return cls.from_arrays(embedding, labels=labels, **kwargs)
1003
-
1004
- # === Event Callbacks ===
1005
-
1006
- def on_click(self, callback: Callable[[str, dict[str, Any]], None]) -> Callable:
1007
- """Register callback for point click: callback(point_id, point_data)."""
1008
- self._click_callbacks.append(callback)
1009
-
1010
- def observer(change: dict[str, Any]) -> None:
1011
- selected = change["new"]
1012
- if selected and len(selected) > 0:
1013
- point_id = selected[-1] if isinstance(selected, list) else selected
1014
- point_data = next((p for p in self.points if p.get("id") == point_id), {})
1015
- for cb in self._click_callbacks:
1016
- cb(point_id, point_data)
1017
-
1018
- self.observe(observer, names=["selected_points"])
1019
- return callback
1020
-
1021
- def on_hover(self, callback: Callable[[str | None, dict[str, Any] | None], None]) -> Callable:
1022
- """Register callback for hover: callback(point_id, point_data)."""
1023
- self._hover_callbacks.append(callback)
1024
-
1025
- def observer(change: dict[str, Any]) -> None:
1026
- point = change["new"]
1027
- if point:
1028
- for cb in self._hover_callbacks:
1029
- cb(point.get("id"), point)
1030
- else:
1031
- for cb in self._hover_callbacks:
1032
- cb(None, None)
1033
-
1034
- self.observe(observer, names=["hovered_point"])
1035
- return callback
1036
-
1037
- def on_selection(self, callback: Callable[[list[str], list[dict[str, Any]]], None]) -> Callable:
1038
- """Register callback for selection changes: callback(point_ids, points_data)."""
1039
- self._selection_callbacks.append(callback)
1040
-
1041
- def observer(change: dict[str, Any]) -> None:
1042
- point_ids = change["new"] or []
1043
- point_data = [p for p in self.points if p.get("id") in point_ids]
1044
- for cb in self._selection_callbacks:
1045
- cb(point_ids, point_data)
1046
-
1047
- self.observe(observer, names=["selected_points"])
1048
- return callback
1049
-
1050
- # === Camera Control ===
1051
-
1052
- def reset_camera(self) -> None:
1053
- """Reset camera to default position."""
1054
- self.camera_position = [2, 2, 2]
1055
- self.camera_target = [0, 0, 0]
1056
-
1057
- def focus_on(self, point_ids: list[str]) -> None:
1058
- """Focus camera on specific points."""
1059
- if not point_ids:
1060
- return
1061
- matching = [p for p in self.points if p.get("id") in point_ids]
1062
- if not matching:
1063
- return
1064
- cx = sum(p.get("x", 0) for p in matching) / len(matching)
1065
- cy = sum(p.get("y", 0) for p in matching) / len(matching)
1066
- cz = sum(p.get("z", 0) for p in matching) / len(matching)
1067
- self.camera_target = [cx, cy, cz]
1068
- self.camera_position = [cx + 1.5, cy + 1.5, cz + 1.5]
1069
-
1070
- # === Selection ===
1071
-
1072
- def select(self, point_ids: list[str]) -> None:
1073
- """Programmatically select points."""
1074
- self.selected_points = point_ids
1075
-
1076
- def clear_selection(self) -> None:
1077
- """Clear all selections."""
1078
- self.selected_points = []
1079
-
1080
202
  # === Distance Methods ===
1081
203
 
1082
- def compute_distances(
1083
- self,
1084
- reference_id: str,
1085
- metric: str | None = None,
1086
- vector_field: str = "vector",
1087
- ) -> dict[str, float]:
1088
- """Compute distances from a reference point to all other points.
1089
-
1090
- Args:
1091
- reference_id: ID of the reference point
1092
- metric: Distance metric (euclidean, cosine, manhattan, dot_product).
1093
- If None, uses self.distance_metric
1094
- vector_field: Field containing the full vector (for high-dim distance).
1095
- Falls back to x,y,z if not present.
1096
-
1097
- Returns:
1098
- Dict mapping point ID to distance from reference
1099
- """
204
+ def compute_distances(self, reference_id: str, metric: str | None = None) -> dict[str, float]:
205
+ """Compute distances from reference point to all others."""
1100
206
  metric = metric or self.distance_metric
1101
- ref_point = next((p for p in self.points if p.get("id") == reference_id), None)
1102
- if not ref_point:
207
+ ref = next((p for p in self.points if p.get("id") == reference_id), None)
208
+ if not ref:
1103
209
  return {}
1104
-
1105
- ref_vec = self._get_vector(ref_point, vector_field)
1106
- distances = {}
1107
-
1108
- for point in self.points:
1109
- if point.get("id") == reference_id:
1110
- continue
1111
- vec = self._get_vector(point, vector_field)
1112
- distances[point.get("id")] = self._compute_distance(ref_vec, vec, metric)
1113
-
1114
- return distances
210
+ return {p.get("id"): self._distance(ref, p, metric) for p in self.points if p.get("id") != reference_id}
1115
211
 
1116
212
  def find_neighbors(
1117
- self,
1118
- reference_id: str,
1119
- k: int | None = None,
1120
- threshold: float | None = None,
1121
- metric: str | None = None,
1122
- vector_field: str = "vector",
213
+ self, reference_id: str, k: int | None = None, threshold: float | None = None
1123
214
  ) -> list[tuple[str, float]]:
1124
- """Find nearest neighbors of a reference point.
1125
-
1126
- Args:
1127
- reference_id: ID of the reference point
1128
- k: Number of neighbors to return (if None, uses threshold)
1129
- threshold: Maximum distance (if None, uses k)
1130
- metric: Distance metric to use
1131
- vector_field: Field containing the full vector
1132
-
1133
- Returns:
1134
- List of (point_id, distance) tuples, sorted by distance
1135
- """
1136
- distances = self.compute_distances(reference_id, metric, vector_field)
1137
- sorted_distances = sorted(distances.items(), key=lambda x: x[1])
1138
-
215
+ """Find nearest neighbors of a reference point."""
216
+ distances = sorted(self.compute_distances(reference_id).items(), key=lambda x: x[1])
1139
217
  if threshold is not None:
1140
- return [(pid, d) for pid, d in sorted_distances if d <= threshold]
1141
- elif k is not None:
1142
- return sorted_distances[:k]
1143
- else:
1144
- return sorted_distances
1145
-
1146
- def color_by_distance(
1147
- self,
1148
- reference_id: str,
1149
- metric: str | None = None,
1150
- vector_field: str = "vector",
1151
- ) -> None:
1152
- """Color points by distance from a reference point.
1153
-
1154
- Updates points with a '_distance' field and sets color_field to use it.
1155
- """
1156
- distances = self.compute_distances(reference_id, metric, vector_field)
218
+ return [(pid, d) for pid, d in distances if d <= threshold]
219
+ return distances[:k] if k else distances
1157
220
 
1158
- # Update points with distance field
1159
- updated_points = []
1160
- for point in self.points:
1161
- point_copy = dict(point)
1162
- pid = point.get("id")
1163
- if pid == reference_id:
1164
- point_copy["_distance"] = 0.0
1165
- elif pid in distances:
1166
- point_copy["_distance"] = distances[pid]
1167
- updated_points.append(point_copy)
1168
-
1169
- self.points = updated_points
221
+ def color_by_distance(self, reference_id: str) -> None:
222
+ """Color points by distance from reference."""
223
+ distances = self.compute_distances(reference_id)
224
+ self.points = [{**p, "_distance": distances.get(p.get("id"), 0)} for p in self.points]
1170
225
  self.color_field = "_distance"
1171
226
  self.reference_point = reference_id
1172
227
 
1173
- def show_neighbors(
1174
- self,
1175
- reference_id: str,
1176
- k: int | None = None,
1177
- threshold: float | None = None,
1178
- ) -> None:
1179
- """Show connections to nearest neighbors of a reference point."""
228
+ def show_neighbors(self, reference_id: str, k: int | None = None, threshold: float | None = None) -> None:
229
+ """Show connections to nearest neighbors."""
1180
230
  self.reference_point = reference_id
1181
231
  self.show_connections = True
1182
- if k is not None:
232
+ if k:
1183
233
  self.k_neighbors = k
1184
- if threshold is not None:
234
+ if threshold:
1185
235
  self.distance_threshold = threshold
1186
236
 
1187
- def _get_vector(self, point: dict[str, Any], vector_field: str) -> list[float]:
1188
- """Extract vector from point, falling back to x,y,z."""
1189
- if vector_field in point and point[vector_field]:
1190
- vec = point[vector_field]
1191
- return list(vec) if hasattr(vec, "__iter__") else [vec]
1192
- return [point.get("x", 0), point.get("y", 0), point.get("z", 0)]
1193
-
1194
- def _compute_distance(self, v1: list[float], v2: list[float], metric: str) -> float:
1195
- """Compute distance between two vectors."""
1196
- import math
1197
-
1198
- # Ensure same length
1199
- n = min(len(v1), len(v2))
1200
- v1, v2 = v1[:n], v2[:n]
1201
-
237
+ def _distance(self, p1: dict, p2: dict, metric: str) -> float:
238
+ """Compute distance between two points."""
239
+ x1, y1, z1 = p1.get("x", 0), p1.get("y", 0), p1.get("z", 0)
240
+ x2, y2, z2 = p2.get("x", 0), p2.get("y", 0), p2.get("z", 0)
1202
241
  if metric == "euclidean":
1203
- return math.sqrt(sum((a - b) ** 2 for a, b in zip(v1, v2)))
242
+ return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2)
1204
243
  elif metric == "cosine":
1205
- dot = sum(a * b for a, b in zip(v1, v2))
1206
- mag1 = math.sqrt(sum(a * a for a in v1))
1207
- mag2 = math.sqrt(sum(b * b for b in v2))
1208
- if mag1 == 0 or mag2 == 0:
1209
- return 1.0
1210
- return 1 - (dot / (mag1 * mag2))
244
+ dot = x1 * x2 + y1 * y2 + z1 * z2
245
+ m1, m2 = math.sqrt(x1 * x1 + y1 * y1 + z1 * z1), math.sqrt(x2 * x2 + y2 * y2 + z2 * z2)
246
+ return 1 - (dot / (m1 * m2)) if m1 and m2 else 1
1211
247
  elif metric == "manhattan":
1212
- return sum(abs(a - b) for a, b in zip(v1, v2))
1213
- elif metric == "dot_product":
1214
- return -sum(a * b for a, b in zip(v1, v2))
1215
- else:
1216
- # Default to euclidean
1217
- return math.sqrt(sum((a - b) ** 2 for a, b in zip(v1, v2)))
248
+ return abs(x1 - x2) + abs(y1 - y2) + abs(z1 - z2)
249
+ return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2)
250
+
251
+ # === Camera ===
252
+
253
+ # === Selection ===
254
+
255
+ def select(self, point_ids: list[str]) -> None:
256
+ """Programmatically select points by ID."""
257
+ self.selected_points = point_ids
258
+
259
+ def clear_selection(self) -> None:
260
+ """Clear all selected points."""
261
+ self.selected_points = []
262
+
263
+ # === Camera ===
264
+
265
+ def reset_camera(self) -> None:
266
+ """Reset camera to default."""
267
+ self.camera_position = [2, 2, 2]
268
+ self.camera_target = [0, 0, 0]
269
+
270
+ def focus_on(self, point_ids: list[str]) -> None:
271
+ """Focus camera on specific points."""
272
+ pts = [p for p in self.points if p.get("id") in point_ids]
273
+ if pts:
274
+ cx = sum(p.get("x", 0) for p in pts) / len(pts)
275
+ cy = sum(p.get("y", 0) for p in pts) / len(pts)
276
+ cz = sum(p.get("z", 0) for p in pts) / len(pts)
277
+ self.camera_target = [cx, cy, cz]
278
+ self.camera_position = [cx + 1.5, cy + 1.5, cz + 1.5]
1218
279
 
1219
280
  # === Export ===
1220
281
 
1221
282
  def to_json(self) -> str:
1222
- """Export points data as JSON."""
1223
- import json
1224
-
283
+ """Export points as JSON."""
1225
284
  return json.dumps(self.points)
1226
285
 
1227
286
 
@@ -1236,7 +295,9 @@ def _normalize_points(data: list[Any]) -> list[dict[str, Any]]:
1236
295
  def _normalize_point(point: Any, index: int) -> dict[str, Any]:
1237
296
  """Convert a single point to standard format."""
1238
297
  if isinstance(point, dict):
1239
- return _ensure_point_id(point, index)
298
+ if "id" not in point:
299
+ point = {**point, "id": f"point_{index}"}
300
+ return point
1240
301
  if hasattr(point, "__iter__") and hasattr(point, "__len__") and len(point) >= 2:
1241
302
  return {
1242
303
  "id": f"point_{index}",
@@ -1247,13 +308,6 @@ def _normalize_point(point: Any, index: int) -> dict[str, Any]:
1247
308
  raise ValueError(f"Cannot normalize point: {point}")
1248
309
 
1249
310
 
1250
- def _ensure_point_id(point: dict[str, Any], index: int) -> dict[str, Any]:
1251
- """Ensure point has an ID."""
1252
- if "id" not in point:
1253
- point = {**point, "id": f"point_{index}"}
1254
- return point
1255
-
1256
-
1257
311
  def _to_list(obj: Any) -> list[Any]:
1258
312
  """Convert numpy arrays or other iterables to lists."""
1259
313
  if hasattr(obj, "tolist"):