learning_model 1.0.38 → 1.0.39

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,6 @@ import * as tf from '@tensorflow/tfjs';
2
2
  import { io } from '@tensorflow/tfjs-core';
3
3
  interface LearningInterface {
4
4
  model: tf.LayersModel | null;
5
- labels: string[];
6
5
  isRunning: boolean;
7
6
  isReady: boolean;
8
7
  onProgress(progress: number): void;
@@ -7,21 +7,18 @@ declare class LearningMobilenet implements LearningInterface {
7
7
  batchSize: number;
8
8
  learningRate: number;
9
9
  validateRate: number;
10
- labels: string[];
11
- modelURL: string;
12
10
  isRunning: boolean;
13
11
  isReady: boolean;
14
12
  isTrainedDone: boolean;
15
13
  limitSize: number;
16
14
  mobilenetModule: tf.LayersModel | null;
17
- imageTensors: tf.Tensor<tf.Rank>[];
15
+ imageExamples: Float32Array[][];
16
+ classNumber: string[];
18
17
  readonly MOBILE_NET_INPUT_WIDTH = 224;
19
18
  readonly MOBILE_NET_INPUT_HEIGHT = 224;
20
19
  readonly MOBILE_NET_INPUT_CHANNEL = 3;
21
20
  readonly IMAGE_NORMALIZATION_FACTOR = 255;
22
- constructor({ modelURL, // 디폴트 mobilenet 이미지
23
- epochs, batchSize, limitSize, learningRate, validateRate, }?: {
24
- modelURL?: string;
21
+ constructor({ epochs, batchSize, limitSize, learningRate, validateRate, }?: {
25
22
  epochs?: number;
26
23
  batchSize?: number;
27
24
  limitSize?: number;
@@ -38,6 +35,8 @@ declare class LearningMobilenet implements LearningInterface {
38
35
  jsonURL: string;
39
36
  labels: Array<string>;
40
37
  }): Promise<void>;
38
+ private registerClassNumber;
39
+ private _convertToTfDataset;
41
40
  addData(label: string, data: any): Promise<void>;
42
41
  init(): Promise<void>;
43
42
  train(): Promise<tf.History>;
@@ -45,9 +44,6 @@ declare class LearningMobilenet implements LearningInterface {
45
44
  saveModel(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<void>;
46
45
  running(): boolean;
47
46
  ready(): boolean;
48
- private _preprocessedTargetData;
49
- private _preprocessedInputData;
50
- private loadModel;
51
47
  private _createModel;
52
48
  }
53
49
  export default LearningMobilenet;
@@ -38,10 +38,14 @@ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, ge
38
38
  };
39
39
  Object.defineProperty(exports, "__esModule", { value: true });
40
40
  const tf = __importStar(require("@tensorflow/tfjs"));
41
+ const tfjs_1 = require("@tensorflow/tfjs");
42
+ const tf_1 = require("../utils/tf");
43
+ const canvas_1 = require("../utils/canvas");
44
+ const dataset_1 = require("../utils/dataset");
41
45
  class LearningMobilenet {
42
- constructor({ modelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json', // 디폴트 mobilenet 이미지
43
- epochs = 10, batchSize = 16, limitSize = 2, learningRate = 0.001, validateRate = 0.2, } = {}) {
44
- this.imageTensors = [];
46
+ constructor({ epochs = 50, batchSize = 16, limitSize = 2, learningRate = 0.001, validateRate = 0.15, } = {}) {
47
+ this.imageExamples = [];
48
+ this.classNumber = [];
45
49
  this.MOBILE_NET_INPUT_WIDTH = 224;
46
50
  this.MOBILE_NET_INPUT_HEIGHT = 224;
47
51
  this.MOBILE_NET_INPUT_CHANNEL = 3;
@@ -58,14 +62,12 @@ class LearningMobilenet {
58
62
  this.batchSize = batchSize;
59
63
  this.learningRate = learningRate;
60
64
  this.validateRate = validateRate;
61
- this.labels = [];
62
- this.modelURL = modelURL;
63
65
  this.isRunning = false;
64
66
  this.isReady = false;
65
67
  this.isTrainedDone = false;
66
68
  this.limitSize = limitSize;
67
69
  this.mobilenetModule = null;
68
- this.imageTensors = [];
70
+ this.classNumber = [];
69
71
  }
70
72
  //
71
73
  // 기존의 모델 로드
@@ -76,7 +78,9 @@ class LearningMobilenet {
76
78
  }
77
79
  try {
78
80
  this.model = yield tf.loadLayersModel(jsonURL);
79
- this.labels = labels;
81
+ for (var i = 0; i < labels.length; i++) {
82
+ this.registerClassNumber(labels[i]);
83
+ }
80
84
  this.isReady = true;
81
85
  this.model.summary();
82
86
  }
@@ -86,33 +90,68 @@ class LearningMobilenet {
86
90
  }
87
91
  });
88
92
  }
93
+ registerClassNumber(value) {
94
+ // 중복 값의 인덱스를 찾습니다.
95
+ const existingIndex = this.classNumber.indexOf(value);
96
+ // 중복 값이 있다면 해당 인덱스를 반환합니다.
97
+ if (existingIndex !== -1) {
98
+ return existingIndex;
99
+ }
100
+ // 중복 값이 없다면 새로운 항목을 추가하고 그 인덱스를 반환합니다.
101
+ this.classNumber.push(value);
102
+ return this.classNumber.length - 1;
103
+ }
104
+ _convertToTfDataset() {
105
+ for (let i = 0; i < this.imageExamples.length; i++) {
106
+ this.imageExamples[i] = (0, dataset_1.fisherYates)(this.imageExamples[i]);
107
+ }
108
+ const trainDataset = [];
109
+ const validationDataset = [];
110
+ for (let i = 0; i < this.imageExamples.length; i++) {
111
+ const classLength = this.imageExamples[i].length;
112
+ // 클래스의 전체 데이터 수를 사용하여 학습 및 검증 데이터 수 계산
113
+ const numValidation = Math.ceil(this.validateRate * classLength);
114
+ const numTrain = classLength - numValidation;
115
+ // One-Hot 인코딩을 사용하여 라벨 생성
116
+ const y = (0, dataset_1.flatOneHot)(i, this.classNumber.length);
117
+ // numTrain과 numValidation에 따라 데이터를 학습 및 검증 데이터로 분할
118
+ const classTrain = this.imageExamples[i].slice(0, numTrain).map(dataArray => ({ data: dataArray, label: y }));
119
+ trainDataset.push(...classTrain);
120
+ const classValidation = this.imageExamples[i].slice(numTrain, numTrain + numValidation).map(dataArray => ({ data: dataArray, label: y }));
121
+ validationDataset.push(...classValidation);
122
+ }
123
+ // Shuffle entire datasets
124
+ const shuffledTrainDataset = (0, dataset_1.fisherYates)(trainDataset);
125
+ const shuffledValidationDataset = (0, dataset_1.fisherYates)(validationDataset);
126
+ // Convert to tf.data.Dataset
127
+ const trainX = tf.data.array(shuffledTrainDataset.map(sample => sample.data));
128
+ const validationX = tf.data.array(shuffledValidationDataset.map(sample => sample.data));
129
+ const trainY = tf.data.array(shuffledTrainDataset.map(sample => sample.label));
130
+ const validationY = tf.data.array(shuffledValidationDataset.map(sample => sample.label));
131
+ return {
132
+ trainDataset: tf.data.zip({ xs: trainX, ys: trainY }),
133
+ validationDataset: tf.data.zip({ xs: validationX, ys: validationY })
134
+ };
135
+ }
89
136
  // 학습 데이타 등록
90
137
  addData(label, data) {
91
138
  return __awaiter(this, void 0, void 0, function* () {
92
139
  try {
93
- const imgTensor = tf.tidy(() => {
94
- if (this.mobilenetModule !== null) {
95
- console.log('predict before', data);
96
- const t = tf.browser.fromPixels(data);
97
- // Resize the image to match the model's input shape
98
- const resizedTensor = tf.image.resizeBilinear(t, [224, 224]);
99
- const normalizedTensor = resizedTensor.div(255.0);
100
- // Add an extra dimension to the tensor
101
- const expandedTensor = normalizedTensor.expandDims(0);
102
- console.log('predict extend', expandedTensor);
103
- const predict = this.mobilenetModule.predict(expandedTensor);
104
- console.log('predict', predict);
105
- return predict;
140
+ if (this.mobilenetModule !== null) {
141
+ const cap = (0, tf_1.isTensor)(data) ? data : (0, tf_1.capture)(data, false);
142
+ const predict = this.mobilenetModule.predict(cap);
143
+ const activation = predict.dataSync();
144
+ const classIndex = this.registerClassNumber(label);
145
+ if (!this.imageExamples[classIndex]) {
146
+ this.imageExamples[classIndex] = [];
106
147
  }
107
- else {
108
- throw new Error('mobilenetModule is null');
148
+ this.imageExamples[classIndex].push(activation);
149
+ if (this.classNumber.length >= this.limitSize) {
150
+ this.isReady = true;
109
151
  }
110
- });
111
- console.log('imgTensor', imgTensor);
112
- this.imageTensors.push(imgTensor);
113
- this.labels.push(label);
114
- if (this.labels.length >= this.limitSize) {
115
- this.isReady = true;
152
+ }
153
+ else {
154
+ throw new Error('mobilenetModule is null');
116
155
  }
117
156
  return Promise.resolve();
118
157
  }
@@ -125,7 +164,7 @@ class LearningMobilenet {
125
164
  init() {
126
165
  return __awaiter(this, void 0, void 0, function* () {
127
166
  try {
128
- this.mobilenetModule = yield this.loadModel();
167
+ this.mobilenetModule = yield (0, tf_1.loadModel)();
129
168
  }
130
169
  catch (error) {
131
170
  console.log('init Error', error);
@@ -144,12 +183,10 @@ class LearningMobilenet {
144
183
  onTrainBegin: (log) => {
145
184
  this.isTrainedDone = false;
146
185
  this.onTrainBegin(log);
147
- console.log('Training has started.');
148
186
  },
149
187
  onTrainEnd: (log) => {
150
188
  this.isTrainedDone = true;
151
189
  this.onTrainEnd(log);
152
- console.log('Training has ended.');
153
190
  this.isRunning = false;
154
191
  },
155
192
  onBatchBegin: (batch, logs) => {
@@ -163,8 +200,6 @@ class LearningMobilenet {
163
200
  },
164
201
  onEpochEnd: (epoch, logs) => {
165
202
  this.onEpochEnd(epoch, logs);
166
- console.log(`Epoch ${epoch + 1} has ended.`);
167
- console.log('Loss:', logs);
168
203
  this.onLoss(logs.loss);
169
204
  this.onProgress(epoch + 1);
170
205
  this.onEvents({
@@ -175,21 +210,24 @@ class LearningMobilenet {
175
210
  };
176
211
  try {
177
212
  this.isRunning = true;
178
- if (this.labels.length < this.limitSize) {
213
+ if (this.classNumber.length < this.limitSize) {
179
214
  return Promise.reject(new Error('Please train Data need over 2 data length'));
180
215
  }
181
- this.model = yield this._createModel(this.labels.length);
182
- const inputData = this._preprocessedInputData(this.model);
183
- console.log('this.imageTensors', this.imageTensors, inputData);
184
- const targetData = this._preprocessedTargetData();
185
- console.log('targetData', targetData);
186
- const history = yield this.model.fit(inputData, targetData, {
216
+ const datasets = this._convertToTfDataset();
217
+ const trainData = datasets.trainDataset.batch(this.batchSize);
218
+ const validationData = datasets.validationDataset.batch(this.batchSize);
219
+ const optimizer = tf.train.adam(this.learningRate);
220
+ const trainModel = yield this._createModel(optimizer);
221
+ const history = yield trainModel.fitDataset(trainData, {
187
222
  epochs: this.epochs,
188
- batchSize: this.batchSize,
189
- validationSplit: this.validateRate,
223
+ validationData: validationData,
190
224
  callbacks: customCallback
191
225
  });
192
- console.log('Model training completed', history);
226
+ const jointModel = tf.sequential();
227
+ jointModel.add(this.mobilenetModule);
228
+ jointModel.add(trainModel);
229
+ this.model = jointModel;
230
+ optimizer.dispose();
193
231
  return history;
194
232
  }
195
233
  catch (error) {
@@ -206,36 +244,18 @@ class LearningMobilenet {
206
244
  return Promise.reject(new Error('Model is Null'));
207
245
  }
208
246
  try {
209
- const imgTensor = tf.tidy(() => {
210
- if (this.mobilenetModule !== null) {
211
- console.log('predict before', data);
212
- const t = tf.browser.fromPixels(data);
213
- // Resize the image to match the model's input shape
214
- const resizedTensor = tf.image.resizeBilinear(t, [224, 224]);
215
- const normalizedTensor = resizedTensor.div(255.0);
216
- // Add an extra dimension to the tensor
217
- const expandedTensor = normalizedTensor.expandDims(0);
218
- console.log('predict extend', expandedTensor);
219
- const predict = this.mobilenetModule.predict(expandedTensor);
220
- console.log('predict', predict);
221
- return predict;
222
- }
223
- else {
224
- throw new Error('mobilenetModule is null');
225
- }
247
+ const classProbabilities = new Map();
248
+ const croppedImage = (0, canvas_1.cropTo)(data, 224, false);
249
+ const logits = tf.tidy(() => {
250
+ const captured = (0, tf_1.capture)(croppedImage, false);
251
+ return this.model.predict(captured);
226
252
  });
227
- const predictions = this.model.predict(imgTensor);
228
- const predictedClass = predictions.as1D().argMax();
229
- const classId = (yield predictedClass.data())[0];
230
- console.log('classId', classId, this.labels[classId]);
231
- const predictionsData = yield predictions.data(); // 예측 텐서의 데이터를 비동기로 가져옴
232
- const classProbabilities = new Map(); // 클래스별 확률 누적값을 저장할 맵
233
- console.log('predictionsData', predictionsData);
253
+ const values = yield logits.data();
234
254
  const EPSILON = 1e-6; // 매우 작은 값을 표현하기 위한 엡실론
235
- for (let i = 0; i < predictionsData.length; i++) {
236
- let probability = Math.max(0, Math.min(1, predictionsData[i])); // 확률 값을 0과 1 사이로 조정
255
+ for (let i = 0; i < values.length; i++) {
256
+ let probability = Math.max(0, Math.min(1, values[i])); // 확률 값을 0과 1 사이로 조정
237
257
  probability = probability < EPSILON ? 0 : probability; // 매우 작은 확률 값을 0으로 간주
238
- const className = this.labels[i]; // 클래스 이름
258
+ const className = this.classNumber[i]; // 클래스 이름
239
259
  const existingProbability = classProbabilities.get(className);
240
260
  if (existingProbability !== undefined) {
241
261
  classProbabilities.set(className, existingProbability + probability);
@@ -244,9 +264,7 @@ class LearningMobilenet {
244
264
  classProbabilities.set(className, probability);
245
265
  }
246
266
  }
247
- console.log('Class Probabilities:', classProbabilities);
248
- predictedClass.dispose();
249
- predictions.dispose();
267
+ (0, tfjs_1.dispose)(logits);
250
268
  return classProbabilities;
251
269
  }
252
270
  catch (error) {
@@ -258,7 +276,6 @@ class LearningMobilenet {
258
276
  saveModel(handlerOrURL, config) {
259
277
  var _a;
260
278
  return __awaiter(this, void 0, void 0, function* () {
261
- console.log('saved model');
262
279
  if (!this.isTrainedDone) {
263
280
  return Promise.reject(new Error('Train is not done status'));
264
281
  }
@@ -272,71 +289,36 @@ class LearningMobilenet {
272
289
  ready() {
273
290
  return this.isReady;
274
291
  }
275
- // target 라벨 데이타
276
- _preprocessedTargetData() {
277
- // 라벨 unique 처리 & 배열 리턴
278
- console.log('uniqueLabels.length', this.labels, this.labels.length);
279
- const labelIndices = this.labels.map((label) => this.labels.indexOf(label));
280
- console.log('labelIndices', labelIndices);
281
- const oneHotEncode = tf.oneHot(tf.tensor1d(labelIndices, 'int32'), this.labels.length);
282
- console.log('oneHotEncode', oneHotEncode);
283
- return oneHotEncode;
284
- }
285
- // 입력 이미지 데이타
286
- _preprocessedInputData(model) {
287
- // 이미지 배열을 배치로 변환 - [null, 224, 224, 3]
288
- const inputShape = model.inputs[0].shape;
289
- console.log('inputShape', inputShape);
290
- // inputShape를 이와 같이 포멧 맞춘다. for reshape to [224, 224, 3]
291
- const inputShapeArray = inputShape.slice(1);
292
- console.log('inputShapeArray', inputShapeArray);
293
- const inputBatch = tf.stack(this.imageTensors.map((image) => {
294
- return tf.reshape(image, inputShapeArray);
295
- }));
296
- return inputBatch;
297
- }
298
- loadModel() {
299
- return __awaiter(this, void 0, void 0, function* () {
300
- const load_model = yield tf.loadLayersModel(this.modelURL);
301
- load_model.summary();
302
- const layer = load_model.getLayer('conv_pw_13_relu');
303
- const truncatedModel = tf.model({
304
- inputs: load_model.inputs,
305
- outputs: layer.output
306
- });
307
- return truncatedModel;
308
- });
309
- }
310
292
  // 모델 저장
311
- _createModel(numClasses) {
293
+ _createModel(optimizer) {
312
294
  return __awaiter(this, void 0, void 0, function* () {
313
295
  try {
314
- const truncatedModel = yield this.loadModel();
315
296
  // 입력 이미지 크기에 맞게 모델 구조 수정
316
- const model = tf.sequential();
317
- // 추가적인 합성곱 층
318
- model.add(tf.layers.conv2d({
319
- filters: 64,
320
- kernelSize: 3,
321
- activation: 'relu',
322
- padding: 'same',
323
- inputShape: truncatedModel.outputs[0].shape.slice(1)
324
- }));
325
- model.add(tf.layers.flatten());
326
- // 더 깊은 밀집층
327
- model.add(tf.layers.dense({ units: 100, activation: 'relu' }));
328
- model.add(tf.layers.dense({ units: 100, activation: 'relu' }));
329
- // 드롭아웃 층
330
- model.add(tf.layers.dropout({ rate: 0.5 }));
331
- model.add(tf.layers.dense({ units: this.labels.length, activation: 'softmax' }));
332
- const optimizer = tf.train.adam(this.learningRate); // Optimizer를 생성하고 학습률을 설정합니다.
333
- model.compile({
334
- loss: (numClasses === 2) ? 'binaryCrossentropy' : 'categoricalCrossentropy',
297
+ let varianceScaling;
298
+ varianceScaling = tf.initializers.varianceScaling({});
299
+ const trainModel = tf.sequential({
300
+ layers: [
301
+ tf.layers.dense({
302
+ inputShape: this.mobilenetModule.outputs[0].shape.slice(1),
303
+ units: 64,
304
+ activation: 'relu',
305
+ kernelInitializer: varianceScaling,
306
+ useBias: true
307
+ }),
308
+ tf.layers.dense({
309
+ kernelInitializer: varianceScaling,
310
+ useBias: false,
311
+ activation: 'softmax',
312
+ units: this.classNumber.length,
313
+ })
314
+ ]
315
+ });
316
+ trainModel.compile({
317
+ loss: 'categoricalCrossentropy',
335
318
  optimizer: optimizer,
336
- metrics: ['accuracy', 'acc']
319
+ metrics: ['accuracy']
337
320
  });
338
- model.summary();
339
- return model;
321
+ return trainModel;
340
322
  }
341
323
  catch (error) {
342
324
  console.error('Failed to load model', error);
@@ -2,7 +2,6 @@ import * as tf from '@tensorflow/tfjs';
2
2
  import { io } from '@tensorflow/tfjs-core';
3
3
  interface LearningInterface {
4
4
  model: tf.LayersModel | null;
5
- labels: string[];
6
5
  isRunning: boolean;
7
6
  isReady: boolean;
8
7
  onProgress(progress: number): void;
@@ -7,21 +7,18 @@ declare class LearningMobilenet implements LearningInterface {
7
7
  batchSize: number;
8
8
  learningRate: number;
9
9
  validateRate: number;
10
- labels: string[];
11
- modelURL: string;
12
10
  isRunning: boolean;
13
11
  isReady: boolean;
14
12
  isTrainedDone: boolean;
15
13
  limitSize: number;
16
14
  mobilenetModule: tf.LayersModel | null;
17
- imageTensors: tf.Tensor<tf.Rank>[];
15
+ imageExamples: Float32Array[][];
16
+ classNumber: string[];
18
17
  readonly MOBILE_NET_INPUT_WIDTH = 224;
19
18
  readonly MOBILE_NET_INPUT_HEIGHT = 224;
20
19
  readonly MOBILE_NET_INPUT_CHANNEL = 3;
21
20
  readonly IMAGE_NORMALIZATION_FACTOR = 255;
22
- constructor({ modelURL, // 디폴트 mobilenet 이미지
23
- epochs, batchSize, limitSize, learningRate, validateRate, }?: {
24
- modelURL?: string;
21
+ constructor({ epochs, batchSize, limitSize, learningRate, validateRate, }?: {
25
22
  epochs?: number;
26
23
  batchSize?: number;
27
24
  limitSize?: number;
@@ -38,6 +35,8 @@ declare class LearningMobilenet implements LearningInterface {
38
35
  jsonURL: string;
39
36
  labels: Array<string>;
40
37
  }): Promise<void>;
38
+ private registerClassNumber;
39
+ private _convertToTfDataset;
41
40
  addData(label: string, data: any): Promise<void>;
42
41
  init(): Promise<void>;
43
42
  train(): Promise<tf.History>;
@@ -45,9 +44,6 @@ declare class LearningMobilenet implements LearningInterface {
45
44
  saveModel(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<void>;
46
45
  running(): boolean;
47
46
  ready(): boolean;
48
- private _preprocessedTargetData;
49
- private _preprocessedInputData;
50
- private loadModel;
51
47
  private _createModel;
52
48
  }
53
49
  export default LearningMobilenet;
@@ -0,0 +1,3 @@
1
+ type Drawable = ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement;
2
+ export declare function cropTo(image: Drawable, size: number, flipped?: boolean, canvas?: HTMLCanvasElement): HTMLCanvasElement;
3
+ export {};
@@ -0,0 +1,6 @@
1
+ export interface Sample {
2
+ data: Float32Array;
3
+ label: number[];
4
+ }
5
+ export declare function flatOneHot(label: number, numClasses: number): number[];
6
+ export declare function fisherYates(array: Float32Array[] | Sample[]): Float32Array[] | Sample[];
@@ -0,0 +1,7 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ export declare function isTensor(c: any): c is tf.Tensor;
3
+ export declare function loadModel(): Promise<tf.LayersModel>;
4
+ export declare function mobileNetURL(version: number): string;
5
+ export declare function imageToTensor(data: any): tf.Tensor3D;
6
+ export declare function capture(rasterElement: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, grayscale?: boolean): tf.Tensor<tf.Rank>;
7
+ export declare function cropTensor(img: tf.Tensor3D, grayscaleModel?: boolean, grayscaleInput?: boolean): tf.Tensor3D;
@@ -0,0 +1,3 @@
1
+ type Drawable = ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement;
2
+ export declare function cropTo(image: Drawable, size: number, flipped?: boolean, canvas?: HTMLCanvasElement): HTMLCanvasElement;
3
+ export {};
@@ -0,0 +1,47 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.cropTo = void 0;
4
+ const newCanvas = () => document.createElement('canvas');
5
+ function cropTo(image, size, flipped = false, canvas = newCanvas()) {
6
+ let width;
7
+ let height;
8
+ // If ImageData
9
+ if (image instanceof ImageData) {
10
+ width = image.width;
11
+ height = image.height;
12
+ }
13
+ // If image, bitmap, or canvas
14
+ else if (image instanceof HTMLImageElement || image instanceof HTMLCanvasElement) {
15
+ width = image.width;
16
+ height = image.height;
17
+ }
18
+ // If video element
19
+ else if (image instanceof HTMLVideoElement) {
20
+ width = image.videoWidth;
21
+ height = image.videoHeight;
22
+ }
23
+ else {
24
+ throw new Error("Unsupported Drawable type");
25
+ }
26
+ const min = Math.min(width, height);
27
+ const scale = size / min;
28
+ const scaledW = Math.ceil(width * scale);
29
+ const scaledH = Math.ceil(height * scale);
30
+ const dx = scaledW - size;
31
+ const dy = scaledH - size;
32
+ canvas.width = canvas.height = size;
33
+ const ctx = canvas.getContext('2d');
34
+ // Handle ImageData separately
35
+ if (image instanceof ImageData) {
36
+ ctx.putImageData(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1); // Adjust this if needed
37
+ }
38
+ else {
39
+ ctx.drawImage(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1, scaledW, scaledH);
40
+ }
41
+ if (flipped) {
42
+ ctx.scale(-1, 1);
43
+ ctx.drawImage(canvas, size * -1, 0);
44
+ }
45
+ return canvas;
46
+ }
47
+ exports.cropTo = cropTo;
@@ -0,0 +1,6 @@
1
+ export interface Sample {
2
+ data: Float32Array;
3
+ label: number[];
4
+ }
5
+ export declare function flatOneHot(label: number, numClasses: number): number[];
6
+ export declare function fisherYates(array: Float32Array[] | Sample[]): Float32Array[] | Sample[];
@@ -0,0 +1,21 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.fisherYates = exports.flatOneHot = void 0;
4
+ function flatOneHot(label, numClasses) {
5
+ const labelOneHot = new Array(numClasses).fill(0);
6
+ labelOneHot[label] = 1;
7
+ return labelOneHot;
8
+ }
9
+ exports.flatOneHot = flatOneHot;
10
+ function fisherYates(array) {
11
+ const length = array.length;
12
+ // need to clone array or we'd be editing original as we goo
13
+ const shuffled = array.slice();
14
+ for (let i = (length - 1); i > 0; i -= 1) {
15
+ let randomIndex;
16
+ randomIndex = Math.floor(Math.random() * (i + 1));
17
+ [shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex], shuffled[i]];
18
+ }
19
+ return shuffled;
20
+ }
21
+ exports.fisherYates = fisherYates;
@@ -0,0 +1,7 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ export declare function isTensor(c: any): c is tf.Tensor;
3
+ export declare function loadModel(): Promise<tf.LayersModel>;
4
+ export declare function mobileNetURL(version: number): string;
5
+ export declare function imageToTensor(data: any): tf.Tensor3D;
6
+ export declare function capture(rasterElement: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, grayscale?: boolean): tf.Tensor<tf.Rank>;
7
+ export declare function cropTensor(img: tf.Tensor3D, grayscaleModel?: boolean, grayscaleInput?: boolean): tf.Tensor3D;