learning_model 1.0.38 → 1.0.40

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,124 @@
1
+ "use strict";
2
+ var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
3
+ if (k2 === undefined) k2 = k;
4
+ var desc = Object.getOwnPropertyDescriptor(m, k);
5
+ if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
6
+ desc = { enumerable: true, get: function() { return m[k]; } };
7
+ }
8
+ Object.defineProperty(o, k2, desc);
9
+ }) : (function(o, m, k, k2) {
10
+ if (k2 === undefined) k2 = k;
11
+ o[k2] = m[k];
12
+ }));
13
+ var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
14
+ Object.defineProperty(o, "default", { enumerable: true, value: v });
15
+ }) : function(o, v) {
16
+ o["default"] = v;
17
+ });
18
+ var __importStar = (this && this.__importStar) || function (mod) {
19
+ if (mod && mod.__esModule) return mod;
20
+ var result = {};
21
+ if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
22
+ __setModuleDefault(result, mod);
23
+ return result;
24
+ };
25
+ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
26
+ function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
27
+ return new (P || (P = Promise))(function (resolve, reject) {
28
+ function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
29
+ function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
30
+ function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
31
+ step((generator = generator.apply(thisArg, _arguments || [])).next());
32
+ });
33
+ };
34
+ Object.defineProperty(exports, "__esModule", { value: true });
35
+ exports.cropTensor = exports.capture = exports.imageToTensor = exports.mobileNetURL = exports.loadModel = exports.isTensor = void 0;
36
+ const tf = __importStar(require("@tensorflow/tfjs"));
37
+ function isTensor(c) {
38
+ return typeof c.dataId === 'object' && typeof c.shape === 'object';
39
+ }
40
+ exports.isTensor = isTensor;
41
+ function loadModel() {
42
+ return __awaiter(this, void 0, void 0, function* () {
43
+ const trainLayerV1 = 'conv_pw_13_relu';
44
+ const trainLayerV2 = 'out_relu';
45
+ var mobileNetVersion = 2;
46
+ const modelURL = mobileNetURL(mobileNetVersion);
47
+ const load_model = yield tf.loadLayersModel(modelURL);
48
+ if (mobileNetVersion == 1) {
49
+ const layer = load_model.getLayer(trainLayerV1);
50
+ const truncatedModel = tf.model({
51
+ inputs: load_model.inputs,
52
+ outputs: layer.output
53
+ });
54
+ const model = tf.sequential();
55
+ model.add(truncatedModel);
56
+ model.add(tf.layers.flatten());
57
+ return model;
58
+ }
59
+ else {
60
+ const layer = load_model.getLayer(trainLayerV2);
61
+ const truncatedModel = tf.model({
62
+ inputs: load_model.inputs,
63
+ outputs: layer.output
64
+ });
65
+ const model = tf.sequential();
66
+ model.add(truncatedModel);
67
+ model.add(tf.layers.globalAveragePooling2d({})); // go from shape [7, 7, 1280] to [1280]
68
+ return model;
69
+ }
70
+ });
71
+ }
72
+ exports.loadModel = loadModel;
73
+ function mobileNetURL(version) {
74
+ if (version == 1) {
75
+ return "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json";
76
+ }
77
+ return "https://storage.googleapis.com/teachable-machine-models/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top/model.json";
78
+ }
79
+ exports.mobileNetURL = mobileNetURL;
80
+ function imageToTensor(data) {
81
+ let tensor;
82
+ if (data instanceof tf.Tensor) {
83
+ tensor = data;
84
+ }
85
+ else {
86
+ // MobileNet 모델 로드
87
+ tensor = tf.browser.fromPixels(data);
88
+ }
89
+ return tensor;
90
+ }
91
+ exports.imageToTensor = imageToTensor;
92
+ function capture(rasterElement, grayscale) {
93
+ return tf.tidy(() => {
94
+ const pixels = tf.browser.fromPixels(rasterElement);
95
+ // crop the image so we're using the center square
96
+ const cropped = cropTensor(pixels, grayscale);
97
+ // Expand the outer most dimension so we have a batch size of 1
98
+ const batchedImage = cropped.expandDims(0);
99
+ // Normalize the image between -1 and a1. The image comes in between 0-255
100
+ // so we divide by 127 and subtract 1.
101
+ return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
102
+ });
103
+ }
104
+ exports.capture = capture;
105
+ function cropTensor(img, grayscaleModel, grayscaleInput) {
106
+ const size = Math.min(img.shape[0], img.shape[1]);
107
+ const centerHeight = img.shape[0] / 2;
108
+ const beginHeight = centerHeight - (size / 2);
109
+ const centerWidth = img.shape[1] / 2;
110
+ const beginWidth = centerWidth - (size / 2);
111
+ if (grayscaleModel && !grayscaleInput) {
112
+ //cropped rgb data
113
+ let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
114
+ grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3]);
115
+ const rgb_weights = [0.2989, 0.5870, 0.1140];
116
+ grayscale_cropped = tf.mul(grayscale_cropped, rgb_weights);
117
+ grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);
118
+ grayscale_cropped = tf.sum(grayscale_cropped, -1);
119
+ grayscale_cropped = tf.expandDims(grayscale_cropped, -1);
120
+ return grayscale_cropped;
121
+ }
122
+ return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
123
+ }
124
+ exports.cropTensor = cropTensor;
@@ -5,8 +5,6 @@ import { io } from '@tensorflow/tfjs-core';
5
5
  interface LearningInterface {
6
6
  // tf 모델
7
7
  model: tf.LayersModel | null;
8
- // 라벨 array
9
- labels: string[];
10
8
  // 이미 train중인지 여부 - 이미 진행중인경우 취소
11
9
  isRunning: boolean;
12
10
  // 준비
@@ -5,8 +5,14 @@
5
5
  ///////////////////////////////////////////////////////////////////////////
6
6
 
7
7
  import * as tf from '@tensorflow/tfjs';
8
+ import { dispose } from '@tensorflow/tfjs';
8
9
  import { io } from '@tensorflow/tfjs-core';
9
10
  import LearningInterface from './base';
11
+ import { capture, loadModel, isTensor } from '../utils/tf';
12
+ import { cropTo } from '../utils/canvas';
13
+ import { flatOneHot, fisherYates, Sample } from '../utils/dataset';
14
+ import { Initializer } from '@tensorflow/tfjs-layers/dist/initializers';
15
+
10
16
 
11
17
  class LearningMobilenet implements LearningInterface {
12
18
  model: tf.LayersModel | null;
@@ -14,41 +20,39 @@ class LearningMobilenet implements LearningInterface {
14
20
  batchSize: number;
15
21
  learningRate: number;
16
22
  validateRate: number;
17
- labels: string[];
18
- modelURL: string;
19
23
  isRunning: boolean;
20
24
  isReady: boolean;
21
25
  isTrainedDone: boolean;
22
26
  limitSize: number;
23
27
  mobilenetModule: tf.LayersModel | null;
24
- imageTensors: tf.Tensor<tf.Rank>[] = [];
28
+ imageExamples: Float32Array[][] = [];
29
+ classNumber: string[] = [];
25
30
 
26
31
  readonly MOBILE_NET_INPUT_WIDTH = 224;
27
32
  readonly MOBILE_NET_INPUT_HEIGHT = 224;
28
33
  readonly MOBILE_NET_INPUT_CHANNEL = 3;
29
34
  readonly IMAGE_NORMALIZATION_FACTOR = 255.0;
30
35
 
36
+
31
37
  constructor({
32
- modelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json', // 디폴트 mobilenet 이미지
33
- epochs = 10,
38
+ epochs = 50,
34
39
  batchSize = 16,
35
40
  limitSize = 2,
36
41
  learningRate = 0.001,
37
- validateRate = 0.2,
38
- }: { modelURL?: string, epochs?: number, batchSize?: number, limitSize?: number, learningRate?: number, validateRate?: number} = {}) {
42
+ validateRate = 0.15,
43
+ }: { epochs?: number, batchSize?: number, limitSize?: number, learningRate?: number, validateRate?: number} = {}) {
39
44
  this.model = null;
40
45
  this.epochs = epochs;
41
46
  this.batchSize = batchSize;
42
47
  this.learningRate = learningRate;
43
48
  this.validateRate = validateRate;
44
- this.labels = [];
45
- this.modelURL = modelURL;
46
49
  this.isRunning = false;
47
50
  this.isReady = false;
48
51
  this.isTrainedDone = false;
49
52
  this.limitSize = limitSize;
50
53
  this.mobilenetModule = null;
51
- this.imageTensors = [];
54
+ this.classNumber = [];
55
+
52
56
  }
53
57
 
54
58
  // 진행 상태를 나타내는 이벤트를 정의합니다.
@@ -73,44 +77,93 @@ class LearningMobilenet implements LearningInterface {
73
77
 
74
78
  try {
75
79
  this.model = await tf.loadLayersModel(jsonURL);
76
- this.labels = labels;
80
+ for(var i = 0; i < labels.length; i++) {
81
+ this.registerClassNumber(labels[i]);
82
+ }
77
83
  this.isReady = true;
78
84
  this.model.summary();
79
85
  } catch (error) {
80
86
  console.error('Model load failed', error);
81
87
  throw error;
82
88
  }
83
-
84
89
  }
85
90
 
86
91
 
92
+ private registerClassNumber(value: string): number {
93
+ // 중복 값의 인덱스를 찾습니다.
94
+ const existingIndex = this.classNumber.indexOf(value);
95
+
96
+ // 중복 값이 있다면 해당 인덱스를 반환합니다.
97
+ if (existingIndex !== -1) {
98
+ return existingIndex;
99
+ }
100
+
101
+ // 중복 값이 없다면 새로운 항목을 추가하고 그 인덱스를 반환합니다.
102
+ this.classNumber.push(value);
103
+ return this.classNumber.length - 1;
104
+ }
105
+
106
+ private _convertToTfDataset() {
107
+ for (let i = 0; i < this.imageExamples.length; i++) {
108
+ this.imageExamples[i] = fisherYates(this.imageExamples[i]) as Float32Array[];
109
+ }
110
+
111
+ const trainDataset: Sample[] = [];
112
+ const validationDataset: Sample[] = [];
113
+
114
+ for (let i = 0; i < this.imageExamples.length; i++) {
115
+ const classLength = this.imageExamples[i].length;
116
+
117
+ // 클래스의 전체 데이터 수를 사용하여 학습 및 검증 데이터 수 계산
118
+ const numValidation = Math.ceil(this.validateRate * classLength);
119
+ const numTrain = classLength - numValidation;
120
+
121
+ // One-Hot 인코딩을 사용하여 라벨 생성
122
+ const y = flatOneHot(i, this.classNumber.length);
123
+
124
+ // numTrain과 numValidation에 따라 데이터를 학습 및 검증 데이터로 분할
125
+ const classTrain: Sample[] = this.imageExamples[i].slice(0, numTrain).map(dataArray => ({ data: dataArray, label: y }));
126
+ trainDataset.push(...classTrain);
127
+
128
+ const classValidation: Sample[] = this.imageExamples[i].slice(numTrain, numTrain + numValidation).map(dataArray => ({ data: dataArray, label: y }));
129
+ validationDataset.push(...classValidation);
130
+ }
131
+
132
+ // Shuffle entire datasets
133
+ const shuffledTrainDataset = fisherYates(trainDataset) as Sample[];
134
+ const shuffledValidationDataset = fisherYates(validationDataset) as Sample[];
135
+
136
+ // Convert to tf.data.Dataset
137
+ const trainX = tf.data.array(shuffledTrainDataset.map(sample => sample.data));
138
+ const validationX = tf.data.array(shuffledValidationDataset.map(sample => sample.data));
139
+ const trainY = tf.data.array(shuffledTrainDataset.map(sample => sample.label));
140
+ const validationY = tf.data.array(shuffledValidationDataset.map(sample => sample.label));
141
+
142
+ return {
143
+ trainDataset: tf.data.zip({ xs: trainX, ys: trainY }),
144
+ validationDataset: tf.data.zip({ xs: validationX, ys: validationY })
145
+ };
146
+ }
147
+
148
+
87
149
  // 학습 데이타 등록
88
150
  public async addData(label: string, data: any): Promise<void> {
89
151
  try {
90
- const imgTensor = tf.tidy(() => {
91
- if (this.mobilenetModule !== null) {
92
- console.log('predict before', data);
93
- const t = tf.browser.fromPixels(data)
94
- // Resize the image to match the model's input shape
95
- const resizedTensor = tf.image.resizeBilinear(t, [224, 224]);
96
- const normalizedTensor = resizedTensor.div(255.0);
97
-
98
- // Add an extra dimension to the tensor
99
- const expandedTensor = normalizedTensor.expandDims(0);
100
-
101
- console.log('predict extend', expandedTensor);
102
- const predict = this.mobilenetModule.predict(expandedTensor);
103
- console.log('predict', predict);
104
- return predict;
105
- } else {
106
- throw new Error('mobilenetModule is null');
152
+ if (this.mobilenetModule !== null) {
153
+ const croppedImage = cropTo(data, 224, false);
154
+ const cap = isTensor(data) ? data : capture(croppedImage, false);
155
+ const predict = this.mobilenetModule.predict(cap) as tf.Tensor;
156
+ const activation = predict.dataSync() as Float32Array;
157
+ const classIndex = this.registerClassNumber(label);
158
+ if (!this.imageExamples[classIndex]) {
159
+ this.imageExamples[classIndex] = [];
107
160
  }
108
- });
109
- console.log('imgTensor', imgTensor);
110
- this.imageTensors.push(imgTensor as tf.Tensor<tf.Rank>);
111
- this.labels.push(label);
112
- if(this.labels.length >= this.limitSize) {
113
- this.isReady = true;
161
+ this.imageExamples[classIndex].push(activation);
162
+ if(this.classNumber.length >= this.limitSize) {
163
+ this.isReady = true;
164
+ }
165
+ } else {
166
+ throw new Error('mobilenetModule is null');
114
167
  }
115
168
  return Promise.resolve();
116
169
  } catch (error) {
@@ -121,13 +174,14 @@ class LearningMobilenet implements LearningInterface {
121
174
 
122
175
  public async init() {
123
176
  try {
124
- this.mobilenetModule = await this.loadModel();
177
+ this.mobilenetModule = await loadModel();
125
178
  } catch(error) {
126
179
  console.log('init Error', error);
127
180
  throw error;
128
181
  }
129
182
  }
130
183
 
184
+
131
185
  // 모델 학습 처리
132
186
  public async train(): Promise<tf.History> {
133
187
  if (this.isRunning) {
@@ -139,27 +193,23 @@ class LearningMobilenet implements LearningInterface {
139
193
  onTrainBegin: (log: any) => {
140
194
  this.isTrainedDone = false;
141
195
  this.onTrainBegin(log);
142
- console.log('Training has started.');
143
196
  },
144
197
  onTrainEnd: (log: any) => {
145
198
  this.isTrainedDone = true;
146
199
  this.onTrainEnd(log);
147
- console.log('Training has ended.');
148
200
  this.isRunning = false;
149
201
  },
150
202
  onBatchBegin: (batch: any, logs: any) => {
151
- console.log(`Batch ${batch} is starting.`);
203
+ //console.log(`Batch ${batch} is starting.`);
152
204
  },
153
205
  onBatchEnd: (batch: any, logs: any) => {
154
- console.log(`Batch ${batch} has ended.`);
206
+ //console.log(`Batch ${batch} has ended.`);
155
207
  },
156
208
  onEpochBegin: (epoch: number, logs: any) => {
157
- console.log(`Epoch ${epoch+1} is starting.`, logs);
209
+ //console.log(`Epoch ${epoch+1} is starting.`, logs);
158
210
  },
159
211
  onEpochEnd: (epoch: number, logs: any) => {
160
212
  this.onEpochEnd(epoch, logs);
161
- console.log(`Epoch ${epoch+1} has ended.`);
162
- console.log('Loss:', logs);
163
213
  this.onLoss(logs.loss);
164
214
  this.onProgress(epoch+1);
165
215
  this.onEvents({
@@ -171,22 +221,24 @@ class LearningMobilenet implements LearningInterface {
171
221
 
172
222
  try {
173
223
  this.isRunning = true;
174
- if (this.labels.length < this.limitSize) {
224
+ if (this.classNumber.length < this.limitSize) {
175
225
  return Promise.reject(new Error('Please train Data need over 2 data length'));
176
226
  }
177
- this.model = await this._createModel(this.labels.length);
178
- const inputData = this._preprocessedInputData(this.model);
179
-
180
- console.log('this.imageTensors', this.imageTensors, inputData);
181
- const targetData = this._preprocessedTargetData();
182
- console.log('targetData', targetData);
183
- const history = await this.model.fit(inputData, targetData, {
227
+ const datasets = this._convertToTfDataset();
228
+ const trainData = datasets.trainDataset.batch(this.batchSize);
229
+ const validationData = datasets.validationDataset.batch(this.batchSize);
230
+ const optimizer = tf.train.adam(this.learningRate);
231
+ const trainModel = await this._createModel(optimizer);
232
+ const history = await trainModel.fitDataset(trainData, {
184
233
  epochs: this.epochs,
185
- batchSize: this.batchSize,
186
- validationSplit: this.validateRate, // 검증 데이터의 비율 설정
234
+ validationData: validationData,
187
235
  callbacks: customCallback
188
236
  });
189
- console.log('Model training completed', history);
237
+ const jointModel = tf.sequential();
238
+ jointModel.add(this.mobilenetModule!);
239
+ jointModel.add(trainModel);
240
+ this.model = jointModel;
241
+ optimizer.dispose();
190
242
  return history;
191
243
  } catch (error) {
192
244
  this.isRunning = false;
@@ -202,38 +254,19 @@ class LearningMobilenet implements LearningInterface {
202
254
  }
203
255
 
204
256
  try {
205
- const imgTensor = tf.tidy(() => {
206
- if (this.mobilenetModule !== null) {
207
- console.log('predict before', data);
208
- const t = tf.browser.fromPixels(data)
209
- // Resize the image to match the model's input shape
210
- const resizedTensor = tf.image.resizeBilinear(t, [224, 224]);
211
- const normalizedTensor = resizedTensor.div(255.0);
212
-
213
- // Add an extra dimension to the tensor
214
- const expandedTensor = normalizedTensor.expandDims(0);
215
-
216
- console.log('predict extend', expandedTensor);
217
- const predict = this.mobilenetModule.predict(expandedTensor);
218
- console.log('predict', predict);
219
- return predict;
220
- } else {
221
- throw new Error('mobilenetModule is null');
222
- }
257
+ const classProbabilities = new Map<string, number>();
258
+ const croppedImage = cropTo(data, 224, false);
259
+
260
+ const logits = tf.tidy(() => {
261
+ const captured = capture(croppedImage, false);
262
+ return this.model!.predict(captured);
223
263
  });
224
- const predictions = this.model.predict(imgTensor) as tf.Tensor;
225
- const predictedClass = predictions.as1D().argMax();
226
- const classId = (await predictedClass.data())[0];
227
- console.log('classId', classId, this.labels[classId]);
228
- const predictionsData = await predictions.data(); // 예측 텐서의 데이터를 비동기로 가져옴
229
- const classProbabilities = new Map<string, number>(); // 클래스별 확률 누적값을 저장할 맵
230
-
231
- console.log('predictionsData', predictionsData);
264
+ const values = await (logits as tf.Tensor<tf.Rank>).data();
232
265
  const EPSILON = 1e-6; // 매우 작은 값을 표현하기 위한 엡실론
233
- for (let i = 0; i < predictionsData.length; i++) {
234
- let probability = Math.max(0, Math.min(1, predictionsData[i])); // 확률 값을 0과 1 사이로 조정
266
+ for (let i = 0; i < values.length; i++) {
267
+ let probability = Math.max(0, Math.min(1, values[i])); // 확률 값을 0과 1 사이로 조정
235
268
  probability = probability < EPSILON ? 0 : probability; // 매우 작은 확률 값을 0으로 간주
236
- const className = this.labels[i]; // 클래스 이름
269
+ const className = this.classNumber[i]; // 클래스 이름
237
270
  const existingProbability = classProbabilities.get(className);
238
271
  if (existingProbability !== undefined) {
239
272
  classProbabilities.set(className, existingProbability + probability);
@@ -241,9 +274,8 @@ class LearningMobilenet implements LearningInterface {
241
274
  classProbabilities.set(className, probability);
242
275
  }
243
276
  }
244
- console.log('Class Probabilities:', classProbabilities);
245
- predictedClass.dispose();
246
- predictions.dispose();
277
+ console.log('classProbabilities', classProbabilities);
278
+ dispose(logits);
247
279
  return classProbabilities;
248
280
  } catch (error) {
249
281
  throw error;
@@ -253,15 +285,12 @@ class LearningMobilenet implements LearningInterface {
253
285
 
254
286
  // 모델 저장
255
287
  public async saveModel(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<void> {
256
- console.log('saved model');
257
288
  if (!this.isTrainedDone) {
258
289
  return Promise.reject(new Error('Train is not done status'));
259
290
  }
260
291
  await this.model?.save(handlerOrURL, config);
261
292
  }
262
293
 
263
-
264
-
265
294
  // 진행중 여부
266
295
  public running(): boolean {
267
296
  return this.isRunning;
@@ -272,83 +301,41 @@ class LearningMobilenet implements LearningInterface {
272
301
  return this.isReady;
273
302
  }
274
303
 
275
-
276
- // target 라벨 데이타
277
- private _preprocessedTargetData(): tf.Tensor<tf.Rank> {
278
- // 라벨 unique 처리 & 배열 리턴
279
- console.log('uniqueLabels.length', this.labels, this.labels.length);
280
- const labelIndices = this.labels.map((label) => this.labels.indexOf(label));
281
- console.log('labelIndices', labelIndices);
282
- const oneHotEncode = tf.oneHot(tf.tensor1d(labelIndices, 'int32'),this.labels.length);
283
- console.log('oneHotEncode', oneHotEncode);
284
- return oneHotEncode
285
- }
286
-
287
-
288
- // 입력 이미지 데이타
289
- private _preprocessedInputData(model: tf.LayersModel): tf.Tensor<tf.Rank> {
290
- // 이미지 배열을 배치로 변환 - [null, 224, 224, 3]
291
- const inputShape = model.inputs[0].shape;
292
- console.log('inputShape', inputShape);
293
- // inputShape를 이와 같이 포멧 맞춘다. for reshape to [224, 224, 3]
294
- const inputShapeArray = inputShape.slice(1) as number[];
295
- console.log('inputShapeArray', inputShapeArray);
296
- const inputBatch = tf.stack(this.imageTensors.map((image) => {
297
- return tf.reshape(image, inputShapeArray);
298
- }));
299
- return inputBatch
300
- }
301
-
302
-
303
- private async loadModel(): Promise<tf.LayersModel> {
304
- const load_model = await tf.loadLayersModel(this.modelURL);
305
- load_model.summary();
306
- const layer = load_model.getLayer('conv_pw_13_relu');
307
- const truncatedModel = tf.model({
308
- inputs: load_model.inputs,
309
- outputs: layer.output
310
- })
311
- return truncatedModel
312
- }
313
-
314
304
  // 모델 저장
315
- private async _createModel(numClasses: number): Promise<tf.Sequential> {
305
+ private async _createModel(optimizer: tf.AdamOptimizer): Promise<tf.Sequential> {
316
306
  try {
317
- const truncatedModel = await this.loadModel()
318
- // 입력 이미지 크기에 맞게 모델 구조 수정
319
- const model = tf.sequential();
320
- // 추가적인 합성곱 층
321
- model.add(tf.layers.conv2d({
322
- filters: 64,
323
- kernelSize: 3,
324
- activation: 'relu',
325
- padding: 'same',
326
- inputShape: truncatedModel.outputs[0].shape.slice(1)
327
- }));
328
-
329
- model.add(tf.layers.flatten());
330
-
331
- // 깊은 밀집층
332
- model.add(tf.layers.dense({ units: 100, activation: 'relu' }));
333
- model.add(tf.layers.dense({ units: 100, activation: 'relu' }));
334
-
335
- // 드롭아웃 층
336
- model.add(tf.layers.dropout({ rate: 0.5 }));
337
-
338
- model.add(tf.layers.dense({ units: this.labels.length, activation: 'softmax' }));
339
-
340
-
341
- const optimizer = tf.train.adam(this.learningRate); // Optimizer를 생성하고 학습률을 설정합니다.
342
- model.compile({
343
- loss: (numClasses === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
344
- optimizer: optimizer,
345
- metrics: ['accuracy', 'acc']
346
- });
347
- model.summary();
348
- return model;
307
+ // 입력 이미지 크기에 맞게 모델 구조 수정
308
+ let varianceScaling: Initializer;
309
+ varianceScaling = tf.initializers.varianceScaling({});
310
+
311
+ const trainModel = tf.sequential({
312
+ layers: [
313
+ tf.layers.dense({
314
+ inputShape: this.mobilenetModule!.outputs[0].shape.slice(1),
315
+ units: 128,
316
+ activation: 'relu',
317
+ kernelInitializer: varianceScaling, // 'varianceScaling'
318
+ useBias: true
319
+ }),
320
+ tf.layers.dense({
321
+ kernelInitializer: varianceScaling, // 'varianceScaling'
322
+ useBias: false,
323
+ activation: 'softmax',
324
+ units: this.classNumber.length,
325
+ })
326
+ ]
327
+ });
328
+
329
+ trainModel.compile({
330
+ loss: 'categoricalCrossentropy',
331
+ optimizer: optimizer,
332
+ metrics: ['accuracy']
333
+ });
334
+
335
+ return trainModel;
349
336
  } catch (error) {
350
- console.error('Failed to load model', error);
351
- throw error;
337
+ console.error('Failed to load model', error);
338
+ throw error;
352
339
  }
353
340
  }
354
341
  }
@@ -0,0 +1,49 @@
1
+ type Drawable = ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement;
2
+
3
+ const newCanvas = () => document.createElement('canvas');
4
+
5
+ export function cropTo(image: Drawable, size: number, flipped = false, canvas: HTMLCanvasElement = newCanvas()) {
6
+ let width: number;
7
+ let height: number;
8
+
9
+ // If ImageData
10
+ if (image instanceof ImageData) {
11
+ width = image.width;
12
+ height = image.height;
13
+ }
14
+ // If image, bitmap, or canvas
15
+ else if (image instanceof HTMLImageElement || image instanceof HTMLCanvasElement) {
16
+ width = image.width;
17
+ height = image.height;
18
+ }
19
+ // If video element
20
+ else if (image instanceof HTMLVideoElement) {
21
+ width = image.videoWidth;
22
+ height = image.videoHeight;
23
+ } else {
24
+ throw new Error("Unsupported Drawable type");
25
+ }
26
+
27
+ const min = Math.min(width, height);
28
+ const scale = size / min;
29
+ const scaledW = Math.ceil(width * scale);
30
+ const scaledH = Math.ceil(height * scale);
31
+ const dx = scaledW - size;
32
+ const dy = scaledH - size;
33
+ canvas.width = canvas.height = size;
34
+ const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;
35
+
36
+ // Handle ImageData separately
37
+ if (image instanceof ImageData) {
38
+ ctx.putImageData(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1); // Adjust this if needed
39
+ } else {
40
+ ctx.drawImage(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1, scaledW, scaledH);
41
+ }
42
+
43
+ if (flipped) {
44
+ ctx.scale(-1, 1);
45
+ ctx.drawImage(canvas, size * -1, 0);
46
+ }
47
+
48
+ return canvas;
49
+ }
@@ -0,0 +1,24 @@
1
+ export interface Sample {
2
+ data: Float32Array;
3
+ label: number[];
4
+ }
5
+
6
+ export function flatOneHot(label: number, numClasses: number) {
7
+ const labelOneHot = new Array(numClasses).fill(0) as number[];
8
+ labelOneHot[label] = 1;
9
+
10
+ return labelOneHot;
11
+ }
12
+
13
+ export function fisherYates(array: Float32Array[] | Sample[]) {
14
+ const length = array.length;
15
+
16
+ // need to clone array or we'd be editing original as we goo
17
+ const shuffled = array.slice();
18
+ for (let i = (length - 1); i > 0; i -= 1) {
19
+ let randomIndex ;
20
+ randomIndex = Math.floor(Math.random() * (i + 1));
21
+ [shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex],shuffled[i]];
22
+ }
23
+ return shuffled;
24
+ }