learning_model 1.0.38 → 1.0.39
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Makefile +5 -0
- package/dist/index.bundle.js +1 -1
- package/dist/learning/base.d.ts +0 -1
- package/dist/learning/mobilenet.d.ts +5 -9
- package/dist/learning/mobilenet.js +115 -133
- package/dist/lib/learning/base.d.ts +0 -1
- package/dist/lib/learning/mobilenet.d.ts +5 -9
- package/dist/lib/utils/canvas.d.ts +3 -0
- package/dist/lib/utils/dataset.d.ts +6 -0
- package/dist/lib/utils/tf.d.ts +7 -0
- package/dist/utils/canvas.d.ts +3 -0
- package/dist/utils/canvas.js +47 -0
- package/dist/utils/dataset.d.ts +6 -0
- package/dist/utils/dataset.js +21 -0
- package/dist/utils/tf.d.ts +7 -0
- package/dist/utils/tf.js +124 -0
- package/lib/learning/base.ts +0 -2
- package/lib/learning/mobilenet.ts +145 -160
- package/lib/utils/canvas.ts +49 -0
- package/lib/utils/dataset.ts +24 -0
- package/lib/utils/tf.ts +94 -0
- package/package.json +1 -1
- package/dist/learning/util.d.ts +0 -2
- package/dist/learning/util.js +0 -40
- package/dist/lib/learning/util.d.ts +0 -2
- package/lib/learning/util.ts +0 -16
package/dist/utils/tf.js
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
|
|
3
|
+
if (k2 === undefined) k2 = k;
|
|
4
|
+
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
5
|
+
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
6
|
+
desc = { enumerable: true, get: function() { return m[k]; } };
|
|
7
|
+
}
|
|
8
|
+
Object.defineProperty(o, k2, desc);
|
|
9
|
+
}) : (function(o, m, k, k2) {
|
|
10
|
+
if (k2 === undefined) k2 = k;
|
|
11
|
+
o[k2] = m[k];
|
|
12
|
+
}));
|
|
13
|
+
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
|
|
14
|
+
Object.defineProperty(o, "default", { enumerable: true, value: v });
|
|
15
|
+
}) : function(o, v) {
|
|
16
|
+
o["default"] = v;
|
|
17
|
+
});
|
|
18
|
+
var __importStar = (this && this.__importStar) || function (mod) {
|
|
19
|
+
if (mod && mod.__esModule) return mod;
|
|
20
|
+
var result = {};
|
|
21
|
+
if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
|
|
22
|
+
__setModuleDefault(result, mod);
|
|
23
|
+
return result;
|
|
24
|
+
};
|
|
25
|
+
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
|
|
26
|
+
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
|
|
27
|
+
return new (P || (P = Promise))(function (resolve, reject) {
|
|
28
|
+
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
|
|
29
|
+
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
|
|
30
|
+
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
|
|
31
|
+
step((generator = generator.apply(thisArg, _arguments || [])).next());
|
|
32
|
+
});
|
|
33
|
+
};
|
|
34
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
35
|
+
exports.cropTensor = exports.capture = exports.imageToTensor = exports.mobileNetURL = exports.loadModel = exports.isTensor = void 0;
|
|
36
|
+
const tf = __importStar(require("@tensorflow/tfjs"));
|
|
37
|
+
function isTensor(c) {
|
|
38
|
+
return typeof c.dataId === 'object' && typeof c.shape === 'object';
|
|
39
|
+
}
|
|
40
|
+
exports.isTensor = isTensor;
|
|
41
|
+
function loadModel() {
|
|
42
|
+
return __awaiter(this, void 0, void 0, function* () {
|
|
43
|
+
const trainLayerV1 = 'conv_pw_13_relu';
|
|
44
|
+
const trainLayerV2 = 'out_relu';
|
|
45
|
+
var mobileNetVersion = 2;
|
|
46
|
+
const modelURL = mobileNetURL(mobileNetVersion);
|
|
47
|
+
const load_model = yield tf.loadLayersModel(modelURL);
|
|
48
|
+
if (mobileNetVersion == 1) {
|
|
49
|
+
const layer = load_model.getLayer(trainLayerV1);
|
|
50
|
+
const truncatedModel = tf.model({
|
|
51
|
+
inputs: load_model.inputs,
|
|
52
|
+
outputs: layer.output
|
|
53
|
+
});
|
|
54
|
+
const model = tf.sequential();
|
|
55
|
+
model.add(truncatedModel);
|
|
56
|
+
model.add(tf.layers.flatten());
|
|
57
|
+
return model;
|
|
58
|
+
}
|
|
59
|
+
else {
|
|
60
|
+
const layer = load_model.getLayer(trainLayerV2);
|
|
61
|
+
const truncatedModel = tf.model({
|
|
62
|
+
inputs: load_model.inputs,
|
|
63
|
+
outputs: layer.output
|
|
64
|
+
});
|
|
65
|
+
const model = tf.sequential();
|
|
66
|
+
model.add(truncatedModel);
|
|
67
|
+
model.add(tf.layers.globalAveragePooling2d({})); // go from shape [7, 7, 1280] to [1280]
|
|
68
|
+
return model;
|
|
69
|
+
}
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
exports.loadModel = loadModel;
|
|
73
|
+
function mobileNetURL(version) {
|
|
74
|
+
if (version == 1) {
|
|
75
|
+
return "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json";
|
|
76
|
+
}
|
|
77
|
+
return "https://storage.googleapis.com/teachable-machine-models/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top/model.json";
|
|
78
|
+
}
|
|
79
|
+
exports.mobileNetURL = mobileNetURL;
|
|
80
|
+
function imageToTensor(data) {
|
|
81
|
+
let tensor;
|
|
82
|
+
if (data instanceof tf.Tensor) {
|
|
83
|
+
tensor = data;
|
|
84
|
+
}
|
|
85
|
+
else {
|
|
86
|
+
// MobileNet 모델 로드
|
|
87
|
+
tensor = tf.browser.fromPixels(data);
|
|
88
|
+
}
|
|
89
|
+
return tensor;
|
|
90
|
+
}
|
|
91
|
+
exports.imageToTensor = imageToTensor;
|
|
92
|
+
function capture(rasterElement, grayscale) {
|
|
93
|
+
return tf.tidy(() => {
|
|
94
|
+
const pixels = tf.browser.fromPixels(rasterElement);
|
|
95
|
+
// crop the image so we're using the center square
|
|
96
|
+
const cropped = cropTensor(pixels, grayscale);
|
|
97
|
+
// Expand the outer most dimension so we have a batch size of 1
|
|
98
|
+
const batchedImage = cropped.expandDims(0);
|
|
99
|
+
// Normalize the image between -1 and a1. The image comes in between 0-255
|
|
100
|
+
// so we divide by 127 and subtract 1.
|
|
101
|
+
return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
|
|
102
|
+
});
|
|
103
|
+
}
|
|
104
|
+
exports.capture = capture;
|
|
105
|
+
function cropTensor(img, grayscaleModel, grayscaleInput) {
|
|
106
|
+
const size = Math.min(img.shape[0], img.shape[1]);
|
|
107
|
+
const centerHeight = img.shape[0] / 2;
|
|
108
|
+
const beginHeight = centerHeight - (size / 2);
|
|
109
|
+
const centerWidth = img.shape[1] / 2;
|
|
110
|
+
const beginWidth = centerWidth - (size / 2);
|
|
111
|
+
if (grayscaleModel && !grayscaleInput) {
|
|
112
|
+
//cropped rgb data
|
|
113
|
+
let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
|
|
114
|
+
grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3]);
|
|
115
|
+
const rgb_weights = [0.2989, 0.5870, 0.1140];
|
|
116
|
+
grayscale_cropped = tf.mul(grayscale_cropped, rgb_weights);
|
|
117
|
+
grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);
|
|
118
|
+
grayscale_cropped = tf.sum(grayscale_cropped, -1);
|
|
119
|
+
grayscale_cropped = tf.expandDims(grayscale_cropped, -1);
|
|
120
|
+
return grayscale_cropped;
|
|
121
|
+
}
|
|
122
|
+
return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
|
|
123
|
+
}
|
|
124
|
+
exports.cropTensor = cropTensor;
|
package/lib/learning/base.ts
CHANGED
|
@@ -5,8 +5,14 @@
|
|
|
5
5
|
///////////////////////////////////////////////////////////////////////////
|
|
6
6
|
|
|
7
7
|
import * as tf from '@tensorflow/tfjs';
|
|
8
|
+
import { dispose } from '@tensorflow/tfjs';
|
|
8
9
|
import { io } from '@tensorflow/tfjs-core';
|
|
9
10
|
import LearningInterface from './base';
|
|
11
|
+
import { capture, loadModel, isTensor } from '../utils/tf';
|
|
12
|
+
import { cropTo } from '../utils/canvas';
|
|
13
|
+
import { flatOneHot, fisherYates, Sample } from '../utils/dataset';
|
|
14
|
+
import { Initializer } from '@tensorflow/tfjs-layers/dist/initializers';
|
|
15
|
+
|
|
10
16
|
|
|
11
17
|
class LearningMobilenet implements LearningInterface {
|
|
12
18
|
model: tf.LayersModel | null;
|
|
@@ -14,41 +20,39 @@ class LearningMobilenet implements LearningInterface {
|
|
|
14
20
|
batchSize: number;
|
|
15
21
|
learningRate: number;
|
|
16
22
|
validateRate: number;
|
|
17
|
-
labels: string[];
|
|
18
|
-
modelURL: string;
|
|
19
23
|
isRunning: boolean;
|
|
20
24
|
isReady: boolean;
|
|
21
25
|
isTrainedDone: boolean;
|
|
22
26
|
limitSize: number;
|
|
23
27
|
mobilenetModule: tf.LayersModel | null;
|
|
24
|
-
|
|
28
|
+
imageExamples: Float32Array[][] = [];
|
|
29
|
+
classNumber: string[] = [];
|
|
25
30
|
|
|
26
31
|
readonly MOBILE_NET_INPUT_WIDTH = 224;
|
|
27
32
|
readonly MOBILE_NET_INPUT_HEIGHT = 224;
|
|
28
33
|
readonly MOBILE_NET_INPUT_CHANNEL = 3;
|
|
29
34
|
readonly IMAGE_NORMALIZATION_FACTOR = 255.0;
|
|
30
35
|
|
|
36
|
+
|
|
31
37
|
constructor({
|
|
32
|
-
|
|
33
|
-
epochs = 10,
|
|
38
|
+
epochs = 50,
|
|
34
39
|
batchSize = 16,
|
|
35
40
|
limitSize = 2,
|
|
36
41
|
learningRate = 0.001,
|
|
37
|
-
validateRate = 0.
|
|
38
|
-
}: {
|
|
42
|
+
validateRate = 0.15,
|
|
43
|
+
}: { epochs?: number, batchSize?: number, limitSize?: number, learningRate?: number, validateRate?: number} = {}) {
|
|
39
44
|
this.model = null;
|
|
40
45
|
this.epochs = epochs;
|
|
41
46
|
this.batchSize = batchSize;
|
|
42
47
|
this.learningRate = learningRate;
|
|
43
48
|
this.validateRate = validateRate;
|
|
44
|
-
this.labels = [];
|
|
45
|
-
this.modelURL = modelURL;
|
|
46
49
|
this.isRunning = false;
|
|
47
50
|
this.isReady = false;
|
|
48
51
|
this.isTrainedDone = false;
|
|
49
52
|
this.limitSize = limitSize;
|
|
50
53
|
this.mobilenetModule = null;
|
|
51
|
-
this.
|
|
54
|
+
this.classNumber = [];
|
|
55
|
+
|
|
52
56
|
}
|
|
53
57
|
|
|
54
58
|
// 진행 상태를 나타내는 이벤트를 정의합니다.
|
|
@@ -73,44 +77,92 @@ class LearningMobilenet implements LearningInterface {
|
|
|
73
77
|
|
|
74
78
|
try {
|
|
75
79
|
this.model = await tf.loadLayersModel(jsonURL);
|
|
76
|
-
|
|
80
|
+
for(var i = 0; i < labels.length; i++) {
|
|
81
|
+
this.registerClassNumber(labels[i]);
|
|
82
|
+
}
|
|
77
83
|
this.isReady = true;
|
|
78
84
|
this.model.summary();
|
|
79
85
|
} catch (error) {
|
|
80
86
|
console.error('Model load failed', error);
|
|
81
87
|
throw error;
|
|
82
88
|
}
|
|
83
|
-
|
|
84
89
|
}
|
|
85
90
|
|
|
86
91
|
|
|
92
|
+
private registerClassNumber(value: string): number {
|
|
93
|
+
// 중복 값의 인덱스를 찾습니다.
|
|
94
|
+
const existingIndex = this.classNumber.indexOf(value);
|
|
95
|
+
|
|
96
|
+
// 중복 값이 있다면 해당 인덱스를 반환합니다.
|
|
97
|
+
if (existingIndex !== -1) {
|
|
98
|
+
return existingIndex;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
// 중복 값이 없다면 새로운 항목을 추가하고 그 인덱스를 반환합니다.
|
|
102
|
+
this.classNumber.push(value);
|
|
103
|
+
return this.classNumber.length - 1;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
private _convertToTfDataset() {
|
|
107
|
+
for (let i = 0; i < this.imageExamples.length; i++) {
|
|
108
|
+
this.imageExamples[i] = fisherYates(this.imageExamples[i]) as Float32Array[];
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
const trainDataset: Sample[] = [];
|
|
112
|
+
const validationDataset: Sample[] = [];
|
|
113
|
+
|
|
114
|
+
for (let i = 0; i < this.imageExamples.length; i++) {
|
|
115
|
+
const classLength = this.imageExamples[i].length;
|
|
116
|
+
|
|
117
|
+
// 클래스의 전체 데이터 수를 사용하여 학습 및 검증 데이터 수 계산
|
|
118
|
+
const numValidation = Math.ceil(this.validateRate * classLength);
|
|
119
|
+
const numTrain = classLength - numValidation;
|
|
120
|
+
|
|
121
|
+
// One-Hot 인코딩을 사용하여 라벨 생성
|
|
122
|
+
const y = flatOneHot(i, this.classNumber.length);
|
|
123
|
+
|
|
124
|
+
// numTrain과 numValidation에 따라 데이터를 학습 및 검증 데이터로 분할
|
|
125
|
+
const classTrain: Sample[] = this.imageExamples[i].slice(0, numTrain).map(dataArray => ({ data: dataArray, label: y }));
|
|
126
|
+
trainDataset.push(...classTrain);
|
|
127
|
+
|
|
128
|
+
const classValidation: Sample[] = this.imageExamples[i].slice(numTrain, numTrain + numValidation).map(dataArray => ({ data: dataArray, label: y }));
|
|
129
|
+
validationDataset.push(...classValidation);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
// Shuffle entire datasets
|
|
133
|
+
const shuffledTrainDataset = fisherYates(trainDataset) as Sample[];
|
|
134
|
+
const shuffledValidationDataset = fisherYates(validationDataset) as Sample[];
|
|
135
|
+
|
|
136
|
+
// Convert to tf.data.Dataset
|
|
137
|
+
const trainX = tf.data.array(shuffledTrainDataset.map(sample => sample.data));
|
|
138
|
+
const validationX = tf.data.array(shuffledValidationDataset.map(sample => sample.data));
|
|
139
|
+
const trainY = tf.data.array(shuffledTrainDataset.map(sample => sample.label));
|
|
140
|
+
const validationY = tf.data.array(shuffledValidationDataset.map(sample => sample.label));
|
|
141
|
+
|
|
142
|
+
return {
|
|
143
|
+
trainDataset: tf.data.zip({ xs: trainX, ys: trainY }),
|
|
144
|
+
validationDataset: tf.data.zip({ xs: validationX, ys: validationY })
|
|
145
|
+
};
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
|
|
87
149
|
// 학습 데이타 등록
|
|
88
150
|
public async addData(label: string, data: any): Promise<void> {
|
|
89
151
|
try {
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
// Add an extra dimension to the tensor
|
|
99
|
-
const expandedTensor = normalizedTensor.expandDims(0);
|
|
100
|
-
|
|
101
|
-
console.log('predict extend', expandedTensor);
|
|
102
|
-
const predict = this.mobilenetModule.predict(expandedTensor);
|
|
103
|
-
console.log('predict', predict);
|
|
104
|
-
return predict;
|
|
105
|
-
} else {
|
|
106
|
-
throw new Error('mobilenetModule is null');
|
|
152
|
+
if (this.mobilenetModule !== null) {
|
|
153
|
+
const cap = isTensor(data) ? data : capture(data, false);
|
|
154
|
+
const predict = this.mobilenetModule.predict(cap) as tf.Tensor;
|
|
155
|
+
const activation = predict.dataSync() as Float32Array;
|
|
156
|
+
const classIndex = this.registerClassNumber(label);
|
|
157
|
+
if (!this.imageExamples[classIndex]) {
|
|
158
|
+
this.imageExamples[classIndex] = [];
|
|
107
159
|
}
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
160
|
+
this.imageExamples[classIndex].push(activation);
|
|
161
|
+
if(this.classNumber.length >= this.limitSize) {
|
|
162
|
+
this.isReady = true;
|
|
163
|
+
}
|
|
164
|
+
} else {
|
|
165
|
+
throw new Error('mobilenetModule is null');
|
|
114
166
|
}
|
|
115
167
|
return Promise.resolve();
|
|
116
168
|
} catch (error) {
|
|
@@ -121,13 +173,14 @@ class LearningMobilenet implements LearningInterface {
|
|
|
121
173
|
|
|
122
174
|
public async init() {
|
|
123
175
|
try {
|
|
124
|
-
this.mobilenetModule = await
|
|
176
|
+
this.mobilenetModule = await loadModel();
|
|
125
177
|
} catch(error) {
|
|
126
178
|
console.log('init Error', error);
|
|
127
179
|
throw error;
|
|
128
180
|
}
|
|
129
181
|
}
|
|
130
182
|
|
|
183
|
+
|
|
131
184
|
// 모델 학습 처리
|
|
132
185
|
public async train(): Promise<tf.History> {
|
|
133
186
|
if (this.isRunning) {
|
|
@@ -139,12 +192,10 @@ class LearningMobilenet implements LearningInterface {
|
|
|
139
192
|
onTrainBegin: (log: any) => {
|
|
140
193
|
this.isTrainedDone = false;
|
|
141
194
|
this.onTrainBegin(log);
|
|
142
|
-
console.log('Training has started.');
|
|
143
195
|
},
|
|
144
196
|
onTrainEnd: (log: any) => {
|
|
145
197
|
this.isTrainedDone = true;
|
|
146
198
|
this.onTrainEnd(log);
|
|
147
|
-
console.log('Training has ended.');
|
|
148
199
|
this.isRunning = false;
|
|
149
200
|
},
|
|
150
201
|
onBatchBegin: (batch: any, logs: any) => {
|
|
@@ -158,8 +209,6 @@ class LearningMobilenet implements LearningInterface {
|
|
|
158
209
|
},
|
|
159
210
|
onEpochEnd: (epoch: number, logs: any) => {
|
|
160
211
|
this.onEpochEnd(epoch, logs);
|
|
161
|
-
console.log(`Epoch ${epoch+1} has ended.`);
|
|
162
|
-
console.log('Loss:', logs);
|
|
163
212
|
this.onLoss(logs.loss);
|
|
164
213
|
this.onProgress(epoch+1);
|
|
165
214
|
this.onEvents({
|
|
@@ -171,22 +220,24 @@ class LearningMobilenet implements LearningInterface {
|
|
|
171
220
|
|
|
172
221
|
try {
|
|
173
222
|
this.isRunning = true;
|
|
174
|
-
if (this.
|
|
223
|
+
if (this.classNumber.length < this.limitSize) {
|
|
175
224
|
return Promise.reject(new Error('Please train Data need over 2 data length'));
|
|
176
225
|
}
|
|
177
|
-
|
|
178
|
-
const
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
const
|
|
182
|
-
|
|
183
|
-
const history = await this.model.fit(inputData, targetData, {
|
|
226
|
+
const datasets = this._convertToTfDataset();
|
|
227
|
+
const trainData = datasets.trainDataset.batch(this.batchSize);
|
|
228
|
+
const validationData = datasets.validationDataset.batch(this.batchSize);
|
|
229
|
+
const optimizer = tf.train.adam(this.learningRate);
|
|
230
|
+
const trainModel = await this._createModel(optimizer);
|
|
231
|
+
const history = await trainModel.fitDataset(trainData, {
|
|
184
232
|
epochs: this.epochs,
|
|
185
|
-
|
|
186
|
-
validationSplit: this.validateRate, // 검증 데이터의 비율 설정
|
|
233
|
+
validationData: validationData,
|
|
187
234
|
callbacks: customCallback
|
|
188
235
|
});
|
|
189
|
-
|
|
236
|
+
const jointModel = tf.sequential();
|
|
237
|
+
jointModel.add(this.mobilenetModule!);
|
|
238
|
+
jointModel.add(trainModel);
|
|
239
|
+
this.model = jointModel;
|
|
240
|
+
optimizer.dispose();
|
|
190
241
|
return history;
|
|
191
242
|
} catch (error) {
|
|
192
243
|
this.isRunning = false;
|
|
@@ -202,38 +253,19 @@ class LearningMobilenet implements LearningInterface {
|
|
|
202
253
|
}
|
|
203
254
|
|
|
204
255
|
try {
|
|
205
|
-
const
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
const normalizedTensor = resizedTensor.div(255.0);
|
|
212
|
-
|
|
213
|
-
// Add an extra dimension to the tensor
|
|
214
|
-
const expandedTensor = normalizedTensor.expandDims(0);
|
|
215
|
-
|
|
216
|
-
console.log('predict extend', expandedTensor);
|
|
217
|
-
const predict = this.mobilenetModule.predict(expandedTensor);
|
|
218
|
-
console.log('predict', predict);
|
|
219
|
-
return predict;
|
|
220
|
-
} else {
|
|
221
|
-
throw new Error('mobilenetModule is null');
|
|
222
|
-
}
|
|
256
|
+
const classProbabilities = new Map<string, number>();
|
|
257
|
+
const croppedImage = cropTo(data, 224, false);
|
|
258
|
+
|
|
259
|
+
const logits = tf.tidy(() => {
|
|
260
|
+
const captured = capture(croppedImage, false);
|
|
261
|
+
return this.model!.predict(captured);
|
|
223
262
|
});
|
|
224
|
-
const
|
|
225
|
-
const predictedClass = predictions.as1D().argMax();
|
|
226
|
-
const classId = (await predictedClass.data())[0];
|
|
227
|
-
console.log('classId', classId, this.labels[classId]);
|
|
228
|
-
const predictionsData = await predictions.data(); // 예측 텐서의 데이터를 비동기로 가져옴
|
|
229
|
-
const classProbabilities = new Map<string, number>(); // 클래스별 확률 누적값을 저장할 맵
|
|
230
|
-
|
|
231
|
-
console.log('predictionsData', predictionsData);
|
|
263
|
+
const values = await (logits as tf.Tensor<tf.Rank>).data();
|
|
232
264
|
const EPSILON = 1e-6; // 매우 작은 값을 표현하기 위한 엡실론
|
|
233
|
-
for (let i = 0; i <
|
|
234
|
-
let probability = Math.max(0, Math.min(1,
|
|
265
|
+
for (let i = 0; i < values.length; i++) {
|
|
266
|
+
let probability = Math.max(0, Math.min(1, values[i])); // 확률 값을 0과 1 사이로 조정
|
|
235
267
|
probability = probability < EPSILON ? 0 : probability; // 매우 작은 확률 값을 0으로 간주
|
|
236
|
-
const className = this.
|
|
268
|
+
const className = this.classNumber[i]; // 클래스 이름
|
|
237
269
|
const existingProbability = classProbabilities.get(className);
|
|
238
270
|
if (existingProbability !== undefined) {
|
|
239
271
|
classProbabilities.set(className, existingProbability + probability);
|
|
@@ -241,9 +273,7 @@ class LearningMobilenet implements LearningInterface {
|
|
|
241
273
|
classProbabilities.set(className, probability);
|
|
242
274
|
}
|
|
243
275
|
}
|
|
244
|
-
|
|
245
|
-
predictedClass.dispose();
|
|
246
|
-
predictions.dispose();
|
|
276
|
+
dispose(logits);
|
|
247
277
|
return classProbabilities;
|
|
248
278
|
} catch (error) {
|
|
249
279
|
throw error;
|
|
@@ -253,15 +283,12 @@ class LearningMobilenet implements LearningInterface {
|
|
|
253
283
|
|
|
254
284
|
// 모델 저장
|
|
255
285
|
public async saveModel(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<void> {
|
|
256
|
-
console.log('saved model');
|
|
257
286
|
if (!this.isTrainedDone) {
|
|
258
287
|
return Promise.reject(new Error('Train is not done status'));
|
|
259
288
|
}
|
|
260
289
|
await this.model?.save(handlerOrURL, config);
|
|
261
290
|
}
|
|
262
291
|
|
|
263
|
-
|
|
264
|
-
|
|
265
292
|
// 진행중 여부
|
|
266
293
|
public running(): boolean {
|
|
267
294
|
return this.isRunning;
|
|
@@ -272,83 +299,41 @@ class LearningMobilenet implements LearningInterface {
|
|
|
272
299
|
return this.isReady;
|
|
273
300
|
}
|
|
274
301
|
|
|
275
|
-
|
|
276
|
-
// target 라벨 데이타
|
|
277
|
-
private _preprocessedTargetData(): tf.Tensor<tf.Rank> {
|
|
278
|
-
// 라벨 unique 처리 & 배열 리턴
|
|
279
|
-
console.log('uniqueLabels.length', this.labels, this.labels.length);
|
|
280
|
-
const labelIndices = this.labels.map((label) => this.labels.indexOf(label));
|
|
281
|
-
console.log('labelIndices', labelIndices);
|
|
282
|
-
const oneHotEncode = tf.oneHot(tf.tensor1d(labelIndices, 'int32'),this.labels.length);
|
|
283
|
-
console.log('oneHotEncode', oneHotEncode);
|
|
284
|
-
return oneHotEncode
|
|
285
|
-
}
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
// 입력 이미지 데이타
|
|
289
|
-
private _preprocessedInputData(model: tf.LayersModel): tf.Tensor<tf.Rank> {
|
|
290
|
-
// 이미지 배열을 배치로 변환 - [null, 224, 224, 3]
|
|
291
|
-
const inputShape = model.inputs[0].shape;
|
|
292
|
-
console.log('inputShape', inputShape);
|
|
293
|
-
// inputShape를 이와 같이 포멧 맞춘다. for reshape to [224, 224, 3]
|
|
294
|
-
const inputShapeArray = inputShape.slice(1) as number[];
|
|
295
|
-
console.log('inputShapeArray', inputShapeArray);
|
|
296
|
-
const inputBatch = tf.stack(this.imageTensors.map((image) => {
|
|
297
|
-
return tf.reshape(image, inputShapeArray);
|
|
298
|
-
}));
|
|
299
|
-
return inputBatch
|
|
300
|
-
}
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
private async loadModel(): Promise<tf.LayersModel> {
|
|
304
|
-
const load_model = await tf.loadLayersModel(this.modelURL);
|
|
305
|
-
load_model.summary();
|
|
306
|
-
const layer = load_model.getLayer('conv_pw_13_relu');
|
|
307
|
-
const truncatedModel = tf.model({
|
|
308
|
-
inputs: load_model.inputs,
|
|
309
|
-
outputs: layer.output
|
|
310
|
-
})
|
|
311
|
-
return truncatedModel
|
|
312
|
-
}
|
|
313
|
-
|
|
314
302
|
// 모델 저장
|
|
315
|
-
private async _createModel(
|
|
303
|
+
private async _createModel(optimizer: tf.AdamOptimizer): Promise<tf.Sequential> {
|
|
316
304
|
try {
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
});
|
|
347
|
-
model.summary();
|
|
348
|
-
return model;
|
|
305
|
+
// 입력 이미지 크기에 맞게 모델 구조 수정
|
|
306
|
+
let varianceScaling: Initializer;
|
|
307
|
+
varianceScaling = tf.initializers.varianceScaling({});
|
|
308
|
+
|
|
309
|
+
const trainModel = tf.sequential({
|
|
310
|
+
layers: [
|
|
311
|
+
tf.layers.dense({
|
|
312
|
+
inputShape: this.mobilenetModule!.outputs[0].shape.slice(1),
|
|
313
|
+
units: 64,
|
|
314
|
+
activation: 'relu',
|
|
315
|
+
kernelInitializer: varianceScaling, // 'varianceScaling'
|
|
316
|
+
useBias: true
|
|
317
|
+
}),
|
|
318
|
+
tf.layers.dense({
|
|
319
|
+
kernelInitializer: varianceScaling, // 'varianceScaling'
|
|
320
|
+
useBias: false,
|
|
321
|
+
activation: 'softmax',
|
|
322
|
+
units: this.classNumber.length,
|
|
323
|
+
})
|
|
324
|
+
]
|
|
325
|
+
});
|
|
326
|
+
|
|
327
|
+
trainModel.compile({
|
|
328
|
+
loss: 'categoricalCrossentropy',
|
|
329
|
+
optimizer: optimizer,
|
|
330
|
+
metrics: ['accuracy']
|
|
331
|
+
});
|
|
332
|
+
|
|
333
|
+
return trainModel;
|
|
349
334
|
} catch (error) {
|
|
350
|
-
|
|
351
|
-
|
|
335
|
+
console.error('Failed to load model', error);
|
|
336
|
+
throw error;
|
|
352
337
|
}
|
|
353
338
|
}
|
|
354
339
|
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
type Drawable = ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement;
|
|
2
|
+
|
|
3
|
+
const newCanvas = () => document.createElement('canvas');
|
|
4
|
+
|
|
5
|
+
export function cropTo(image: Drawable, size: number, flipped = false, canvas: HTMLCanvasElement = newCanvas()) {
|
|
6
|
+
let width: number;
|
|
7
|
+
let height: number;
|
|
8
|
+
|
|
9
|
+
// If ImageData
|
|
10
|
+
if (image instanceof ImageData) {
|
|
11
|
+
width = image.width;
|
|
12
|
+
height = image.height;
|
|
13
|
+
}
|
|
14
|
+
// If image, bitmap, or canvas
|
|
15
|
+
else if (image instanceof HTMLImageElement || image instanceof HTMLCanvasElement) {
|
|
16
|
+
width = image.width;
|
|
17
|
+
height = image.height;
|
|
18
|
+
}
|
|
19
|
+
// If video element
|
|
20
|
+
else if (image instanceof HTMLVideoElement) {
|
|
21
|
+
width = image.videoWidth;
|
|
22
|
+
height = image.videoHeight;
|
|
23
|
+
} else {
|
|
24
|
+
throw new Error("Unsupported Drawable type");
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
const min = Math.min(width, height);
|
|
28
|
+
const scale = size / min;
|
|
29
|
+
const scaledW = Math.ceil(width * scale);
|
|
30
|
+
const scaledH = Math.ceil(height * scale);
|
|
31
|
+
const dx = scaledW - size;
|
|
32
|
+
const dy = scaledH - size;
|
|
33
|
+
canvas.width = canvas.height = size;
|
|
34
|
+
const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;
|
|
35
|
+
|
|
36
|
+
// Handle ImageData separately
|
|
37
|
+
if (image instanceof ImageData) {
|
|
38
|
+
ctx.putImageData(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1); // Adjust this if needed
|
|
39
|
+
} else {
|
|
40
|
+
ctx.drawImage(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1, scaledW, scaledH);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if (flipped) {
|
|
44
|
+
ctx.scale(-1, 1);
|
|
45
|
+
ctx.drawImage(canvas, size * -1, 0);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return canvas;
|
|
49
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
export interface Sample {
|
|
2
|
+
data: Float32Array;
|
|
3
|
+
label: number[];
|
|
4
|
+
}
|
|
5
|
+
|
|
6
|
+
export function flatOneHot(label: number, numClasses: number) {
|
|
7
|
+
const labelOneHot = new Array(numClasses).fill(0) as number[];
|
|
8
|
+
labelOneHot[label] = 1;
|
|
9
|
+
|
|
10
|
+
return labelOneHot;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export function fisherYates(array: Float32Array[] | Sample[]) {
|
|
14
|
+
const length = array.length;
|
|
15
|
+
|
|
16
|
+
// need to clone array or we'd be editing original as we goo
|
|
17
|
+
const shuffled = array.slice();
|
|
18
|
+
for (let i = (length - 1); i > 0; i -= 1) {
|
|
19
|
+
let randomIndex ;
|
|
20
|
+
randomIndex = Math.floor(Math.random() * (i + 1));
|
|
21
|
+
[shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex],shuffled[i]];
|
|
22
|
+
}
|
|
23
|
+
return shuffled;
|
|
24
|
+
}
|