ts-classify 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/LICENSE.md ADDED
@@ -0,0 +1,20 @@
1
+ Copyright (c) 2026 Bart Riepe
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining
4
+ a copy of this software and associated documentation files (the
5
+ "Software"), to deal in the Software without restriction, including
6
+ without limitation the rights to use, copy, modify, merge, publish,
7
+ distribute, sublicense, and/or sell copies of the Software, and to
8
+ permit persons to whom the Software is furnished to do so, subject to
9
+ the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be
12
+ included in all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package/README.md CHANGED
@@ -1,148 +1,95 @@
1
- # SVM
1
+ # ts-classify
2
2
 
3
- A fast Support Vector Machine implementation in Rust, targeting both native and WebAssembly.
3
+ Fast text classification via SVM, compiled to WebAssembly.
4
4
 
5
5
  ## Features
6
6
 
7
- - **Binary SVM** with two solvers:
8
- - **Coordinate Descent** (default) - O(N) per pass, fast for large datasets
9
- - **SMO** - O(N²), useful for kernel SVMs (future)
10
- - **Multiclass classification**:
11
- - **One-vs-Rest (OvR)** - K classifiers for K classes, fast prediction
12
- - **One-vs-One (OvO)** - K(K-1)/2 classifiers, often more accurate
13
7
  - **Sparse vectors** - efficient for high-dimensional data (text, trigrams)
14
- - **Parallel training** via rayon (native only)
15
- - **Binary serialization** via bincode
16
- - **WebAssembly support** with JS-friendly API
8
+ - **Binary serialization** - train once, load instantly
9
+ - **Written in Rust**, compiled to WASM for near-native performance in the browser
17
10
 
18
11
  ## Installation
19
12
 
20
- ```toml
21
- [dependencies]
22
- svm = { path = "." }
13
+ ```bash
14
+ npm install ts-classify
23
15
  ```
24
16
 
25
- ## Usage
26
-
27
- ### Binary SVM
28
-
29
- ```rust
30
- use svm::{SVM, SparseVec};
17
+ ## Classifiers
31
18
 
32
- // Sparse format: Vec<(index, value)>
33
- let samples: Vec<SparseVec> = vec![
34
- vec![(0, -2.0), (1, 0.0)],
35
- vec![(0, -1.0), (1, 1.0)],
36
- vec![(0, 1.0), (1, 0.0)],
37
- vec![(0, 2.0), (1, 1.0)],
38
- ];
39
- let labels = vec![-1.0, -1.0, 1.0, 1.0];
19
+ All classifiers use the same sparse data format: flat `Float64Array` of `[index, value, index, value, ...]` pairs.
40
20
 
41
- let mut svm = SVM::new();
42
- svm.train_sparse(&samples, &labels);
21
+ ```javascript
22
+ import init, { JsOneVsRestSVM, JsMulticlassSVM, JsNearestCentroid, JsSVM } from 'ts-classify';
43
23
 
44
- assert_eq!(svm.predict_sparse(&[(0, -3.0)]), -1.0);
45
- assert_eq!(svm.predict_sparse(&[(0, 3.0)]), 1.0);
24
+ await init();
46
25
  ```
47
26
 
48
- ### Multiclass (One-vs-Rest)
49
-
50
- ```rust
51
- use svm::{OneVsRestSVM, SparseVec};
27
+ ### JsOneVsRestSVM
52
28
 
53
- let samples: Vec<SparseVec> = vec![
54
- vec![(0, -2.0), (1, -2.0)], // class 0
55
- vec![(0, 2.0), (1, -2.0)], // class 1
56
- vec![(0, 0.0), (1, 2.0)], // class 2
57
- // ... more samples
58
- ];
59
- let labels = vec![0, 1, 2, /* ... */];
29
+ Multiclass SVM using one-vs-rest strategy. Trains K classifiers for K classes. Best for large numbers of classes.
60
30
 
61
- let mut svm = OneVsRestSVM::new();
62
- svm.train_sparse(&samples, &labels);
63
-
64
- let predicted_class = svm.predict_sparse(&[(0, 0.0), (1, 3.0)]);
31
+ ```javascript
32
+ const svm = new JsOneVsRestSVM();
33
+ svm.train(samplesFlat, sampleLengths, labels);
34
+ const prediction = svm.predict(sampleFlat);
35
+ const margins = svm.margins(sampleFlat); // [class0, margin0, class1, margin1, ...]
65
36
  ```
66
37
 
67
- ### Custom Configuration
38
+ ### JsMulticlassSVM
68
39
 
69
- ```rust
70
- use svm::{SVM, TrainConfig, Solver};
40
+ Multiclass SVM using one-vs-one strategy. Trains K(K-1)/2 classifiers, predicts by voting. Often more accurate for smaller numbers of classes.
71
41
 
72
- let config = TrainConfig::coordinate_descent()
73
- .with_c(0.5) // regularization (default: 1.0)
74
- .with_max_iter(500) // max iterations (default: 1000)
75
- .with_tol(0.001); // tolerance (default: 0.01)
76
-
77
- let mut svm = SVM::new();
78
- svm.train_sparse_with_config(&samples, &labels, &config);
79
-
80
- // Or use SMO solver
81
- let config = TrainConfig::smo()
82
- .with_c(1.0)
83
- .with_max_iter(10000);
42
+ ```javascript
43
+ const svm = new JsMulticlassSVM();
44
+ svm.train(samplesFlat, sampleLengths, labels);
45
+ const prediction = svm.predict(sampleFlat);
84
46
  ```
85
47
 
86
- ### Serialization
48
+ ### JsNearestCentroid
87
49
 
88
- ```rust
89
- // Save
90
- let bytes = svm.to_bytes()?;
91
- std::fs::write("model.bin", &bytes)?;
50
+ Computes the centroid of each class and predicts by cosine similarity. Very fast to train, no hyperparameters.
92
51
 
93
- // Load
94
- let bytes = std::fs::read("model.bin")?;
95
- let svm = SVM::from_bytes(&bytes)?;
52
+ ```javascript
53
+ const nc = new JsNearestCentroid();
54
+ nc.train(samplesFlat, sampleLengths, labels);
55
+ const prediction = nc.predict(sampleFlat);
56
+ const similarities = nc.margins(sampleFlat); // [class0, sim0, class1, sim1, ...]
96
57
  ```
97
58
 
98
- ## WebAssembly
99
-
100
- Build for WASM (without parallel feature):
59
+ ### JsSVM
101
60
 
102
- ```bash
103
- cargo build --release --target wasm32-unknown-unknown --no-default-features
104
- ```
61
+ Binary SVM for two-class problems. Labels are -1.0 or 1.0.
105
62
 
106
- Or use wasm-pack:
107
-
108
- ```bash
109
- wasm-pack build --target web --no-default-features
63
+ ```javascript
64
+ const svm = new JsSVM();
65
+ svm.train(samplesFlat, sampleLengths, labels); // labels: Float64Array of -1.0/1.0
66
+ const prediction = svm.predict(sampleFlat); // returns -1.0 or 1.0
67
+ const margin = svm.margin(sampleFlat); // raw decision value
110
68
  ```
111
69
 
112
- ### JavaScript API
70
+ ### Data format
113
71
 
114
72
  ```javascript
115
- import { JsOneVsRestSVM } from 'svm';
116
-
117
- const svm = new JsOneVsRestSVM();
118
-
119
73
  // Flat sparse format: [idx, val, idx, val, ...]
120
- const sampleFlat = new Float64Array([0, -2.0, 1, 0.0]); // single sample
121
74
  const samplesFlat = new Float64Array([
122
75
  0, -2.0, 1, 0.0, // sample 0: 2 pairs
123
76
  0, 1.0, 1, 1.0, // sample 1: 2 pairs
124
77
  ]);
125
78
  const sampleLengths = new Uint32Array([2, 2]); // pairs per sample
126
- const labels = new Int32Array([0, 1]);
79
+ const labels = new Int32Array([0, 1]); // class IDs (or Float64Array for JsSVM)
80
+ ```
127
81
 
128
- svm.train(samplesFlat, sampleLengths, labels);
82
+ ### Serialization
129
83
 
130
- const prediction = svm.predict(sampleFlat);
84
+ All classifiers support binary serialization:
131
85
 
132
- // Serialize/deserialize
86
+ ```javascript
133
87
  const bytes = svm.to_bytes();
134
88
  const loaded = JsOneVsRestSVM.from_bytes(bytes);
135
89
  ```
136
90
 
137
91
  ## Performance
138
92
 
139
- Coordinate descent vs SMO on 50-class OvR:
140
-
141
- | Solver | Train Time |
142
- |--------|------------|
143
- | Coordinate Descent | 6ms |
144
- | SMO | 29s |
145
-
146
93
  Scaling (coordinate descent, OvR):
147
94
 
148
95
  | Classes | Classifiers | Train | Predict |
@@ -152,16 +99,18 @@ Scaling (coordinate descent, OvR):
152
99
  | 100 | 100 | 1.4ms | 1.6µs |
153
100
  | 5000 | 5000 | ~50-100ms | ~80µs |
154
101
 
155
- ## Feature Flags
102
+ ## Building from source
156
103
 
157
- - `parallel` (default) - Enable parallel training via rayon. Disable for WASM.
104
+ Requires [Rust](https://rustup.rs/) and [wasm-pack](https://rustwasm.github.io/wasm-pack/installer/).
158
105
 
159
- ```toml
160
- # Native with parallel
161
- svm = { path = "." }
106
+ ```bash
107
+ wasm-pack build --target web --no-default-features
108
+ ```
162
109
 
163
- # WASM (no parallel)
164
- svm = { path = ".", default-features = false }
110
+ Run tests (native):
111
+
112
+ ```bash
113
+ cargo test --release
165
114
  ```
166
115
 
167
116
  ## License
package/package.json CHANGED
@@ -2,7 +2,7 @@
2
2
  "name": "ts-classify",
3
3
  "type": "module",
4
4
  "description": "Fast text classification with SVM and Nearest Centroid (WebAssembly)",
5
- "version": "0.1.0",
5
+ "version": "0.2.0",
6
6
  "license": "MIT",
7
7
  "repository": {
8
8
  "type": "git",
package/ts_classify.d.ts CHANGED
@@ -1,6 +1,44 @@
1
1
  /* tslint:disable */
2
2
  /* eslint-disable */
3
3
 
4
+ /**
5
+ * JS-friendly wrapper for One-vs-One multiclass SVM.
6
+ */
7
+ export class JsMulticlassSVM {
8
+ free(): void;
9
+ [Symbol.dispose](): void;
10
+ /**
11
+ * Deserialize from bytes.
12
+ */
13
+ static from_bytes(bytes: Uint8Array): JsMulticlassSVM;
14
+ constructor();
15
+ /**
16
+ * Number of classes.
17
+ */
18
+ num_classes(): number;
19
+ /**
20
+ * Number of binary classifiers.
21
+ */
22
+ num_classifiers(): number;
23
+ /**
24
+ * Predict class for a single sparse sample.
25
+ */
26
+ predict(sample_flat: Float64Array): number;
27
+ /**
28
+ * Serialize to bytes.
29
+ */
30
+ to_bytes(): Uint8Array;
31
+ /**
32
+ * Train on sparse data.
33
+ * labels: integer class IDs (0, 1, 2, ...)
34
+ */
35
+ train(samples_flat: Float64Array, sample_lengths: Uint32Array, labels: Int32Array): void;
36
+ /**
37
+ * Train with custom parameters.
38
+ */
39
+ train_with_config(samples_flat: Float64Array, sample_lengths: Uint32Array, labels: Int32Array, use_smo: boolean, c: number, max_iter: number, tol: number): void;
40
+ }
41
+
4
42
  /**
5
43
  * JS-friendly wrapper for Nearest Centroid classifier.
6
44
  */
@@ -119,9 +157,18 @@ export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembl
119
157
 
120
158
  export interface InitOutput {
121
159
  readonly memory: WebAssembly.Memory;
160
+ readonly __wbg_jsmulticlasssvm_free: (a: number, b: number) => void;
122
161
  readonly __wbg_jsnearestcentroid_free: (a: number, b: number) => void;
123
162
  readonly __wbg_jsonevsrestsvm_free: (a: number, b: number) => void;
124
163
  readonly __wbg_jssvm_free: (a: number, b: number) => void;
164
+ readonly jsmulticlasssvm_from_bytes: (a: number, b: number) => [number, number, number];
165
+ readonly jsmulticlasssvm_new: () => number;
166
+ readonly jsmulticlasssvm_num_classes: (a: number) => number;
167
+ readonly jsmulticlasssvm_num_classifiers: (a: number) => number;
168
+ readonly jsmulticlasssvm_predict: (a: number, b: number, c: number) => number;
169
+ readonly jsmulticlasssvm_to_bytes: (a: number) => [number, number, number, number];
170
+ readonly jsmulticlasssvm_train: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
171
+ readonly jsmulticlasssvm_train_with_config: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number) => void;
125
172
  readonly jsnearestcentroid_from_bytes: (a: number, b: number) => [number, number, number];
126
173
  readonly jsnearestcentroid_margins: (a: number, b: number, c: number) => [number, number];
127
174
  readonly jsnearestcentroid_new: () => number;
package/ts_classify.js CHANGED
@@ -1,5 +1,124 @@
1
1
  /* @ts-self-types="./ts_classify.d.ts" */
2
2
 
3
+ /**
4
+ * JS-friendly wrapper for One-vs-One multiclass SVM.
5
+ */
6
+ export class JsMulticlassSVM {
7
+ static __wrap(ptr) {
8
+ ptr = ptr >>> 0;
9
+ const obj = Object.create(JsMulticlassSVM.prototype);
10
+ obj.__wbg_ptr = ptr;
11
+ JsMulticlassSVMFinalization.register(obj, obj.__wbg_ptr, obj);
12
+ return obj;
13
+ }
14
+ __destroy_into_raw() {
15
+ const ptr = this.__wbg_ptr;
16
+ this.__wbg_ptr = 0;
17
+ JsMulticlassSVMFinalization.unregister(this);
18
+ return ptr;
19
+ }
20
+ free() {
21
+ const ptr = this.__destroy_into_raw();
22
+ wasm.__wbg_jsmulticlasssvm_free(ptr, 0);
23
+ }
24
+ /**
25
+ * Deserialize from bytes.
26
+ * @param {Uint8Array} bytes
27
+ * @returns {JsMulticlassSVM}
28
+ */
29
+ static from_bytes(bytes) {
30
+ const ptr0 = passArray8ToWasm0(bytes, wasm.__wbindgen_malloc);
31
+ const len0 = WASM_VECTOR_LEN;
32
+ const ret = wasm.jsmulticlasssvm_from_bytes(ptr0, len0);
33
+ if (ret[2]) {
34
+ throw takeFromExternrefTable0(ret[1]);
35
+ }
36
+ return JsMulticlassSVM.__wrap(ret[0]);
37
+ }
38
+ constructor() {
39
+ const ret = wasm.jsmulticlasssvm_new();
40
+ this.__wbg_ptr = ret >>> 0;
41
+ JsMulticlassSVMFinalization.register(this, this.__wbg_ptr, this);
42
+ return this;
43
+ }
44
+ /**
45
+ * Number of classes.
46
+ * @returns {number}
47
+ */
48
+ num_classes() {
49
+ const ret = wasm.jsmulticlasssvm_num_classes(this.__wbg_ptr);
50
+ return ret >>> 0;
51
+ }
52
+ /**
53
+ * Number of binary classifiers.
54
+ * @returns {number}
55
+ */
56
+ num_classifiers() {
57
+ const ret = wasm.jsmulticlasssvm_num_classifiers(this.__wbg_ptr);
58
+ return ret >>> 0;
59
+ }
60
+ /**
61
+ * Predict class for a single sparse sample.
62
+ * @param {Float64Array} sample_flat
63
+ * @returns {number}
64
+ */
65
+ predict(sample_flat) {
66
+ const ptr0 = passArrayF64ToWasm0(sample_flat, wasm.__wbindgen_malloc);
67
+ const len0 = WASM_VECTOR_LEN;
68
+ const ret = wasm.jsmulticlasssvm_predict(this.__wbg_ptr, ptr0, len0);
69
+ return ret;
70
+ }
71
+ /**
72
+ * Serialize to bytes.
73
+ * @returns {Uint8Array}
74
+ */
75
+ to_bytes() {
76
+ const ret = wasm.jsmulticlasssvm_to_bytes(this.__wbg_ptr);
77
+ if (ret[3]) {
78
+ throw takeFromExternrefTable0(ret[2]);
79
+ }
80
+ var v1 = getArrayU8FromWasm0(ret[0], ret[1]).slice();
81
+ wasm.__wbindgen_free(ret[0], ret[1] * 1, 1);
82
+ return v1;
83
+ }
84
+ /**
85
+ * Train on sparse data.
86
+ * labels: integer class IDs (0, 1, 2, ...)
87
+ * @param {Float64Array} samples_flat
88
+ * @param {Uint32Array} sample_lengths
89
+ * @param {Int32Array} labels
90
+ */
91
+ train(samples_flat, sample_lengths, labels) {
92
+ const ptr0 = passArrayF64ToWasm0(samples_flat, wasm.__wbindgen_malloc);
93
+ const len0 = WASM_VECTOR_LEN;
94
+ const ptr1 = passArray32ToWasm0(sample_lengths, wasm.__wbindgen_malloc);
95
+ const len1 = WASM_VECTOR_LEN;
96
+ const ptr2 = passArray32ToWasm0(labels, wasm.__wbindgen_malloc);
97
+ const len2 = WASM_VECTOR_LEN;
98
+ wasm.jsmulticlasssvm_train(this.__wbg_ptr, ptr0, len0, ptr1, len1, ptr2, len2);
99
+ }
100
+ /**
101
+ * Train with custom parameters.
102
+ * @param {Float64Array} samples_flat
103
+ * @param {Uint32Array} sample_lengths
104
+ * @param {Int32Array} labels
105
+ * @param {boolean} use_smo
106
+ * @param {number} c
107
+ * @param {number} max_iter
108
+ * @param {number} tol
109
+ */
110
+ train_with_config(samples_flat, sample_lengths, labels, use_smo, c, max_iter, tol) {
111
+ const ptr0 = passArrayF64ToWasm0(samples_flat, wasm.__wbindgen_malloc);
112
+ const len0 = WASM_VECTOR_LEN;
113
+ const ptr1 = passArray32ToWasm0(sample_lengths, wasm.__wbindgen_malloc);
114
+ const len1 = WASM_VECTOR_LEN;
115
+ const ptr2 = passArray32ToWasm0(labels, wasm.__wbindgen_malloc);
116
+ const len2 = WASM_VECTOR_LEN;
117
+ wasm.jsmulticlasssvm_train_with_config(this.__wbg_ptr, ptr0, len0, ptr1, len1, ptr2, len2, use_smo, c, max_iter, tol);
118
+ }
119
+ }
120
+ if (Symbol.dispose) JsMulticlassSVM.prototype[Symbol.dispose] = JsMulticlassSVM.prototype.free;
121
+
3
122
  /**
4
123
  * JS-friendly wrapper for Nearest Centroid classifier.
5
124
  */
@@ -377,6 +496,9 @@ function __wbg_get_imports() {
377
496
  };
378
497
  }
379
498
 
499
+ const JsMulticlassSVMFinalization = (typeof FinalizationRegistry === 'undefined')
500
+ ? { register: () => {}, unregister: () => {} }
501
+ : new FinalizationRegistry(ptr => wasm.__wbg_jsmulticlasssvm_free(ptr >>> 0, 1));
380
502
  const JsNearestCentroidFinalization = (typeof FinalizationRegistry === 'undefined')
381
503
  ? { register: () => {}, unregister: () => {} }
382
504
  : new FinalizationRegistry(ptr => wasm.__wbg_jsnearestcentroid_free(ptr >>> 0, 1));
Binary file