learning_model 1.0.38 → 1.0.40
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Makefile +5 -0
- package/dist/index.bundle.js +1 -1
- package/dist/learning/base.d.ts +0 -1
- package/dist/learning/mobilenet.d.ts +5 -9
- package/dist/learning/mobilenet.js +120 -136
- package/dist/lib/learning/base.d.ts +0 -1
- package/dist/lib/learning/mobilenet.d.ts +5 -9
- package/dist/lib/utils/canvas.d.ts +3 -0
- package/dist/lib/utils/dataset.d.ts +6 -0
- package/dist/lib/utils/tf.d.ts +7 -0
- package/dist/utils/canvas.d.ts +3 -0
- package/dist/utils/canvas.js +47 -0
- package/dist/utils/dataset.d.ts +6 -0
- package/dist/utils/dataset.js +21 -0
- package/dist/utils/tf.d.ts +7 -0
- package/dist/utils/tf.js +124 -0
- package/lib/learning/base.ts +0 -2
- package/lib/learning/mobilenet.ts +150 -163
- package/lib/utils/canvas.ts +49 -0
- package/lib/utils/dataset.ts +24 -0
- package/lib/utils/tf.ts +94 -0
- package/package.json +1 -1
- package/dist/learning/util.d.ts +0 -2
- package/dist/learning/util.js +0 -40
- package/dist/lib/learning/util.d.ts +0 -2
- package/lib/learning/util.ts +0 -16
package/dist/utils/tf.js
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
|
|
3
|
+
if (k2 === undefined) k2 = k;
|
|
4
|
+
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
5
|
+
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
6
|
+
desc = { enumerable: true, get: function() { return m[k]; } };
|
|
7
|
+
}
|
|
8
|
+
Object.defineProperty(o, k2, desc);
|
|
9
|
+
}) : (function(o, m, k, k2) {
|
|
10
|
+
if (k2 === undefined) k2 = k;
|
|
11
|
+
o[k2] = m[k];
|
|
12
|
+
}));
|
|
13
|
+
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
|
|
14
|
+
Object.defineProperty(o, "default", { enumerable: true, value: v });
|
|
15
|
+
}) : function(o, v) {
|
|
16
|
+
o["default"] = v;
|
|
17
|
+
});
|
|
18
|
+
var __importStar = (this && this.__importStar) || function (mod) {
|
|
19
|
+
if (mod && mod.__esModule) return mod;
|
|
20
|
+
var result = {};
|
|
21
|
+
if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
|
|
22
|
+
__setModuleDefault(result, mod);
|
|
23
|
+
return result;
|
|
24
|
+
};
|
|
25
|
+
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
|
|
26
|
+
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
|
|
27
|
+
return new (P || (P = Promise))(function (resolve, reject) {
|
|
28
|
+
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
|
|
29
|
+
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
|
|
30
|
+
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
|
|
31
|
+
step((generator = generator.apply(thisArg, _arguments || [])).next());
|
|
32
|
+
});
|
|
33
|
+
};
|
|
34
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
35
|
+
exports.cropTensor = exports.capture = exports.imageToTensor = exports.mobileNetURL = exports.loadModel = exports.isTensor = void 0;
|
|
36
|
+
const tf = __importStar(require("@tensorflow/tfjs"));
|
|
37
|
+
function isTensor(c) {
|
|
38
|
+
return typeof c.dataId === 'object' && typeof c.shape === 'object';
|
|
39
|
+
}
|
|
40
|
+
exports.isTensor = isTensor;
|
|
41
|
+
function loadModel() {
|
|
42
|
+
return __awaiter(this, void 0, void 0, function* () {
|
|
43
|
+
const trainLayerV1 = 'conv_pw_13_relu';
|
|
44
|
+
const trainLayerV2 = 'out_relu';
|
|
45
|
+
var mobileNetVersion = 2;
|
|
46
|
+
const modelURL = mobileNetURL(mobileNetVersion);
|
|
47
|
+
const load_model = yield tf.loadLayersModel(modelURL);
|
|
48
|
+
if (mobileNetVersion == 1) {
|
|
49
|
+
const layer = load_model.getLayer(trainLayerV1);
|
|
50
|
+
const truncatedModel = tf.model({
|
|
51
|
+
inputs: load_model.inputs,
|
|
52
|
+
outputs: layer.output
|
|
53
|
+
});
|
|
54
|
+
const model = tf.sequential();
|
|
55
|
+
model.add(truncatedModel);
|
|
56
|
+
model.add(tf.layers.flatten());
|
|
57
|
+
return model;
|
|
58
|
+
}
|
|
59
|
+
else {
|
|
60
|
+
const layer = load_model.getLayer(trainLayerV2);
|
|
61
|
+
const truncatedModel = tf.model({
|
|
62
|
+
inputs: load_model.inputs,
|
|
63
|
+
outputs: layer.output
|
|
64
|
+
});
|
|
65
|
+
const model = tf.sequential();
|
|
66
|
+
model.add(truncatedModel);
|
|
67
|
+
model.add(tf.layers.globalAveragePooling2d({})); // go from shape [7, 7, 1280] to [1280]
|
|
68
|
+
return model;
|
|
69
|
+
}
|
|
70
|
+
});
|
|
71
|
+
}
|
|
72
|
+
exports.loadModel = loadModel;
|
|
73
|
+
function mobileNetURL(version) {
|
|
74
|
+
if (version == 1) {
|
|
75
|
+
return "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json";
|
|
76
|
+
}
|
|
77
|
+
return "https://storage.googleapis.com/teachable-machine-models/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top/model.json";
|
|
78
|
+
}
|
|
79
|
+
exports.mobileNetURL = mobileNetURL;
|
|
80
|
+
function imageToTensor(data) {
|
|
81
|
+
let tensor;
|
|
82
|
+
if (data instanceof tf.Tensor) {
|
|
83
|
+
tensor = data;
|
|
84
|
+
}
|
|
85
|
+
else {
|
|
86
|
+
// MobileNet 모델 로드
|
|
87
|
+
tensor = tf.browser.fromPixels(data);
|
|
88
|
+
}
|
|
89
|
+
return tensor;
|
|
90
|
+
}
|
|
91
|
+
exports.imageToTensor = imageToTensor;
|
|
92
|
+
function capture(rasterElement, grayscale) {
|
|
93
|
+
return tf.tidy(() => {
|
|
94
|
+
const pixels = tf.browser.fromPixels(rasterElement);
|
|
95
|
+
// crop the image so we're using the center square
|
|
96
|
+
const cropped = cropTensor(pixels, grayscale);
|
|
97
|
+
// Expand the outer most dimension so we have a batch size of 1
|
|
98
|
+
const batchedImage = cropped.expandDims(0);
|
|
99
|
+
// Normalize the image between -1 and a1. The image comes in between 0-255
|
|
100
|
+
// so we divide by 127 and subtract 1.
|
|
101
|
+
return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
|
|
102
|
+
});
|
|
103
|
+
}
|
|
104
|
+
exports.capture = capture;
|
|
105
|
+
function cropTensor(img, grayscaleModel, grayscaleInput) {
|
|
106
|
+
const size = Math.min(img.shape[0], img.shape[1]);
|
|
107
|
+
const centerHeight = img.shape[0] / 2;
|
|
108
|
+
const beginHeight = centerHeight - (size / 2);
|
|
109
|
+
const centerWidth = img.shape[1] / 2;
|
|
110
|
+
const beginWidth = centerWidth - (size / 2);
|
|
111
|
+
if (grayscaleModel && !grayscaleInput) {
|
|
112
|
+
//cropped rgb data
|
|
113
|
+
let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
|
|
114
|
+
grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3]);
|
|
115
|
+
const rgb_weights = [0.2989, 0.5870, 0.1140];
|
|
116
|
+
grayscale_cropped = tf.mul(grayscale_cropped, rgb_weights);
|
|
117
|
+
grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);
|
|
118
|
+
grayscale_cropped = tf.sum(grayscale_cropped, -1);
|
|
119
|
+
grayscale_cropped = tf.expandDims(grayscale_cropped, -1);
|
|
120
|
+
return grayscale_cropped;
|
|
121
|
+
}
|
|
122
|
+
return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
|
|
123
|
+
}
|
|
124
|
+
exports.cropTensor = cropTensor;
|
package/lib/learning/base.ts
CHANGED
|
@@ -5,8 +5,14 @@
|
|
|
5
5
|
///////////////////////////////////////////////////////////////////////////
|
|
6
6
|
|
|
7
7
|
import * as tf from '@tensorflow/tfjs';
|
|
8
|
+
import { dispose } from '@tensorflow/tfjs';
|
|
8
9
|
import { io } from '@tensorflow/tfjs-core';
|
|
9
10
|
import LearningInterface from './base';
|
|
11
|
+
import { capture, loadModel, isTensor } from '../utils/tf';
|
|
12
|
+
import { cropTo } from '../utils/canvas';
|
|
13
|
+
import { flatOneHot, fisherYates, Sample } from '../utils/dataset';
|
|
14
|
+
import { Initializer } from '@tensorflow/tfjs-layers/dist/initializers';
|
|
15
|
+
|
|
10
16
|
|
|
11
17
|
class LearningMobilenet implements LearningInterface {
|
|
12
18
|
model: tf.LayersModel | null;
|
|
@@ -14,41 +20,39 @@ class LearningMobilenet implements LearningInterface {
|
|
|
14
20
|
batchSize: number;
|
|
15
21
|
learningRate: number;
|
|
16
22
|
validateRate: number;
|
|
17
|
-
labels: string[];
|
|
18
|
-
modelURL: string;
|
|
19
23
|
isRunning: boolean;
|
|
20
24
|
isReady: boolean;
|
|
21
25
|
isTrainedDone: boolean;
|
|
22
26
|
limitSize: number;
|
|
23
27
|
mobilenetModule: tf.LayersModel | null;
|
|
24
|
-
|
|
28
|
+
imageExamples: Float32Array[][] = [];
|
|
29
|
+
classNumber: string[] = [];
|
|
25
30
|
|
|
26
31
|
readonly MOBILE_NET_INPUT_WIDTH = 224;
|
|
27
32
|
readonly MOBILE_NET_INPUT_HEIGHT = 224;
|
|
28
33
|
readonly MOBILE_NET_INPUT_CHANNEL = 3;
|
|
29
34
|
readonly IMAGE_NORMALIZATION_FACTOR = 255.0;
|
|
30
35
|
|
|
36
|
+
|
|
31
37
|
constructor({
|
|
32
|
-
|
|
33
|
-
epochs = 10,
|
|
38
|
+
epochs = 50,
|
|
34
39
|
batchSize = 16,
|
|
35
40
|
limitSize = 2,
|
|
36
41
|
learningRate = 0.001,
|
|
37
|
-
validateRate = 0.
|
|
38
|
-
}: {
|
|
42
|
+
validateRate = 0.15,
|
|
43
|
+
}: { epochs?: number, batchSize?: number, limitSize?: number, learningRate?: number, validateRate?: number} = {}) {
|
|
39
44
|
this.model = null;
|
|
40
45
|
this.epochs = epochs;
|
|
41
46
|
this.batchSize = batchSize;
|
|
42
47
|
this.learningRate = learningRate;
|
|
43
48
|
this.validateRate = validateRate;
|
|
44
|
-
this.labels = [];
|
|
45
|
-
this.modelURL = modelURL;
|
|
46
49
|
this.isRunning = false;
|
|
47
50
|
this.isReady = false;
|
|
48
51
|
this.isTrainedDone = false;
|
|
49
52
|
this.limitSize = limitSize;
|
|
50
53
|
this.mobilenetModule = null;
|
|
51
|
-
this.
|
|
54
|
+
this.classNumber = [];
|
|
55
|
+
|
|
52
56
|
}
|
|
53
57
|
|
|
54
58
|
// 진행 상태를 나타내는 이벤트를 정의합니다.
|
|
@@ -73,44 +77,93 @@ class LearningMobilenet implements LearningInterface {
|
|
|
73
77
|
|
|
74
78
|
try {
|
|
75
79
|
this.model = await tf.loadLayersModel(jsonURL);
|
|
76
|
-
|
|
80
|
+
for(var i = 0; i < labels.length; i++) {
|
|
81
|
+
this.registerClassNumber(labels[i]);
|
|
82
|
+
}
|
|
77
83
|
this.isReady = true;
|
|
78
84
|
this.model.summary();
|
|
79
85
|
} catch (error) {
|
|
80
86
|
console.error('Model load failed', error);
|
|
81
87
|
throw error;
|
|
82
88
|
}
|
|
83
|
-
|
|
84
89
|
}
|
|
85
90
|
|
|
86
91
|
|
|
92
|
+
private registerClassNumber(value: string): number {
|
|
93
|
+
// 중복 값의 인덱스를 찾습니다.
|
|
94
|
+
const existingIndex = this.classNumber.indexOf(value);
|
|
95
|
+
|
|
96
|
+
// 중복 값이 있다면 해당 인덱스를 반환합니다.
|
|
97
|
+
if (existingIndex !== -1) {
|
|
98
|
+
return existingIndex;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
// 중복 값이 없다면 새로운 항목을 추가하고 그 인덱스를 반환합니다.
|
|
102
|
+
this.classNumber.push(value);
|
|
103
|
+
return this.classNumber.length - 1;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
private _convertToTfDataset() {
|
|
107
|
+
for (let i = 0; i < this.imageExamples.length; i++) {
|
|
108
|
+
this.imageExamples[i] = fisherYates(this.imageExamples[i]) as Float32Array[];
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
const trainDataset: Sample[] = [];
|
|
112
|
+
const validationDataset: Sample[] = [];
|
|
113
|
+
|
|
114
|
+
for (let i = 0; i < this.imageExamples.length; i++) {
|
|
115
|
+
const classLength = this.imageExamples[i].length;
|
|
116
|
+
|
|
117
|
+
// 클래스의 전체 데이터 수를 사용하여 학습 및 검증 데이터 수 계산
|
|
118
|
+
const numValidation = Math.ceil(this.validateRate * classLength);
|
|
119
|
+
const numTrain = classLength - numValidation;
|
|
120
|
+
|
|
121
|
+
// One-Hot 인코딩을 사용하여 라벨 생성
|
|
122
|
+
const y = flatOneHot(i, this.classNumber.length);
|
|
123
|
+
|
|
124
|
+
// numTrain과 numValidation에 따라 데이터를 학습 및 검증 데이터로 분할
|
|
125
|
+
const classTrain: Sample[] = this.imageExamples[i].slice(0, numTrain).map(dataArray => ({ data: dataArray, label: y }));
|
|
126
|
+
trainDataset.push(...classTrain);
|
|
127
|
+
|
|
128
|
+
const classValidation: Sample[] = this.imageExamples[i].slice(numTrain, numTrain + numValidation).map(dataArray => ({ data: dataArray, label: y }));
|
|
129
|
+
validationDataset.push(...classValidation);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
// Shuffle entire datasets
|
|
133
|
+
const shuffledTrainDataset = fisherYates(trainDataset) as Sample[];
|
|
134
|
+
const shuffledValidationDataset = fisherYates(validationDataset) as Sample[];
|
|
135
|
+
|
|
136
|
+
// Convert to tf.data.Dataset
|
|
137
|
+
const trainX = tf.data.array(shuffledTrainDataset.map(sample => sample.data));
|
|
138
|
+
const validationX = tf.data.array(shuffledValidationDataset.map(sample => sample.data));
|
|
139
|
+
const trainY = tf.data.array(shuffledTrainDataset.map(sample => sample.label));
|
|
140
|
+
const validationY = tf.data.array(shuffledValidationDataset.map(sample => sample.label));
|
|
141
|
+
|
|
142
|
+
return {
|
|
143
|
+
trainDataset: tf.data.zip({ xs: trainX, ys: trainY }),
|
|
144
|
+
validationDataset: tf.data.zip({ xs: validationX, ys: validationY })
|
|
145
|
+
};
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
|
|
87
149
|
// 학습 데이타 등록
|
|
88
150
|
public async addData(label: string, data: any): Promise<void> {
|
|
89
151
|
try {
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
// Add an extra dimension to the tensor
|
|
99
|
-
const expandedTensor = normalizedTensor.expandDims(0);
|
|
100
|
-
|
|
101
|
-
console.log('predict extend', expandedTensor);
|
|
102
|
-
const predict = this.mobilenetModule.predict(expandedTensor);
|
|
103
|
-
console.log('predict', predict);
|
|
104
|
-
return predict;
|
|
105
|
-
} else {
|
|
106
|
-
throw new Error('mobilenetModule is null');
|
|
152
|
+
if (this.mobilenetModule !== null) {
|
|
153
|
+
const croppedImage = cropTo(data, 224, false);
|
|
154
|
+
const cap = isTensor(data) ? data : capture(croppedImage, false);
|
|
155
|
+
const predict = this.mobilenetModule.predict(cap) as tf.Tensor;
|
|
156
|
+
const activation = predict.dataSync() as Float32Array;
|
|
157
|
+
const classIndex = this.registerClassNumber(label);
|
|
158
|
+
if (!this.imageExamples[classIndex]) {
|
|
159
|
+
this.imageExamples[classIndex] = [];
|
|
107
160
|
}
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
161
|
+
this.imageExamples[classIndex].push(activation);
|
|
162
|
+
if(this.classNumber.length >= this.limitSize) {
|
|
163
|
+
this.isReady = true;
|
|
164
|
+
}
|
|
165
|
+
} else {
|
|
166
|
+
throw new Error('mobilenetModule is null');
|
|
114
167
|
}
|
|
115
168
|
return Promise.resolve();
|
|
116
169
|
} catch (error) {
|
|
@@ -121,13 +174,14 @@ class LearningMobilenet implements LearningInterface {
|
|
|
121
174
|
|
|
122
175
|
public async init() {
|
|
123
176
|
try {
|
|
124
|
-
this.mobilenetModule = await
|
|
177
|
+
this.mobilenetModule = await loadModel();
|
|
125
178
|
} catch(error) {
|
|
126
179
|
console.log('init Error', error);
|
|
127
180
|
throw error;
|
|
128
181
|
}
|
|
129
182
|
}
|
|
130
183
|
|
|
184
|
+
|
|
131
185
|
// 모델 학습 처리
|
|
132
186
|
public async train(): Promise<tf.History> {
|
|
133
187
|
if (this.isRunning) {
|
|
@@ -139,27 +193,23 @@ class LearningMobilenet implements LearningInterface {
|
|
|
139
193
|
onTrainBegin: (log: any) => {
|
|
140
194
|
this.isTrainedDone = false;
|
|
141
195
|
this.onTrainBegin(log);
|
|
142
|
-
console.log('Training has started.');
|
|
143
196
|
},
|
|
144
197
|
onTrainEnd: (log: any) => {
|
|
145
198
|
this.isTrainedDone = true;
|
|
146
199
|
this.onTrainEnd(log);
|
|
147
|
-
console.log('Training has ended.');
|
|
148
200
|
this.isRunning = false;
|
|
149
201
|
},
|
|
150
202
|
onBatchBegin: (batch: any, logs: any) => {
|
|
151
|
-
console.log(`Batch ${batch} is starting.`);
|
|
203
|
+
//console.log(`Batch ${batch} is starting.`);
|
|
152
204
|
},
|
|
153
205
|
onBatchEnd: (batch: any, logs: any) => {
|
|
154
|
-
console.log(`Batch ${batch} has ended.`);
|
|
206
|
+
//console.log(`Batch ${batch} has ended.`);
|
|
155
207
|
},
|
|
156
208
|
onEpochBegin: (epoch: number, logs: any) => {
|
|
157
|
-
console.log(`Epoch ${epoch+1} is starting.`, logs);
|
|
209
|
+
//console.log(`Epoch ${epoch+1} is starting.`, logs);
|
|
158
210
|
},
|
|
159
211
|
onEpochEnd: (epoch: number, logs: any) => {
|
|
160
212
|
this.onEpochEnd(epoch, logs);
|
|
161
|
-
console.log(`Epoch ${epoch+1} has ended.`);
|
|
162
|
-
console.log('Loss:', logs);
|
|
163
213
|
this.onLoss(logs.loss);
|
|
164
214
|
this.onProgress(epoch+1);
|
|
165
215
|
this.onEvents({
|
|
@@ -171,22 +221,24 @@ class LearningMobilenet implements LearningInterface {
|
|
|
171
221
|
|
|
172
222
|
try {
|
|
173
223
|
this.isRunning = true;
|
|
174
|
-
if (this.
|
|
224
|
+
if (this.classNumber.length < this.limitSize) {
|
|
175
225
|
return Promise.reject(new Error('Please train Data need over 2 data length'));
|
|
176
226
|
}
|
|
177
|
-
|
|
178
|
-
const
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
const
|
|
182
|
-
|
|
183
|
-
const history = await this.model.fit(inputData, targetData, {
|
|
227
|
+
const datasets = this._convertToTfDataset();
|
|
228
|
+
const trainData = datasets.trainDataset.batch(this.batchSize);
|
|
229
|
+
const validationData = datasets.validationDataset.batch(this.batchSize);
|
|
230
|
+
const optimizer = tf.train.adam(this.learningRate);
|
|
231
|
+
const trainModel = await this._createModel(optimizer);
|
|
232
|
+
const history = await trainModel.fitDataset(trainData, {
|
|
184
233
|
epochs: this.epochs,
|
|
185
|
-
|
|
186
|
-
validationSplit: this.validateRate, // 검증 데이터의 비율 설정
|
|
234
|
+
validationData: validationData,
|
|
187
235
|
callbacks: customCallback
|
|
188
236
|
});
|
|
189
|
-
|
|
237
|
+
const jointModel = tf.sequential();
|
|
238
|
+
jointModel.add(this.mobilenetModule!);
|
|
239
|
+
jointModel.add(trainModel);
|
|
240
|
+
this.model = jointModel;
|
|
241
|
+
optimizer.dispose();
|
|
190
242
|
return history;
|
|
191
243
|
} catch (error) {
|
|
192
244
|
this.isRunning = false;
|
|
@@ -202,38 +254,19 @@ class LearningMobilenet implements LearningInterface {
|
|
|
202
254
|
}
|
|
203
255
|
|
|
204
256
|
try {
|
|
205
|
-
const
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
const normalizedTensor = resizedTensor.div(255.0);
|
|
212
|
-
|
|
213
|
-
// Add an extra dimension to the tensor
|
|
214
|
-
const expandedTensor = normalizedTensor.expandDims(0);
|
|
215
|
-
|
|
216
|
-
console.log('predict extend', expandedTensor);
|
|
217
|
-
const predict = this.mobilenetModule.predict(expandedTensor);
|
|
218
|
-
console.log('predict', predict);
|
|
219
|
-
return predict;
|
|
220
|
-
} else {
|
|
221
|
-
throw new Error('mobilenetModule is null');
|
|
222
|
-
}
|
|
257
|
+
const classProbabilities = new Map<string, number>();
|
|
258
|
+
const croppedImage = cropTo(data, 224, false);
|
|
259
|
+
|
|
260
|
+
const logits = tf.tidy(() => {
|
|
261
|
+
const captured = capture(croppedImage, false);
|
|
262
|
+
return this.model!.predict(captured);
|
|
223
263
|
});
|
|
224
|
-
const
|
|
225
|
-
const predictedClass = predictions.as1D().argMax();
|
|
226
|
-
const classId = (await predictedClass.data())[0];
|
|
227
|
-
console.log('classId', classId, this.labels[classId]);
|
|
228
|
-
const predictionsData = await predictions.data(); // 예측 텐서의 데이터를 비동기로 가져옴
|
|
229
|
-
const classProbabilities = new Map<string, number>(); // 클래스별 확률 누적값을 저장할 맵
|
|
230
|
-
|
|
231
|
-
console.log('predictionsData', predictionsData);
|
|
264
|
+
const values = await (logits as tf.Tensor<tf.Rank>).data();
|
|
232
265
|
const EPSILON = 1e-6; // 매우 작은 값을 표현하기 위한 엡실론
|
|
233
|
-
for (let i = 0; i <
|
|
234
|
-
let probability = Math.max(0, Math.min(1,
|
|
266
|
+
for (let i = 0; i < values.length; i++) {
|
|
267
|
+
let probability = Math.max(0, Math.min(1, values[i])); // 확률 값을 0과 1 사이로 조정
|
|
235
268
|
probability = probability < EPSILON ? 0 : probability; // 매우 작은 확률 값을 0으로 간주
|
|
236
|
-
const className = this.
|
|
269
|
+
const className = this.classNumber[i]; // 클래스 이름
|
|
237
270
|
const existingProbability = classProbabilities.get(className);
|
|
238
271
|
if (existingProbability !== undefined) {
|
|
239
272
|
classProbabilities.set(className, existingProbability + probability);
|
|
@@ -241,9 +274,8 @@ class LearningMobilenet implements LearningInterface {
|
|
|
241
274
|
classProbabilities.set(className, probability);
|
|
242
275
|
}
|
|
243
276
|
}
|
|
244
|
-
console.log('
|
|
245
|
-
|
|
246
|
-
predictions.dispose();
|
|
277
|
+
console.log('classProbabilities', classProbabilities);
|
|
278
|
+
dispose(logits);
|
|
247
279
|
return classProbabilities;
|
|
248
280
|
} catch (error) {
|
|
249
281
|
throw error;
|
|
@@ -253,15 +285,12 @@ class LearningMobilenet implements LearningInterface {
|
|
|
253
285
|
|
|
254
286
|
// 모델 저장
|
|
255
287
|
public async saveModel(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<void> {
|
|
256
|
-
console.log('saved model');
|
|
257
288
|
if (!this.isTrainedDone) {
|
|
258
289
|
return Promise.reject(new Error('Train is not done status'));
|
|
259
290
|
}
|
|
260
291
|
await this.model?.save(handlerOrURL, config);
|
|
261
292
|
}
|
|
262
293
|
|
|
263
|
-
|
|
264
|
-
|
|
265
294
|
// 진행중 여부
|
|
266
295
|
public running(): boolean {
|
|
267
296
|
return this.isRunning;
|
|
@@ -272,83 +301,41 @@ class LearningMobilenet implements LearningInterface {
|
|
|
272
301
|
return this.isReady;
|
|
273
302
|
}
|
|
274
303
|
|
|
275
|
-
|
|
276
|
-
// target 라벨 데이타
|
|
277
|
-
private _preprocessedTargetData(): tf.Tensor<tf.Rank> {
|
|
278
|
-
// 라벨 unique 처리 & 배열 리턴
|
|
279
|
-
console.log('uniqueLabels.length', this.labels, this.labels.length);
|
|
280
|
-
const labelIndices = this.labels.map((label) => this.labels.indexOf(label));
|
|
281
|
-
console.log('labelIndices', labelIndices);
|
|
282
|
-
const oneHotEncode = tf.oneHot(tf.tensor1d(labelIndices, 'int32'),this.labels.length);
|
|
283
|
-
console.log('oneHotEncode', oneHotEncode);
|
|
284
|
-
return oneHotEncode
|
|
285
|
-
}
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
// 입력 이미지 데이타
|
|
289
|
-
private _preprocessedInputData(model: tf.LayersModel): tf.Tensor<tf.Rank> {
|
|
290
|
-
// 이미지 배열을 배치로 변환 - [null, 224, 224, 3]
|
|
291
|
-
const inputShape = model.inputs[0].shape;
|
|
292
|
-
console.log('inputShape', inputShape);
|
|
293
|
-
// inputShape를 이와 같이 포멧 맞춘다. for reshape to [224, 224, 3]
|
|
294
|
-
const inputShapeArray = inputShape.slice(1) as number[];
|
|
295
|
-
console.log('inputShapeArray', inputShapeArray);
|
|
296
|
-
const inputBatch = tf.stack(this.imageTensors.map((image) => {
|
|
297
|
-
return tf.reshape(image, inputShapeArray);
|
|
298
|
-
}));
|
|
299
|
-
return inputBatch
|
|
300
|
-
}
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
private async loadModel(): Promise<tf.LayersModel> {
|
|
304
|
-
const load_model = await tf.loadLayersModel(this.modelURL);
|
|
305
|
-
load_model.summary();
|
|
306
|
-
const layer = load_model.getLayer('conv_pw_13_relu');
|
|
307
|
-
const truncatedModel = tf.model({
|
|
308
|
-
inputs: load_model.inputs,
|
|
309
|
-
outputs: layer.output
|
|
310
|
-
})
|
|
311
|
-
return truncatedModel
|
|
312
|
-
}
|
|
313
|
-
|
|
314
304
|
// 모델 저장
|
|
315
|
-
private async _createModel(
|
|
305
|
+
private async _createModel(optimizer: tf.AdamOptimizer): Promise<tf.Sequential> {
|
|
316
306
|
try {
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
});
|
|
347
|
-
model.summary();
|
|
348
|
-
return model;
|
|
307
|
+
// 입력 이미지 크기에 맞게 모델 구조 수정
|
|
308
|
+
let varianceScaling: Initializer;
|
|
309
|
+
varianceScaling = tf.initializers.varianceScaling({});
|
|
310
|
+
|
|
311
|
+
const trainModel = tf.sequential({
|
|
312
|
+
layers: [
|
|
313
|
+
tf.layers.dense({
|
|
314
|
+
inputShape: this.mobilenetModule!.outputs[0].shape.slice(1),
|
|
315
|
+
units: 128,
|
|
316
|
+
activation: 'relu',
|
|
317
|
+
kernelInitializer: varianceScaling, // 'varianceScaling'
|
|
318
|
+
useBias: true
|
|
319
|
+
}),
|
|
320
|
+
tf.layers.dense({
|
|
321
|
+
kernelInitializer: varianceScaling, // 'varianceScaling'
|
|
322
|
+
useBias: false,
|
|
323
|
+
activation: 'softmax',
|
|
324
|
+
units: this.classNumber.length,
|
|
325
|
+
})
|
|
326
|
+
]
|
|
327
|
+
});
|
|
328
|
+
|
|
329
|
+
trainModel.compile({
|
|
330
|
+
loss: 'categoricalCrossentropy',
|
|
331
|
+
optimizer: optimizer,
|
|
332
|
+
metrics: ['accuracy']
|
|
333
|
+
});
|
|
334
|
+
|
|
335
|
+
return trainModel;
|
|
349
336
|
} catch (error) {
|
|
350
|
-
|
|
351
|
-
|
|
337
|
+
console.error('Failed to load model', error);
|
|
338
|
+
throw error;
|
|
352
339
|
}
|
|
353
340
|
}
|
|
354
341
|
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
type Drawable = ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement;
|
|
2
|
+
|
|
3
|
+
const newCanvas = () => document.createElement('canvas');
|
|
4
|
+
|
|
5
|
+
export function cropTo(image: Drawable, size: number, flipped = false, canvas: HTMLCanvasElement = newCanvas()) {
|
|
6
|
+
let width: number;
|
|
7
|
+
let height: number;
|
|
8
|
+
|
|
9
|
+
// If ImageData
|
|
10
|
+
if (image instanceof ImageData) {
|
|
11
|
+
width = image.width;
|
|
12
|
+
height = image.height;
|
|
13
|
+
}
|
|
14
|
+
// If image, bitmap, or canvas
|
|
15
|
+
else if (image instanceof HTMLImageElement || image instanceof HTMLCanvasElement) {
|
|
16
|
+
width = image.width;
|
|
17
|
+
height = image.height;
|
|
18
|
+
}
|
|
19
|
+
// If video element
|
|
20
|
+
else if (image instanceof HTMLVideoElement) {
|
|
21
|
+
width = image.videoWidth;
|
|
22
|
+
height = image.videoHeight;
|
|
23
|
+
} else {
|
|
24
|
+
throw new Error("Unsupported Drawable type");
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
const min = Math.min(width, height);
|
|
28
|
+
const scale = size / min;
|
|
29
|
+
const scaledW = Math.ceil(width * scale);
|
|
30
|
+
const scaledH = Math.ceil(height * scale);
|
|
31
|
+
const dx = scaledW - size;
|
|
32
|
+
const dy = scaledH - size;
|
|
33
|
+
canvas.width = canvas.height = size;
|
|
34
|
+
const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;
|
|
35
|
+
|
|
36
|
+
// Handle ImageData separately
|
|
37
|
+
if (image instanceof ImageData) {
|
|
38
|
+
ctx.putImageData(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1); // Adjust this if needed
|
|
39
|
+
} else {
|
|
40
|
+
ctx.drawImage(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1, scaledW, scaledH);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if (flipped) {
|
|
44
|
+
ctx.scale(-1, 1);
|
|
45
|
+
ctx.drawImage(canvas, size * -1, 0);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return canvas;
|
|
49
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
export interface Sample {
|
|
2
|
+
data: Float32Array;
|
|
3
|
+
label: number[];
|
|
4
|
+
}
|
|
5
|
+
|
|
6
|
+
export function flatOneHot(label: number, numClasses: number) {
|
|
7
|
+
const labelOneHot = new Array(numClasses).fill(0) as number[];
|
|
8
|
+
labelOneHot[label] = 1;
|
|
9
|
+
|
|
10
|
+
return labelOneHot;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export function fisherYates(array: Float32Array[] | Sample[]) {
|
|
14
|
+
const length = array.length;
|
|
15
|
+
|
|
16
|
+
// need to clone array or we'd be editing original as we goo
|
|
17
|
+
const shuffled = array.slice();
|
|
18
|
+
for (let i = (length - 1); i > 0; i -= 1) {
|
|
19
|
+
let randomIndex ;
|
|
20
|
+
randomIndex = Math.floor(Math.random() * (i + 1));
|
|
21
|
+
[shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex],shuffled[i]];
|
|
22
|
+
}
|
|
23
|
+
return shuffled;
|
|
24
|
+
}
|