rtmlib-ts 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. package/.gitattributes +1 -0
  2. package/README.md +202 -0
  3. package/dist/core/base.d.ts +20 -0
  4. package/dist/core/base.d.ts.map +1 -0
  5. package/dist/core/base.js +40 -0
  6. package/dist/core/file.d.ts +11 -0
  7. package/dist/core/file.d.ts.map +1 -0
  8. package/dist/core/file.js +111 -0
  9. package/dist/core/modelCache.d.ts +35 -0
  10. package/dist/core/modelCache.d.ts.map +1 -0
  11. package/dist/core/modelCache.js +161 -0
  12. package/dist/core/posePostprocessing.d.ts +12 -0
  13. package/dist/core/posePostprocessing.d.ts.map +1 -0
  14. package/dist/core/posePostprocessing.js +76 -0
  15. package/dist/core/postprocessing.d.ts +10 -0
  16. package/dist/core/postprocessing.d.ts.map +1 -0
  17. package/dist/core/postprocessing.js +70 -0
  18. package/dist/core/preprocessing.d.ts +14 -0
  19. package/dist/core/preprocessing.d.ts.map +1 -0
  20. package/dist/core/preprocessing.js +79 -0
  21. package/dist/index.d.ts +27 -0
  22. package/dist/index.d.ts.map +1 -0
  23. package/dist/index.js +31 -0
  24. package/dist/models/rtmpose.d.ts +25 -0
  25. package/dist/models/rtmpose.d.ts.map +1 -0
  26. package/dist/models/rtmpose.js +185 -0
  27. package/dist/models/rtmpose3d.d.ts +28 -0
  28. package/dist/models/rtmpose3d.d.ts.map +1 -0
  29. package/dist/models/rtmpose3d.js +184 -0
  30. package/dist/models/yolo12.d.ts +23 -0
  31. package/dist/models/yolo12.d.ts.map +1 -0
  32. package/dist/models/yolo12.js +165 -0
  33. package/dist/models/yolox.d.ts +18 -0
  34. package/dist/models/yolox.d.ts.map +1 -0
  35. package/dist/models/yolox.js +167 -0
  36. package/dist/solution/animalDetector.d.ts +229 -0
  37. package/dist/solution/animalDetector.d.ts.map +1 -0
  38. package/dist/solution/animalDetector.js +663 -0
  39. package/dist/solution/body.d.ts +16 -0
  40. package/dist/solution/body.d.ts.map +1 -0
  41. package/dist/solution/body.js +52 -0
  42. package/dist/solution/bodyWithFeet.d.ts +16 -0
  43. package/dist/solution/bodyWithFeet.d.ts.map +1 -0
  44. package/dist/solution/bodyWithFeet.js +52 -0
  45. package/dist/solution/customDetector.d.ts +137 -0
  46. package/dist/solution/customDetector.d.ts.map +1 -0
  47. package/dist/solution/customDetector.js +342 -0
  48. package/dist/solution/hand.d.ts +14 -0
  49. package/dist/solution/hand.d.ts.map +1 -0
  50. package/dist/solution/hand.js +20 -0
  51. package/dist/solution/index.d.ts +10 -0
  52. package/dist/solution/index.d.ts.map +1 -0
  53. package/dist/solution/index.js +9 -0
  54. package/dist/solution/objectDetector.d.ts +172 -0
  55. package/dist/solution/objectDetector.d.ts.map +1 -0
  56. package/dist/solution/objectDetector.js +606 -0
  57. package/dist/solution/pose3dDetector.d.ts +145 -0
  58. package/dist/solution/pose3dDetector.d.ts.map +1 -0
  59. package/dist/solution/pose3dDetector.js +611 -0
  60. package/dist/solution/poseDetector.d.ts +198 -0
  61. package/dist/solution/poseDetector.d.ts.map +1 -0
  62. package/dist/solution/poseDetector.js +622 -0
  63. package/dist/solution/poseTracker.d.ts +22 -0
  64. package/dist/solution/poseTracker.d.ts.map +1 -0
  65. package/dist/solution/poseTracker.js +106 -0
  66. package/dist/solution/wholebody.d.ts +19 -0
  67. package/dist/solution/wholebody.d.ts.map +1 -0
  68. package/dist/solution/wholebody.js +82 -0
  69. package/dist/solution/wholebody3d.d.ts +22 -0
  70. package/dist/solution/wholebody3d.d.ts.map +1 -0
  71. package/dist/solution/wholebody3d.js +75 -0
  72. package/dist/types/index.d.ts +52 -0
  73. package/dist/types/index.d.ts.map +1 -0
  74. package/dist/types/index.js +5 -0
  75. package/dist/visualization/draw.d.ts +57 -0
  76. package/dist/visualization/draw.d.ts.map +1 -0
  77. package/dist/visualization/draw.js +400 -0
  78. package/dist/visualization/skeleton/coco133.d.ts +350 -0
  79. package/dist/visualization/skeleton/coco133.d.ts.map +1 -0
  80. package/dist/visualization/skeleton/coco133.js +120 -0
  81. package/dist/visualization/skeleton/coco17.d.ts +180 -0
  82. package/dist/visualization/skeleton/coco17.d.ts.map +1 -0
  83. package/dist/visualization/skeleton/coco17.js +48 -0
  84. package/dist/visualization/skeleton/halpe26.d.ts +278 -0
  85. package/dist/visualization/skeleton/halpe26.d.ts.map +1 -0
  86. package/dist/visualization/skeleton/halpe26.js +70 -0
  87. package/dist/visualization/skeleton/hand21.d.ts +196 -0
  88. package/dist/visualization/skeleton/hand21.d.ts.map +1 -0
  89. package/dist/visualization/skeleton/hand21.js +51 -0
  90. package/dist/visualization/skeleton/index.d.ts +10 -0
  91. package/dist/visualization/skeleton/index.d.ts.map +1 -0
  92. package/dist/visualization/skeleton/index.js +9 -0
  93. package/dist/visualization/skeleton/openpose134.d.ts +357 -0
  94. package/dist/visualization/skeleton/openpose134.d.ts.map +1 -0
  95. package/dist/visualization/skeleton/openpose134.js +116 -0
  96. package/dist/visualization/skeleton/openpose18.d.ts +177 -0
  97. package/dist/visualization/skeleton/openpose18.d.ts.map +1 -0
  98. package/dist/visualization/skeleton/openpose18.js +47 -0
  99. package/docs/ANIMAL_DETECTOR.md +450 -0
  100. package/docs/CUSTOM_DETECTOR.md +568 -0
  101. package/docs/OBJECT_DETECTOR.md +373 -0
  102. package/docs/POSE3D_DETECTOR.md +458 -0
  103. package/docs/POSE_DETECTOR.md +442 -0
  104. package/examples/README.md +119 -0
  105. package/examples/index.html +746 -0
  106. package/package.json +51 -0
  107. package/playground/README.md +114 -0
  108. package/playground/app/favicon.ico +0 -0
  109. package/playground/app/globals.css +17 -0
  110. package/playground/app/layout.tsx +19 -0
  111. package/playground/app/page.tsx +1338 -0
  112. package/playground/eslint.config.mjs +18 -0
  113. package/playground/next.config.ts +34 -0
  114. package/playground/package-lock.json +6723 -0
  115. package/playground/package.json +27 -0
  116. package/playground/postcss.config.mjs +7 -0
  117. package/playground/tsconfig.json +34 -0
  118. package/src/core/base.ts +66 -0
  119. package/src/core/file.ts +141 -0
  120. package/src/core/modelCache.ts +189 -0
  121. package/src/core/posePostprocessing.ts +91 -0
  122. package/src/core/postprocessing.ts +93 -0
  123. package/src/core/preprocessing.ts +127 -0
  124. package/src/index.ts +69 -0
  125. package/src/models/rtmpose.ts +265 -0
  126. package/src/models/rtmpose3d.ts +289 -0
  127. package/src/models/yolo12.ts +220 -0
  128. package/src/models/yolox.ts +214 -0
  129. package/src/solution/animalDetector.ts +955 -0
  130. package/src/solution/body.ts +89 -0
  131. package/src/solution/bodyWithFeet.ts +89 -0
  132. package/src/solution/customDetector.ts +474 -0
  133. package/src/solution/hand.ts +52 -0
  134. package/src/solution/index.ts +10 -0
  135. package/src/solution/objectDetector.ts +816 -0
  136. package/src/solution/pose3dDetector.ts +890 -0
  137. package/src/solution/poseDetector.ts +892 -0
  138. package/src/solution/poseTracker.ts +172 -0
  139. package/src/solution/wholebody.ts +130 -0
  140. package/src/solution/wholebody3d.ts +125 -0
  141. package/src/types/index.ts +62 -0
  142. package/src/visualization/draw.ts +543 -0
  143. package/src/visualization/skeleton/coco133.ts +131 -0
  144. package/src/visualization/skeleton/coco17.ts +49 -0
  145. package/src/visualization/skeleton/halpe26.ts +71 -0
  146. package/src/visualization/skeleton/hand21.ts +52 -0
  147. package/src/visualization/skeleton/index.ts +10 -0
  148. package/src/visualization/skeleton/openpose134.ts +125 -0
  149. package/src/visualization/skeleton/openpose18.ts +48 -0
  150. package/tsconfig.json +32 -0
@@ -0,0 +1,955 @@
1
+ /**
2
+ * AnimalDetector - Animal detection and pose estimation API
3
+ * Supports 30 animal classes with ViTPose++ pose model
4
+ *
5
+ * @example
6
+ * ```typescript
7
+ * // Initialize with default models
8
+ * const detector = new AnimalDetector();
9
+ * await detector.init();
10
+ *
11
+ * // Detect animals
12
+ * const animals = await detector.detectFromCanvas(canvas);
13
+ * console.log(`Found ${animals.length} animals`);
14
+ *
15
+ * // With custom models
16
+ * const detector2 = new AnimalDetector({
17
+ * detModel: 'path/to/yolox_animal.onnx',
18
+ * poseModel: 'path/to/vitpose_animal.onnx',
19
+ * });
20
+ * ```
21
+ */
22
+
23
+ import * as ort from 'onnxruntime-web';
24
+ import { getCachedModel, isModelCached } from '../core/modelCache';
25
+
26
+ // Configure ONNX Runtime Web
27
+ ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/';
28
+ ort.env.wasm.simd = true;
29
+ ort.env.wasm.proxy = false;
30
+
31
+ /**
32
+ * 30 Animal class names supported by AnimalDetector
33
+ */
34
+ export const ANIMAL_CLASSES: string[] = [
35
+ 'gorilla',
36
+ 'spider-monkey',
37
+ 'howling-monkey',
38
+ 'zebra',
39
+ 'elephant',
40
+ 'hippo',
41
+ 'raccon',
42
+ 'rhino',
43
+ 'giraffe',
44
+ 'tiger',
45
+ 'deer',
46
+ 'lion',
47
+ 'panda',
48
+ 'cheetah',
49
+ 'black-bear',
50
+ 'polar-bear',
51
+ 'antelope',
52
+ 'fox',
53
+ 'buffalo',
54
+ 'cow',
55
+ 'wolf',
56
+ 'dog',
57
+ 'sheep',
58
+ 'cat',
59
+ 'horse',
60
+ 'rabbit',
61
+ 'pig',
62
+ 'chimpanzee',
63
+ 'monkey',
64
+ 'orangutan',
65
+ ];
66
+
67
+ /**
68
+ * Available ViTPose++ models for animal pose estimation
69
+ * All models are trained on 6 datasets and support 30 animal classes
70
+ */
71
+ export const VITPOSE_MODELS = {
72
+ /** ViTPose++-s: Fastest, 74.2 AP on AP10K */
73
+ 'vitpose-s': {
74
+ name: 'ViTPose++-s',
75
+ url: 'https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/onnx/apt36k/vitpose-s-apt36k.onnx',
76
+ inputSize: [256, 192] as [number, number],
77
+ ap: 74.2,
78
+ description: 'Fastest inference, suitable for real-time applications',
79
+ },
80
+ /** ViTPose++-b: Balanced, 75.9 AP on AP10K */
81
+ 'vitpose-b': {
82
+ name: 'ViTPose++-b',
83
+ url: 'https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/onnx/apt36k/vitpose-b-apt36k.onnx',
84
+ inputSize: [256, 192] as [number, number],
85
+ ap: 75.9,
86
+ description: 'Balanced speed and accuracy',
87
+ },
88
+ /** ViTPose++-l: Most accurate, 80.8 AP on AP10K */
89
+ 'vitpose-l': {
90
+ name: 'ViTPose++-l',
91
+ url: 'https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/onnx/apt36k/vitpose-h-apt36k.onnx',
92
+ inputSize: [256, 192] as [number, number],
93
+ ap: 80.8,
94
+ description: 'Highest accuracy, slower inference',
95
+ },
96
+ } as const;
97
+
98
+ export type VitPoseModelType = keyof typeof VITPOSE_MODELS;
99
+
100
+ /**
101
+ * Configuration options for AnimalDetector
102
+ */
103
+ export interface AnimalDetectorConfig {
104
+ /** Path to animal detection model (optional - uses default if not specified) */
105
+ detModel?: string;
106
+ /** Path to animal pose estimation model (optional - uses default if not specified) */
107
+ poseModel?: string;
108
+ /** ViTPose++ model variant (optional - uses vitpose-b if not specified) */
109
+ poseModelType?: VitPoseModelType;
110
+ /** Detection input size (default: [640, 640]) */
111
+ detInputSize?: [number, number];
112
+ /** Pose input size (default: [256, 192]) */
113
+ poseInputSize?: [number, number];
114
+ /** Detection confidence threshold (default: 0.5) */
115
+ detConfidence?: number;
116
+ /** NMS IoU threshold (default: 0.45) */
117
+ nmsThreshold?: number;
118
+ /** Pose keypoint confidence threshold (default: 0.3) */
119
+ poseConfidence?: number;
120
+ /** Execution backend (default: 'wasm') */
121
+ backend?: 'wasm' | 'webgpu';
122
+ /** Enable model caching (default: true) */
123
+ cache?: boolean;
124
+ /** Animal classes to detect (null = all) */
125
+ classes?: string[] | null;
126
+ }
127
+
128
+ /**
129
+ * Detected animal with bounding box and keypoints
130
+ */
131
+ export interface DetectedAnimal {
132
+ /** Bounding box coordinates */
133
+ bbox: {
134
+ x1: number;
135
+ y1: number;
136
+ x2: number;
137
+ y2: number;
138
+ confidence: number;
139
+ };
140
+ /** Animal class ID */
141
+ classId: number;
142
+ /** Animal class name */
143
+ className: string;
144
+ /** 17 keypoints (COCO format) */
145
+ keypoints: AnimalKeypoint[];
146
+ /** Keypoint scores (0-1) */
147
+ scores: number[];
148
+ }
149
+
150
+ /**
151
+ * Single keypoint with coordinates and visibility
152
+ */
153
+ export interface AnimalKeypoint {
154
+ x: number;
155
+ y: number;
156
+ score: number;
157
+ visible: boolean;
158
+ name: string;
159
+ }
160
+
161
+ /**
162
+ * Detection statistics
163
+ */
164
+ export interface AnimalDetectionStats {
165
+ /** Number of detected animals */
166
+ animalCount: number;
167
+ /** Detections per class */
168
+ classCounts: Record<string, number>;
169
+ /** Detection inference time (ms) */
170
+ detTime: number;
171
+ /** Pose estimation time (ms) */
172
+ poseTime: number;
173
+ /** Total processing time (ms) */
174
+ totalTime: number;
175
+ }
176
+
177
+ /**
178
+ * COCO17 keypoint names (used for animal pose)
179
+ */
180
+ const KEYPOINT_NAMES = [
181
+ 'nose',
182
+ 'left_eye',
183
+ 'right_eye',
184
+ 'left_ear',
185
+ 'right_ear',
186
+ 'left_shoulder',
187
+ 'right_shoulder',
188
+ 'left_elbow',
189
+ 'right_elbow',
190
+ 'left_wrist',
191
+ 'right_wrist',
192
+ 'left_hip',
193
+ 'right_hip',
194
+ 'left_knee',
195
+ 'right_knee',
196
+ 'left_ankle',
197
+ 'right_ankle',
198
+ ];
199
+
200
+ /**
201
+ * Default configuration - uses ViTPose++-b model
202
+ */
203
+ const DEFAULT_CONFIG: Required<Omit<AnimalDetectorConfig, 'poseModel' | 'poseModelType'>> & {
204
+ poseModel?: string;
205
+ poseModelType: VitPoseModelType;
206
+ } = {
207
+ detModel: 'https://huggingface.co/demon2233/rtmlib-ts/resolve/main/yolo/yolov12n.onnx',
208
+ poseModel: undefined, // Will be set from poseModelType
209
+ poseModelType: 'vitpose-b',
210
+ detInputSize: [640, 640],
211
+ poseInputSize: [256, 192],
212
+ detConfidence: 0.5,
213
+ nmsThreshold: 0.45,
214
+ poseConfidence: 0.3,
215
+ backend: 'webgpu', // Default to WebGPU for better performance
216
+ cache: true,
217
+ classes: null,
218
+ };
219
+
220
+ export class AnimalDetector {
221
+ private config: Required<AnimalDetectorConfig>;
222
+ private detSession: ort.InferenceSession | null = null;
223
+ private poseSession: ort.InferenceSession | null = null;
224
+ private initialized = false;
225
+ private classFilter: Set<number> | null = null;
226
+
227
+ // Pre-allocated buffers
228
+ private canvas: HTMLCanvasElement | null = null;
229
+ private ctx: CanvasRenderingContext2D | null = null;
230
+ private poseCanvas: HTMLCanvasElement | null = null;
231
+ private poseCtx: CanvasRenderingContext2D | null = null;
232
+ private poseTensorBuffer: Float32Array | null = null;
233
+ private detInputSize: [number, number] = [640, 640];
234
+ private poseInputSize: [number, number] = [256, 192];
235
+
236
+ constructor(config: AnimalDetectorConfig = {}) {
237
+ // Resolve pose model URL from poseModelType if poseModel not explicitly provided
238
+ let finalConfig = { ...DEFAULT_CONFIG, ...config };
239
+
240
+ if (!config.poseModel && config.poseModelType) {
241
+ const vitposeModel = VITPOSE_MODELS[config.poseModelType];
242
+ finalConfig.poseModel = vitposeModel.url;
243
+ finalConfig.poseInputSize = vitposeModel.inputSize;
244
+ } else if (!config.poseModel && !config.poseModelType) {
245
+ // Use default vitpose-b
246
+ finalConfig.poseModel = VITPOSE_MODELS['vitpose-b'].url;
247
+ finalConfig.poseInputSize = VITPOSE_MODELS['vitpose-b'].inputSize;
248
+ }
249
+
250
+ this.config = finalConfig as Required<AnimalDetectorConfig>;
251
+ this.updateClassFilter();
252
+ }
253
+
254
+ /**
255
+ * Update class filter based on config
256
+ */
257
+ private updateClassFilter(): void {
258
+ if (!this.config.classes) {
259
+ this.classFilter = null;
260
+ return;
261
+ }
262
+
263
+ this.classFilter = new Set<number>();
264
+ this.config.classes.forEach((className) => {
265
+ const classId = ANIMAL_CLASSES.indexOf(className.toLowerCase());
266
+ if (classId !== -1) {
267
+ this.classFilter!.add(classId);
268
+ } else {
269
+ console.warn(`[AnimalDetector] Unknown class: ${className}`);
270
+ }
271
+ });
272
+ }
273
+
274
+ /**
275
+ * Set which animal classes to detect
276
+ */
277
+ setClasses(classes: string[] | null): void {
278
+ this.config.classes = classes;
279
+ this.updateClassFilter();
280
+ }
281
+
282
+ /**
283
+ * Get list of available animal classes
284
+ */
285
+ getAvailableClasses(): string[] {
286
+ return [...ANIMAL_CLASSES];
287
+ }
288
+
289
+ /**
290
+ * Get information about the current ViTPose++ model
291
+ */
292
+ getPoseModelInfo() {
293
+ const modelType = (this.config as any).poseModelType as VitPoseModelType;
294
+ if (modelType && VITPOSE_MODELS[modelType]) {
295
+ return VITPOSE_MODELS[modelType];
296
+ }
297
+ return null;
298
+ }
299
+
300
+ /**
301
+ * Initialize both detection and pose models
302
+ */
303
+ async init(): Promise<void> {
304
+ if (this.initialized) return;
305
+
306
+ try {
307
+ // Load detection model
308
+ console.log(`[AnimalDetector] Loading detection model from: ${this.config.detModel}`);
309
+ let detBuffer: ArrayBuffer;
310
+
311
+ if (this.config.cache) {
312
+ const detCached = await isModelCached(this.config.detModel);
313
+ console.log(`[AnimalDetector] Det model cache ${detCached ? 'hit' : 'miss'}`);
314
+ detBuffer = await getCachedModel(this.config.detModel);
315
+ } else {
316
+ const detResponse = await fetch(this.config.detModel);
317
+ if (!detResponse.ok) {
318
+ throw new Error(`Failed to fetch det model: HTTP ${detResponse.status}`);
319
+ }
320
+ detBuffer = await detResponse.arrayBuffer();
321
+ }
322
+
323
+ this.detSession = await ort.InferenceSession.create(detBuffer, {
324
+ executionProviders: [this.config.backend],
325
+ graphOptimizationLevel: 'all',
326
+ });
327
+ console.log(`[AnimalDetector] Detection model loaded, size: ${(detBuffer.byteLength / 1024 / 1024).toFixed(2)} MB`);
328
+
329
+ // Load pose model
330
+ console.log(`[AnimalDetector] Loading pose model from: ${this.config.poseModel}`);
331
+ let poseBuffer: ArrayBuffer;
332
+
333
+ if (this.config.cache) {
334
+ const poseCached = await isModelCached(this.config.poseModel);
335
+ console.log(`[AnimalDetector] Pose model cache ${poseCached ? 'hit' : 'miss'}`);
336
+ poseBuffer = await getCachedModel(this.config.poseModel);
337
+ } else {
338
+ const poseResponse = await fetch(this.config.poseModel);
339
+ if (!poseResponse.ok) {
340
+ throw new Error(`Failed to fetch pose model: HTTP ${poseResponse.status}`);
341
+ }
342
+ poseBuffer = await poseResponse.arrayBuffer();
343
+ }
344
+
345
+ this.poseSession = await ort.InferenceSession.create(poseBuffer, {
346
+ executionProviders: [this.config.backend],
347
+ graphOptimizationLevel: 'all',
348
+ });
349
+ console.log(`[AnimalDetector] Pose model loaded, size: ${(poseBuffer.byteLength / 1024 / 1024).toFixed(2)} MB`);
350
+
351
+ // Pre-allocate resources
352
+ const [detW, detH] = this.config.detInputSize;
353
+ this.detInputSize = [detW, detH];
354
+
355
+ const [poseW, poseH] = this.config.poseInputSize;
356
+ this.poseInputSize = [poseW, poseH];
357
+
358
+ // Main canvas for detection
359
+ this.canvas = document.createElement('canvas');
360
+ this.canvas.width = detW;
361
+ this.canvas.height = detH;
362
+ this.ctx = this.canvas.getContext('2d', {
363
+ willReadFrequently: true,
364
+ alpha: false
365
+ })!;
366
+
367
+ // Pose crop canvas
368
+ this.poseCanvas = document.createElement('canvas');
369
+ this.poseCanvas.width = poseW;
370
+ this.poseCanvas.height = poseH;
371
+ this.poseCtx = this.poseCanvas.getContext('2d', {
372
+ willReadFrequently: true,
373
+ alpha: false
374
+ })!;
375
+
376
+ // Pre-allocate pose tensor buffer
377
+ this.poseTensorBuffer = new Float32Array(3 * poseW * poseH);
378
+
379
+ this.initialized = true;
380
+ console.log(`[AnimalDetector] ✅ Initialized (det:${detW}x${detH}, pose:${poseW}x${poseH})`);
381
+ } catch (error) {
382
+ console.error('[AnimalDetector] ❌ Initialization failed:', error);
383
+ throw error;
384
+ }
385
+ }
386
+
387
+ /**
388
+ * Detect animals from HTMLCanvasElement
389
+ */
390
+ async detectFromCanvas(canvas: HTMLCanvasElement): Promise<DetectedAnimal[]> {
391
+ const ctx = canvas.getContext('2d');
392
+ if (!ctx) {
393
+ throw new Error('Could not get 2D context from canvas');
394
+ }
395
+
396
+ const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
397
+ return this.detect(new Uint8Array(imageData.data.buffer), canvas.width, canvas.height);
398
+ }
399
+
400
+ /**
401
+ * Detect animals from HTMLVideoElement
402
+ */
403
+ async detectFromVideo(
404
+ video: HTMLVideoElement,
405
+ targetCanvas?: HTMLCanvasElement
406
+ ): Promise<DetectedAnimal[]> {
407
+ if (video.readyState < 2) {
408
+ throw new Error('Video not ready. Ensure video is loaded and playing.');
409
+ }
410
+
411
+ const canvas = targetCanvas || document.createElement('canvas');
412
+ canvas.width = video.videoWidth;
413
+ canvas.height = video.videoHeight;
414
+
415
+ const ctx = canvas.getContext('2d');
416
+ if (!ctx) {
417
+ throw new Error('Could not get 2D context from canvas');
418
+ }
419
+
420
+ ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
421
+ const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
422
+
423
+ return this.detect(new Uint8Array(imageData.data.buffer), canvas.width, canvas.height);
424
+ }
425
+
426
+ /**
427
+ * Detect animals from HTMLImageElement
428
+ */
429
+ async detectFromImage(
430
+ image: HTMLImageElement,
431
+ targetCanvas?: HTMLCanvasElement
432
+ ): Promise<DetectedAnimal[]> {
433
+ if (!image.complete || !image.naturalWidth) {
434
+ throw new Error('Image not loaded. Ensure image is fully loaded.');
435
+ }
436
+
437
+ const canvas = targetCanvas || document.createElement('canvas');
438
+ canvas.width = image.naturalWidth;
439
+ canvas.height = image.naturalHeight;
440
+
441
+ const ctx = canvas.getContext('2d');
442
+ if (!ctx) {
443
+ throw new Error('Could not get 2D context from canvas');
444
+ }
445
+
446
+ ctx.drawImage(image, 0, 0);
447
+ const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
448
+
449
+ return this.detect(new Uint8Array(imageData.data.buffer), canvas.width, canvas.height);
450
+ }
451
+
452
+ /**
453
+ * Detect animals from ImageBitmap
454
+ */
455
+ async detectFromBitmap(
456
+ bitmap: ImageBitmap,
457
+ targetCanvas?: HTMLCanvasElement
458
+ ): Promise<DetectedAnimal[]> {
459
+ const canvas = targetCanvas || document.createElement('canvas');
460
+ canvas.width = bitmap.width;
461
+ canvas.height = bitmap.height;
462
+
463
+ const ctx = canvas.getContext('2d');
464
+ if (!ctx) {
465
+ throw new Error('Could not get 2D context from canvas');
466
+ }
467
+
468
+ ctx.drawImage(bitmap, 0, 0);
469
+ const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
470
+
471
+ return this.detect(new Uint8Array(imageData.data.buffer), canvas.width, canvas.height);
472
+ }
473
+
474
+ /**
475
+ * Detect animals from File
476
+ */
477
+ async detectFromFile(
478
+ file: File,
479
+ targetCanvas?: HTMLCanvasElement
480
+ ): Promise<DetectedAnimal[]> {
481
+ return new Promise((resolve, reject) => {
482
+ const img = new Image();
483
+ img.onload = async () => {
484
+ try {
485
+ const results = await this.detectFromImage(img, targetCanvas);
486
+ resolve(results);
487
+ } catch (error) {
488
+ reject(error);
489
+ }
490
+ };
491
+ img.onerror = () => reject(new Error('Failed to load image from file'));
492
+ img.src = URL.createObjectURL(file);
493
+ });
494
+ }
495
+
496
+ /**
497
+ * Detect animals from Blob
498
+ */
499
+ async detectFromBlob(
500
+ blob: Blob,
501
+ targetCanvas?: HTMLCanvasElement
502
+ ): Promise<DetectedAnimal[]> {
503
+ const bitmap = await createImageBitmap(blob);
504
+ const results = await this.detectFromBitmap(bitmap, targetCanvas);
505
+ bitmap.close();
506
+ return results;
507
+ }
508
+
509
+ /**
510
+ * Detect animals from raw image data
511
+ */
512
+ async detect(
513
+ imageData: Uint8Array,
514
+ width: number,
515
+ height: number
516
+ ): Promise<DetectedAnimal[]> {
517
+ if (!this.initialized) {
518
+ await this.init();
519
+ }
520
+
521
+ const startTime = performance.now();
522
+
523
+ // Step 1: Detect animals
524
+ const detStart = performance.now();
525
+ const detections = await this.detectAnimals(imageData, width, height);
526
+ const detTime = performance.now() - detStart;
527
+
528
+ // Step 2: Estimate poses for each animal
529
+ const poseStart = performance.now();
530
+ const animals: DetectedAnimal[] = [];
531
+
532
+ for (const det of detections) {
533
+ const keypoints = await this.estimatePose(imageData, width, height, det.bbox);
534
+ animals.push({
535
+ bbox: det.bbox,
536
+ classId: det.classId,
537
+ className: det.className,
538
+ keypoints,
539
+ scores: keypoints.map(k => k.score),
540
+ });
541
+ }
542
+
543
+ const poseTime = performance.now() - poseStart;
544
+ const totalTime = performance.now() - startTime;
545
+
546
+ // Calculate stats
547
+ const classCounts: Record<string, number> = {};
548
+ animals.forEach(animal => {
549
+ classCounts[animal.className] = (classCounts[animal.className] || 0) + 1;
550
+ });
551
+
552
+ // Attach stats
553
+ (animals as any).stats = {
554
+ animalCount: animals.length,
555
+ classCounts,
556
+ detTime: Math.round(detTime),
557
+ poseTime: Math.round(poseTime),
558
+ totalTime: Math.round(totalTime),
559
+ } as AnimalDetectionStats;
560
+
561
+ return animals;
562
+ }
563
+
564
+ /**
565
+ * Detect animals using YOLO
566
+ */
567
+ private async detectAnimals(
568
+ imageData: Uint8Array,
569
+ width: number,
570
+ height: number
571
+ ): Promise<Array<{
572
+ bbox: { x1: number; y1: number; x2: number; y2: number; confidence: number };
573
+ classId: number;
574
+ className: string;
575
+ }>> {
576
+ const [inputH, inputW] = this.config.detInputSize;
577
+
578
+ const { tensor, paddingX, paddingY, scaleX, scaleY } = this.preprocessYOLO(
579
+ imageData,
580
+ width,
581
+ height,
582
+ [inputW, inputH]
583
+ );
584
+
585
+ const inputTensor = new ort.Tensor('float32', tensor, [1, 3, inputH, inputW]);
586
+ const inputName = this.detSession!.inputNames[0];
587
+
588
+ const feeds: Record<string, ort.Tensor> = {};
589
+ feeds[inputName] = inputTensor;
590
+
591
+ const results = await this.detSession!.run(feeds);
592
+ const output = results[this.detSession!.outputNames[0]];
593
+
594
+ return this.postprocessYOLO(
595
+ output.data as Float32Array,
596
+ output.dims[1],
597
+ width,
598
+ height,
599
+ paddingX,
600
+ paddingY,
601
+ scaleX,
602
+ scaleY
603
+ );
604
+ }
605
+
606
+ /**
607
+ * Estimate pose for a single animal
608
+ */
609
+ private async estimatePose(
610
+ imageData: Uint8Array,
611
+ imgWidth: number,
612
+ imgHeight: number,
613
+ bbox: { x1: number; y1: number; x2: number; y2: number; confidence: number }
614
+ ): Promise<AnimalKeypoint[]> {
615
+ const [inputH, inputW] = this.config.poseInputSize;
616
+
617
+ const { tensor, center, scale } = this.preprocessPose(
618
+ imageData,
619
+ imgWidth,
620
+ imgHeight,
621
+ bbox,
622
+ [inputW, inputH]
623
+ );
624
+
625
+ const inputTensor = new ort.Tensor('float32', tensor, [1, 3, inputH, inputW]);
626
+ const results = await this.poseSession!.run({ input: inputTensor });
627
+
628
+ return this.postprocessPose(
629
+ results.simcc_x.data as Float32Array,
630
+ results.simcc_y.data as Float32Array,
631
+ results.simcc_x.dims as number[],
632
+ results.simcc_y.dims as number[],
633
+ center,
634
+ scale
635
+ );
636
+ }
637
+
638
+ private preprocessYOLO(
639
+ imageData: Uint8Array,
640
+ imgWidth: number,
641
+ imgHeight: number,
642
+ inputSize: [number, number]
643
+ ): {
644
+ tensor: Float32Array;
645
+ paddingX: number;
646
+ paddingY: number;
647
+ scaleX: number;
648
+ scaleY: number;
649
+ } {
650
+ const [inputW, inputH] = inputSize;
651
+
652
+ if (!this.canvas || !this.ctx) {
653
+ this.canvas = document.createElement('canvas');
654
+ this.canvas.width = inputW;
655
+ this.canvas.height = inputH;
656
+ this.ctx = this.canvas.getContext('2d', { willReadFrequently: true, alpha: false })!;
657
+ }
658
+
659
+ const ctx = this.ctx;
660
+ ctx.fillStyle = '#000000';
661
+ ctx.fillRect(0, 0, inputW, inputH);
662
+
663
+ const aspectRatio = imgWidth / imgHeight;
664
+ const targetAspectRatio = inputW / inputH;
665
+
666
+ let drawWidth: number, drawHeight: number, offsetX: number, offsetY: number;
667
+
668
+ if (aspectRatio > targetAspectRatio) {
669
+ drawWidth = inputW;
670
+ drawHeight = Math.floor(inputW / aspectRatio);
671
+ offsetX = 0;
672
+ offsetY = Math.floor((inputH - drawHeight) / 2);
673
+ } else {
674
+ drawHeight = inputH;
675
+ drawWidth = Math.floor(inputH * aspectRatio);
676
+ offsetX = Math.floor((inputW - drawWidth) / 2);
677
+ offsetY = 0;
678
+ }
679
+
680
+ const srcCanvas = document.createElement('canvas');
681
+ const srcCtx = srcCanvas.getContext('2d')!;
682
+ srcCanvas.width = imgWidth;
683
+ srcCanvas.height = imgHeight;
684
+
685
+ const srcImageData = srcCtx.createImageData(imgWidth, imgHeight);
686
+ srcImageData.data.set(imageData);
687
+ srcCtx.putImageData(srcImageData, 0, 0);
688
+
689
+ ctx.drawImage(srcCanvas, 0, 0, imgWidth, imgHeight, offsetX, offsetY, drawWidth, drawHeight);
690
+
691
+ const paddedData = ctx.getImageData(0, 0, inputW, inputH);
692
+ const tensor = new Float32Array(inputW * inputH * 3);
693
+
694
+ for (let i = 0; i < paddedData.data.length; i += 4) {
695
+ const pixelIdx = i / 4;
696
+ tensor[pixelIdx] = paddedData.data[i] / 255;
697
+ tensor[pixelIdx + inputW * inputH] = paddedData.data[i + 1] / 255;
698
+ tensor[pixelIdx + 2 * inputW * inputH] = paddedData.data[i + 2] / 255;
699
+ }
700
+
701
+ const scaleX = imgWidth / drawWidth;
702
+ const scaleY = imgHeight / drawHeight;
703
+
704
+ return { tensor, paddingX: offsetX, paddingY: offsetY, scaleX, scaleY };
705
+ }
706
+
707
+ private postprocessYOLO(
708
+ output: Float32Array,
709
+ numDetections: number,
710
+ imgWidth: number,
711
+ imgHeight: number,
712
+ paddingX: number,
713
+ paddingY: number,
714
+ scaleX: number,
715
+ scaleY: number
716
+ ): Array<{
717
+ bbox: { x1: number; y1: number; x2: number; y2: number; confidence: number };
718
+ classId: number;
719
+ className: string;
720
+ }> {
721
+ const detections: Array<{
722
+ bbox: { x1: number; y1: number; x2: number; y2: number; confidence: number };
723
+ classId: number;
724
+ className: string;
725
+ }> = [];
726
+
727
+ for (let i = 0; i < numDetections; i++) {
728
+ const idx = i * 6;
729
+ const x1 = output[idx];
730
+ const y1 = output[idx + 1];
731
+ const x2 = output[idx + 2];
732
+ const y2 = output[idx + 3];
733
+ const confidence = output[idx + 4];
734
+ const classId = Math.round(output[idx + 5]);
735
+
736
+ if (confidence < this.config.detConfidence) continue;
737
+ if (this.classFilter && !this.classFilter.has(classId)) continue;
738
+
739
+ const tx1 = (x1 - paddingX) * scaleX;
740
+ const ty1 = (y1 - paddingY) * scaleY;
741
+ const tx2 = (x2 - paddingX) * scaleX;
742
+ const ty2 = (y2 - paddingY) * scaleY;
743
+
744
+ detections.push({
745
+ bbox: {
746
+ x1: Math.max(0, tx1),
747
+ y1: Math.max(0, ty1),
748
+ x2: Math.min(imgWidth, tx2),
749
+ y2: Math.min(imgHeight, ty2),
750
+ confidence,
751
+ },
752
+ classId,
753
+ className: ANIMAL_CLASSES[classId] || `animal_${classId}`,
754
+ });
755
+ }
756
+
757
+ return this.applyNMS(detections, this.config.nmsThreshold);
758
+ }
759
+
760
+ private preprocessPose(
761
+ imageData: Uint8Array,
762
+ imgWidth: number,
763
+ imgHeight: number,
764
+ bbox: { x1: number; y1: number; x2: number; y2: number; confidence: number },
765
+ inputSize: [number, number]
766
+ ): { tensor: Float32Array; center: [number, number]; scale: [number, number] } {
767
+ const [inputW, inputH] = inputSize;
768
+ const bboxWidth = bbox.x2 - bbox.x1;
769
+ const bboxHeight = bbox.y2 - bbox.y1;
770
+
771
+ const center: [number, number] = [
772
+ bbox.x1 + bboxWidth / 2,
773
+ bbox.y1 + bboxHeight / 2,
774
+ ];
775
+
776
+ const bboxAspectRatio = bboxWidth / bboxHeight;
777
+ const modelAspectRatio = inputW / inputH;
778
+
779
+ let scaleW: number, scaleH: number;
780
+ if (bboxAspectRatio > modelAspectRatio) {
781
+ scaleW = bboxWidth * 1.25;
782
+ scaleH = scaleW / modelAspectRatio;
783
+ } else {
784
+ scaleH = bboxHeight * 1.25;
785
+ scaleW = scaleH * modelAspectRatio;
786
+ }
787
+
788
+ const scale: [number, number] = [scaleW, scaleH];
789
+
790
+ if (!this.poseCanvas || !this.poseCtx) {
791
+ this.poseCanvas = document.createElement('canvas');
792
+ this.poseCanvas.width = inputW;
793
+ this.poseCanvas.height = inputH;
794
+ this.poseCtx = this.poseCanvas.getContext('2d', { willReadFrequently: true, alpha: false })!;
795
+ this.poseTensorBuffer = new Float32Array(3 * inputW * inputH);
796
+ }
797
+
798
+ const ctx = this.poseCtx;
799
+ ctx.clearRect(0, 0, inputW, inputH);
800
+
801
+ const srcCanvas = document.createElement('canvas');
802
+ const srcCtx = srcCanvas.getContext('2d')!;
803
+ srcCanvas.width = imgWidth;
804
+ srcCanvas.height = imgHeight;
805
+
806
+ const srcImageData = srcCtx.createImageData(imgWidth, imgHeight);
807
+ srcImageData.data.set(imageData);
808
+ srcCtx.putImageData(srcImageData, 0, 0);
809
+
810
+ const srcX = center[0] - scaleW / 2;
811
+ const srcY = center[1] - scaleH / 2;
812
+ ctx.drawImage(srcCanvas, srcX, srcY, scaleW, scaleH, 0, 0, inputW, inputH);
813
+
814
+ const croppedData = ctx.getImageData(0, 0, inputW, inputH);
815
+ const tensor = this.poseTensorBuffer!;
816
+ const len = croppedData.data.length;
817
+ const planeSize = inputW * inputH;
818
+
819
+ const mean0 = 123.675, mean1 = 116.28, mean2 = 103.53;
820
+ const stdInv0 = 1 / 58.395, stdInv1 = 1 / 57.12, stdInv2 = 1 / 57.375;
821
+
822
+ for (let i = 0; i < len; i += 16) {
823
+ const p1 = i / 4, p2 = p1 + 1, p3 = p1 + 2, p4 = p1 + 3;
824
+
825
+ tensor[p1] = (croppedData.data[i] - mean0) * stdInv0;
826
+ tensor[p2] = (croppedData.data[i + 4] - mean0) * stdInv0;
827
+ tensor[p3] = (croppedData.data[i + 8] - mean0) * stdInv0;
828
+ tensor[p4] = (croppedData.data[i + 12] - mean0) * stdInv0;
829
+
830
+ tensor[p1 + planeSize] = (croppedData.data[i + 1] - mean1) * stdInv1;
831
+ tensor[p2 + planeSize] = (croppedData.data[i + 5] - mean1) * stdInv1;
832
+ tensor[p3 + planeSize] = (croppedData.data[i + 9] - mean1) * stdInv1;
833
+ tensor[p4 + planeSize] = (croppedData.data[i + 13] - mean1) * stdInv1;
834
+
835
+ tensor[p1 + planeSize * 2] = (croppedData.data[i + 2] - mean2) * stdInv2;
836
+ tensor[p2 + planeSize * 2] = (croppedData.data[i + 6] - mean2) * stdInv2;
837
+ tensor[p3 + planeSize * 2] = (croppedData.data[i + 10] - mean2) * stdInv2;
838
+ tensor[p4 + planeSize * 2] = (croppedData.data[i + 14] - mean2) * stdInv2;
839
+ }
840
+
841
+ return { tensor, center, scale };
842
+ }
843
+
844
+ private postprocessPose(
845
+ simccX: Float32Array,
846
+ simccY: Float32Array,
847
+ shapeX: number[],
848
+ shapeY: number[],
849
+ center: [number, number],
850
+ scale: [number, number]
851
+ ): AnimalKeypoint[] {
852
+ const numKeypoints = shapeX[1];
853
+ const wx = shapeX[2];
854
+ const wy = shapeY[2];
855
+
856
+ const keypoints: AnimalKeypoint[] = [];
857
+
858
+ for (let k = 0; k < numKeypoints; k++) {
859
+ let maxX = -Infinity, argmaxX = 0;
860
+ for (let i = 0; i < wx; i++) {
861
+ const val = simccX[k * wx + i];
862
+ if (val > maxX) { maxX = val; argmaxX = i; }
863
+ }
864
+
865
+ let maxY = -Infinity, argmaxY = 0;
866
+ for (let i = 0; i < wy; i++) {
867
+ const val = simccY[k * wy + i];
868
+ if (val > maxY) { maxY = val; argmaxY = i; }
869
+ }
870
+
871
+ const score = 0.5 * (maxX + maxY);
872
+ const visible = score > this.config.poseConfidence;
873
+
874
+ const normX = argmaxX / wx;
875
+ const normY = argmaxY / wy;
876
+
877
+ const x = (normX - 0.5) * scale[0] + center[0];
878
+ const y = (normY - 0.5) * scale[1] + center[1];
879
+
880
+ keypoints.push({
881
+ x,
882
+ y,
883
+ score,
884
+ visible,
885
+ name: KEYPOINT_NAMES[k] || `keypoint_${k}`,
886
+ });
887
+ }
888
+
889
+ return keypoints;
890
+ }
891
+
892
+ private applyNMS<T extends { bbox: { x1: number; y1: number; x2: number; y2: number; confidence: number } }>(
893
+ detections: T[],
894
+ iouThreshold: number
895
+ ): T[] {
896
+ if (detections.length === 0) return [];
897
+
898
+ detections.sort((a, b) => b.bbox.confidence - a.bbox.confidence);
899
+
900
+ const selected: T[] = [];
901
+ const used = new Set<number>();
902
+
903
+ for (let i = 0; i < detections.length; i++) {
904
+ if (used.has(i)) continue;
905
+
906
+ selected.push(detections[i]);
907
+ used.add(i);
908
+
909
+ for (let j = i + 1; j < detections.length; j++) {
910
+ if (used.has(j)) continue;
911
+
912
+ const iou = this.calculateIoU(detections[i].bbox, detections[j].bbox);
913
+ if (iou > iouThreshold) {
914
+ used.add(j);
915
+ }
916
+ }
917
+ }
918
+
919
+ return selected;
920
+ }
921
+
922
+ private calculateIoU(
923
+ box1: { x1: number; y1: number; x2: number; y2: number },
924
+ box2: { x1: number; y1: number; x2: number; y2: number }
925
+ ): number {
926
+ const x1 = Math.max(box1.x1, box2.x1);
927
+ const y1 = Math.max(box1.y1, box2.y1);
928
+ const x2 = Math.min(box1.x2, box2.x2);
929
+ const y2 = Math.min(box1.y2, box2.y2);
930
+
931
+ if (x2 <= x1 || y2 <= y1) return 0;
932
+
933
+ const intersection = (x2 - x1) * (y2 - y1);
934
+ const area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);
935
+ const area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1);
936
+ const union = area1 + area2 - intersection;
937
+
938
+ return intersection / union;
939
+ }
940
+
941
+ /**
942
+ * Dispose resources
943
+ */
944
+ dispose(): void {
945
+ if (this.detSession) {
946
+ this.detSession.release();
947
+ this.detSession = null;
948
+ }
949
+ if (this.poseSession) {
950
+ this.poseSession.release();
951
+ this.poseSession = null;
952
+ }
953
+ this.initialized = false;
954
+ }
955
+ }