learning_model 1.0.37 → 1.0.39

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,124 @@
1
+ "use strict";
2
+ var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
3
+ if (k2 === undefined) k2 = k;
4
+ var desc = Object.getOwnPropertyDescriptor(m, k);
5
+ if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
6
+ desc = { enumerable: true, get: function() { return m[k]; } };
7
+ }
8
+ Object.defineProperty(o, k2, desc);
9
+ }) : (function(o, m, k, k2) {
10
+ if (k2 === undefined) k2 = k;
11
+ o[k2] = m[k];
12
+ }));
13
+ var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
14
+ Object.defineProperty(o, "default", { enumerable: true, value: v });
15
+ }) : function(o, v) {
16
+ o["default"] = v;
17
+ });
18
+ var __importStar = (this && this.__importStar) || function (mod) {
19
+ if (mod && mod.__esModule) return mod;
20
+ var result = {};
21
+ if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
22
+ __setModuleDefault(result, mod);
23
+ return result;
24
+ };
25
+ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
26
+ function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
27
+ return new (P || (P = Promise))(function (resolve, reject) {
28
+ function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
29
+ function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
30
+ function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
31
+ step((generator = generator.apply(thisArg, _arguments || [])).next());
32
+ });
33
+ };
34
+ Object.defineProperty(exports, "__esModule", { value: true });
35
+ exports.cropTensor = exports.capture = exports.imageToTensor = exports.mobileNetURL = exports.loadModel = exports.isTensor = void 0;
36
+ const tf = __importStar(require("@tensorflow/tfjs"));
37
+ function isTensor(c) {
38
+ return typeof c.dataId === 'object' && typeof c.shape === 'object';
39
+ }
40
+ exports.isTensor = isTensor;
41
+ function loadModel() {
42
+ return __awaiter(this, void 0, void 0, function* () {
43
+ const trainLayerV1 = 'conv_pw_13_relu';
44
+ const trainLayerV2 = 'out_relu';
45
+ var mobileNetVersion = 2;
46
+ const modelURL = mobileNetURL(mobileNetVersion);
47
+ const load_model = yield tf.loadLayersModel(modelURL);
48
+ if (mobileNetVersion == 1) {
49
+ const layer = load_model.getLayer(trainLayerV1);
50
+ const truncatedModel = tf.model({
51
+ inputs: load_model.inputs,
52
+ outputs: layer.output
53
+ });
54
+ const model = tf.sequential();
55
+ model.add(truncatedModel);
56
+ model.add(tf.layers.flatten());
57
+ return model;
58
+ }
59
+ else {
60
+ const layer = load_model.getLayer(trainLayerV2);
61
+ const truncatedModel = tf.model({
62
+ inputs: load_model.inputs,
63
+ outputs: layer.output
64
+ });
65
+ const model = tf.sequential();
66
+ model.add(truncatedModel);
67
+ model.add(tf.layers.globalAveragePooling2d({})); // go from shape [7, 7, 1280] to [1280]
68
+ return model;
69
+ }
70
+ });
71
+ }
72
+ exports.loadModel = loadModel;
73
+ function mobileNetURL(version) {
74
+ if (version == 1) {
75
+ return "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json";
76
+ }
77
+ return "https://storage.googleapis.com/teachable-machine-models/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top/model.json";
78
+ }
79
+ exports.mobileNetURL = mobileNetURL;
80
+ function imageToTensor(data) {
81
+ let tensor;
82
+ if (data instanceof tf.Tensor) {
83
+ tensor = data;
84
+ }
85
+ else {
86
+ // MobileNet 모델 로드
87
+ tensor = tf.browser.fromPixels(data);
88
+ }
89
+ return tensor;
90
+ }
91
+ exports.imageToTensor = imageToTensor;
92
+ function capture(rasterElement, grayscale) {
93
+ return tf.tidy(() => {
94
+ const pixels = tf.browser.fromPixels(rasterElement);
95
+ // crop the image so we're using the center square
96
+ const cropped = cropTensor(pixels, grayscale);
97
+ // Expand the outer most dimension so we have a batch size of 1
98
+ const batchedImage = cropped.expandDims(0);
99
+ // Normalize the image between -1 and a1. The image comes in between 0-255
100
+ // so we divide by 127 and subtract 1.
101
+ return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
102
+ });
103
+ }
104
+ exports.capture = capture;
105
+ function cropTensor(img, grayscaleModel, grayscaleInput) {
106
+ const size = Math.min(img.shape[0], img.shape[1]);
107
+ const centerHeight = img.shape[0] / 2;
108
+ const beginHeight = centerHeight - (size / 2);
109
+ const centerWidth = img.shape[1] / 2;
110
+ const beginWidth = centerWidth - (size / 2);
111
+ if (grayscaleModel && !grayscaleInput) {
112
+ //cropped rgb data
113
+ let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
114
+ grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3]);
115
+ const rgb_weights = [0.2989, 0.5870, 0.1140];
116
+ grayscale_cropped = tf.mul(grayscale_cropped, rgb_weights);
117
+ grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);
118
+ grayscale_cropped = tf.sum(grayscale_cropped, -1);
119
+ grayscale_cropped = tf.expandDims(grayscale_cropped, -1);
120
+ return grayscale_cropped;
121
+ }
122
+ return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
123
+ }
124
+ exports.cropTensor = cropTensor;
@@ -5,8 +5,6 @@ import { io } from '@tensorflow/tfjs-core';
5
5
  interface LearningInterface {
6
6
  // tf 모델
7
7
  model: tf.LayersModel | null;
8
- // 라벨 array
9
- labels: string[];
10
8
  // 이미 train중인지 여부 - 이미 진행중인경우 취소
11
9
  isRunning: boolean;
12
10
  // 준비
@@ -45,7 +43,7 @@ interface LearningInterface {
45
43
  ready(): boolean;
46
44
 
47
45
  // 모델 로드
48
- load({jsonURL, binFile, labels}: {jsonURL :string, binFile: io.LoadOptions | undefined, labels : Array<string>}): Promise<void>;
46
+ load({jsonURL, labels}: {jsonURL :string, labels : Array<string>}): Promise<void>;
49
47
  }
50
48
 
51
49
  export default LearningInterface;
@@ -5,8 +5,14 @@
5
5
  ///////////////////////////////////////////////////////////////////////////
6
6
 
7
7
  import * as tf from '@tensorflow/tfjs';
8
+ import { dispose } from '@tensorflow/tfjs';
8
9
  import { io } from '@tensorflow/tfjs-core';
9
10
  import LearningInterface from './base';
11
+ import { capture, loadModel, isTensor } from '../utils/tf';
12
+ import { cropTo } from '../utils/canvas';
13
+ import { flatOneHot, fisherYates, Sample } from '../utils/dataset';
14
+ import { Initializer } from '@tensorflow/tfjs-layers/dist/initializers';
15
+
10
16
 
11
17
  class LearningMobilenet implements LearningInterface {
12
18
  model: tf.LayersModel | null;
@@ -14,41 +20,39 @@ class LearningMobilenet implements LearningInterface {
14
20
  batchSize: number;
15
21
  learningRate: number;
16
22
  validateRate: number;
17
- labels: string[];
18
- modelURL: string;
19
23
  isRunning: boolean;
20
24
  isReady: boolean;
21
25
  isTrainedDone: boolean;
22
26
  limitSize: number;
23
27
  mobilenetModule: tf.LayersModel | null;
24
- imageTensors: tf.Tensor<tf.Rank>[] = [];
28
+ imageExamples: Float32Array[][] = [];
29
+ classNumber: string[] = [];
25
30
 
26
31
  readonly MOBILE_NET_INPUT_WIDTH = 224;
27
32
  readonly MOBILE_NET_INPUT_HEIGHT = 224;
28
33
  readonly MOBILE_NET_INPUT_CHANNEL = 3;
29
34
  readonly IMAGE_NORMALIZATION_FACTOR = 255.0;
30
35
 
36
+
31
37
  constructor({
32
- modelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json', // 디폴트 mobilenet 이미지
33
- epochs = 10,
38
+ epochs = 50,
34
39
  batchSize = 16,
35
40
  limitSize = 2,
36
41
  learningRate = 0.001,
37
- validateRate = 0.2,
38
- }: { modelURL?: string, epochs?: number, batchSize?: number, limitSize?: number, learningRate?: number, validateRate?: number} = {}) {
42
+ validateRate = 0.15,
43
+ }: { epochs?: number, batchSize?: number, limitSize?: number, learningRate?: number, validateRate?: number} = {}) {
39
44
  this.model = null;
40
45
  this.epochs = epochs;
41
46
  this.batchSize = batchSize;
42
47
  this.learningRate = learningRate;
43
48
  this.validateRate = validateRate;
44
- this.labels = [];
45
- this.modelURL = modelURL;
46
49
  this.isRunning = false;
47
50
  this.isReady = false;
48
51
  this.isTrainedDone = false;
49
52
  this.limitSize = limitSize;
50
53
  this.mobilenetModule = null;
51
- this.imageTensors = [];
54
+ this.classNumber = [];
55
+
52
56
  }
53
57
 
54
58
  // 진행 상태를 나타내는 이벤트를 정의합니다.
@@ -66,51 +70,99 @@ class LearningMobilenet implements LearningInterface {
66
70
 
67
71
  //
68
72
  // 기존의 모델 로드
69
- public async load({jsonURL, binFile, labels}: {jsonURL :string, binFile: io.LoadOptions | undefined, labels : Array<string>}): Promise<void> {
73
+ public async load({jsonURL, labels}: {jsonURL :string, labels : Array<string>}): Promise<void> {
70
74
  if (labels.length <= 0) {
71
75
  return Promise.reject(new Error('Labels length is 0'));
72
76
  }
73
77
 
74
78
  try {
75
79
  this.model = await tf.loadLayersModel(jsonURL);
76
- this.labels = labels;
80
+ for(var i = 0; i < labels.length; i++) {
81
+ this.registerClassNumber(labels[i]);
82
+ }
77
83
  this.isReady = true;
78
84
  this.model.summary();
79
85
  } catch (error) {
80
86
  console.error('Model load failed', error);
81
87
  throw error;
82
88
  }
83
-
84
89
  }
85
90
 
86
91
 
92
+ private registerClassNumber(value: string): number {
93
+ // 중복 값의 인덱스를 찾습니다.
94
+ const existingIndex = this.classNumber.indexOf(value);
95
+
96
+ // 중복 값이 있다면 해당 인덱스를 반환합니다.
97
+ if (existingIndex !== -1) {
98
+ return existingIndex;
99
+ }
100
+
101
+ // 중복 값이 없다면 새로운 항목을 추가하고 그 인덱스를 반환합니다.
102
+ this.classNumber.push(value);
103
+ return this.classNumber.length - 1;
104
+ }
105
+
106
+ private _convertToTfDataset() {
107
+ for (let i = 0; i < this.imageExamples.length; i++) {
108
+ this.imageExamples[i] = fisherYates(this.imageExamples[i]) as Float32Array[];
109
+ }
110
+
111
+ const trainDataset: Sample[] = [];
112
+ const validationDataset: Sample[] = [];
113
+
114
+ for (let i = 0; i < this.imageExamples.length; i++) {
115
+ const classLength = this.imageExamples[i].length;
116
+
117
+ // 클래스의 전체 데이터 수를 사용하여 학습 및 검증 데이터 수 계산
118
+ const numValidation = Math.ceil(this.validateRate * classLength);
119
+ const numTrain = classLength - numValidation;
120
+
121
+ // One-Hot 인코딩을 사용하여 라벨 생성
122
+ const y = flatOneHot(i, this.classNumber.length);
123
+
124
+ // numTrain과 numValidation에 따라 데이터를 학습 및 검증 데이터로 분할
125
+ const classTrain: Sample[] = this.imageExamples[i].slice(0, numTrain).map(dataArray => ({ data: dataArray, label: y }));
126
+ trainDataset.push(...classTrain);
127
+
128
+ const classValidation: Sample[] = this.imageExamples[i].slice(numTrain, numTrain + numValidation).map(dataArray => ({ data: dataArray, label: y }));
129
+ validationDataset.push(...classValidation);
130
+ }
131
+
132
+ // Shuffle entire datasets
133
+ const shuffledTrainDataset = fisherYates(trainDataset) as Sample[];
134
+ const shuffledValidationDataset = fisherYates(validationDataset) as Sample[];
135
+
136
+ // Convert to tf.data.Dataset
137
+ const trainX = tf.data.array(shuffledTrainDataset.map(sample => sample.data));
138
+ const validationX = tf.data.array(shuffledValidationDataset.map(sample => sample.data));
139
+ const trainY = tf.data.array(shuffledTrainDataset.map(sample => sample.label));
140
+ const validationY = tf.data.array(shuffledValidationDataset.map(sample => sample.label));
141
+
142
+ return {
143
+ trainDataset: tf.data.zip({ xs: trainX, ys: trainY }),
144
+ validationDataset: tf.data.zip({ xs: validationX, ys: validationY })
145
+ };
146
+ }
147
+
148
+
87
149
  // 학습 데이타 등록
88
150
  public async addData(label: string, data: any): Promise<void> {
89
151
  try {
90
- const imgTensor = tf.tidy(() => {
91
- if (this.mobilenetModule !== null) {
92
- console.log('predict before', data);
93
- const t = tf.browser.fromPixels(data)
94
- // Resize the image to match the model's input shape
95
- const resizedTensor = tf.image.resizeBilinear(t, [224, 224]);
96
- const normalizedTensor = resizedTensor.div(255.0);
97
-
98
- // Add an extra dimension to the tensor
99
- const expandedTensor = normalizedTensor.expandDims(0);
100
-
101
- console.log('predict extend', expandedTensor);
102
- const predict = this.mobilenetModule.predict(expandedTensor);
103
- console.log('predict', predict);
104
- return predict;
105
- } else {
106
- throw new Error('mobilenetModule is null');
152
+ if (this.mobilenetModule !== null) {
153
+ const cap = isTensor(data) ? data : capture(data, false);
154
+ const predict = this.mobilenetModule.predict(cap) as tf.Tensor;
155
+ const activation = predict.dataSync() as Float32Array;
156
+ const classIndex = this.registerClassNumber(label);
157
+ if (!this.imageExamples[classIndex]) {
158
+ this.imageExamples[classIndex] = [];
107
159
  }
108
- });
109
- console.log('imgTensor', imgTensor);
110
- this.imageTensors.push(imgTensor as tf.Tensor<tf.Rank>);
111
- this.labels.push(label);
112
- if(this.labels.length >= this.limitSize) {
113
- this.isReady = true;
160
+ this.imageExamples[classIndex].push(activation);
161
+ if(this.classNumber.length >= this.limitSize) {
162
+ this.isReady = true;
163
+ }
164
+ } else {
165
+ throw new Error('mobilenetModule is null');
114
166
  }
115
167
  return Promise.resolve();
116
168
  } catch (error) {
@@ -121,13 +173,14 @@ class LearningMobilenet implements LearningInterface {
121
173
 
122
174
  public async init() {
123
175
  try {
124
- this.mobilenetModule = await this.loadModel();
176
+ this.mobilenetModule = await loadModel();
125
177
  } catch(error) {
126
178
  console.log('init Error', error);
127
179
  throw error;
128
180
  }
129
181
  }
130
182
 
183
+
131
184
  // 모델 학습 처리
132
185
  public async train(): Promise<tf.History> {
133
186
  if (this.isRunning) {
@@ -139,12 +192,10 @@ class LearningMobilenet implements LearningInterface {
139
192
  onTrainBegin: (log: any) => {
140
193
  this.isTrainedDone = false;
141
194
  this.onTrainBegin(log);
142
- console.log('Training has started.');
143
195
  },
144
196
  onTrainEnd: (log: any) => {
145
197
  this.isTrainedDone = true;
146
198
  this.onTrainEnd(log);
147
- console.log('Training has ended.');
148
199
  this.isRunning = false;
149
200
  },
150
201
  onBatchBegin: (batch: any, logs: any) => {
@@ -158,8 +209,6 @@ class LearningMobilenet implements LearningInterface {
158
209
  },
159
210
  onEpochEnd: (epoch: number, logs: any) => {
160
211
  this.onEpochEnd(epoch, logs);
161
- console.log(`Epoch ${epoch+1} has ended.`);
162
- console.log('Loss:', logs);
163
212
  this.onLoss(logs.loss);
164
213
  this.onProgress(epoch+1);
165
214
  this.onEvents({
@@ -171,22 +220,24 @@ class LearningMobilenet implements LearningInterface {
171
220
 
172
221
  try {
173
222
  this.isRunning = true;
174
- if (this.labels.length < this.limitSize) {
223
+ if (this.classNumber.length < this.limitSize) {
175
224
  return Promise.reject(new Error('Please train Data need over 2 data length'));
176
225
  }
177
- this.model = await this._createModel(this.labels.length);
178
- const inputData = this._preprocessedInputData(this.model);
179
-
180
- console.log('this.imageTensors', this.imageTensors, inputData);
181
- const targetData = this._preprocessedTargetData();
182
- console.log('targetData', targetData);
183
- const history = await this.model.fit(inputData, targetData, {
226
+ const datasets = this._convertToTfDataset();
227
+ const trainData = datasets.trainDataset.batch(this.batchSize);
228
+ const validationData = datasets.validationDataset.batch(this.batchSize);
229
+ const optimizer = tf.train.adam(this.learningRate);
230
+ const trainModel = await this._createModel(optimizer);
231
+ const history = await trainModel.fitDataset(trainData, {
184
232
  epochs: this.epochs,
185
- batchSize: this.batchSize,
186
- validationSplit: this.validateRate, // 검증 데이터의 비율 설정
233
+ validationData: validationData,
187
234
  callbacks: customCallback
188
235
  });
189
- console.log('Model training completed', history);
236
+ const jointModel = tf.sequential();
237
+ jointModel.add(this.mobilenetModule!);
238
+ jointModel.add(trainModel);
239
+ this.model = jointModel;
240
+ optimizer.dispose();
190
241
  return history;
191
242
  } catch (error) {
192
243
  this.isRunning = false;
@@ -202,38 +253,19 @@ class LearningMobilenet implements LearningInterface {
202
253
  }
203
254
 
204
255
  try {
205
- const imgTensor = tf.tidy(() => {
206
- if (this.mobilenetModule !== null) {
207
- console.log('predict before', data);
208
- const t = tf.browser.fromPixels(data)
209
- // Resize the image to match the model's input shape
210
- const resizedTensor = tf.image.resizeBilinear(t, [224, 224]);
211
- const normalizedTensor = resizedTensor.div(255.0);
212
-
213
- // Add an extra dimension to the tensor
214
- const expandedTensor = normalizedTensor.expandDims(0);
215
-
216
- console.log('predict extend', expandedTensor);
217
- const predict = this.mobilenetModule.predict(expandedTensor);
218
- console.log('predict', predict);
219
- return predict;
220
- } else {
221
- throw new Error('mobilenetModule is null');
222
- }
256
+ const classProbabilities = new Map<string, number>();
257
+ const croppedImage = cropTo(data, 224, false);
258
+
259
+ const logits = tf.tidy(() => {
260
+ const captured = capture(croppedImage, false);
261
+ return this.model!.predict(captured);
223
262
  });
224
- const predictions = this.model.predict(imgTensor) as tf.Tensor;
225
- const predictedClass = predictions.as1D().argMax();
226
- const classId = (await predictedClass.data())[0];
227
- console.log('classId', classId, this.labels[classId]);
228
- const predictionsData = await predictions.data(); // 예측 텐서의 데이터를 비동기로 가져옴
229
- const classProbabilities = new Map<string, number>(); // 클래스별 확률 누적값을 저장할 맵
230
-
231
- console.log('predictionsData', predictionsData);
263
+ const values = await (logits as tf.Tensor<tf.Rank>).data();
232
264
  const EPSILON = 1e-6; // 매우 작은 값을 표현하기 위한 엡실론
233
- for (let i = 0; i < predictionsData.length; i++) {
234
- let probability = Math.max(0, Math.min(1, predictionsData[i])); // 확률 값을 0과 1 사이로 조정
265
+ for (let i = 0; i < values.length; i++) {
266
+ let probability = Math.max(0, Math.min(1, values[i])); // 확률 값을 0과 1 사이로 조정
235
267
  probability = probability < EPSILON ? 0 : probability; // 매우 작은 확률 값을 0으로 간주
236
- const className = this.labels[i]; // 클래스 이름
268
+ const className = this.classNumber[i]; // 클래스 이름
237
269
  const existingProbability = classProbabilities.get(className);
238
270
  if (existingProbability !== undefined) {
239
271
  classProbabilities.set(className, existingProbability + probability);
@@ -241,9 +273,7 @@ class LearningMobilenet implements LearningInterface {
241
273
  classProbabilities.set(className, probability);
242
274
  }
243
275
  }
244
- console.log('Class Probabilities:', classProbabilities);
245
- predictedClass.dispose();
246
- predictions.dispose();
276
+ dispose(logits);
247
277
  return classProbabilities;
248
278
  } catch (error) {
249
279
  throw error;
@@ -253,15 +283,12 @@ class LearningMobilenet implements LearningInterface {
253
283
 
254
284
  // 모델 저장
255
285
  public async saveModel(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<void> {
256
- console.log('saved model');
257
286
  if (!this.isTrainedDone) {
258
287
  return Promise.reject(new Error('Train is not done status'));
259
288
  }
260
289
  await this.model?.save(handlerOrURL, config);
261
290
  }
262
291
 
263
-
264
-
265
292
  // 진행중 여부
266
293
  public running(): boolean {
267
294
  return this.isRunning;
@@ -272,83 +299,41 @@ class LearningMobilenet implements LearningInterface {
272
299
  return this.isReady;
273
300
  }
274
301
 
275
-
276
- // target 라벨 데이타
277
- private _preprocessedTargetData(): tf.Tensor<tf.Rank> {
278
- // 라벨 unique 처리 & 배열 리턴
279
- console.log('uniqueLabels.length', this.labels, this.labels.length);
280
- const labelIndices = this.labels.map((label) => this.labels.indexOf(label));
281
- console.log('labelIndices', labelIndices);
282
- const oneHotEncode = tf.oneHot(tf.tensor1d(labelIndices, 'int32'),this.labels.length);
283
- console.log('oneHotEncode', oneHotEncode);
284
- return oneHotEncode
285
- }
286
-
287
-
288
- // 입력 이미지 데이타
289
- private _preprocessedInputData(model: tf.LayersModel): tf.Tensor<tf.Rank> {
290
- // 이미지 배열을 배치로 변환 - [null, 224, 224, 3]
291
- const inputShape = model.inputs[0].shape;
292
- console.log('inputShape', inputShape);
293
- // inputShape를 이와 같이 포멧 맞춘다. for reshape to [224, 224, 3]
294
- const inputShapeArray = inputShape.slice(1) as number[];
295
- console.log('inputShapeArray', inputShapeArray);
296
- const inputBatch = tf.stack(this.imageTensors.map((image) => {
297
- return tf.reshape(image, inputShapeArray);
298
- }));
299
- return inputBatch
300
- }
301
-
302
-
303
- private async loadModel(): Promise<tf.LayersModel> {
304
- const load_model = await tf.loadLayersModel(this.modelURL);
305
- load_model.summary();
306
- const layer = load_model.getLayer('conv_pw_13_relu');
307
- const truncatedModel = tf.model({
308
- inputs: load_model.inputs,
309
- outputs: layer.output
310
- })
311
- return truncatedModel
312
- }
313
-
314
302
  // 모델 저장
315
- private async _createModel(numClasses: number): Promise<tf.Sequential> {
303
+ private async _createModel(optimizer: tf.AdamOptimizer): Promise<tf.Sequential> {
316
304
  try {
317
- const truncatedModel = await this.loadModel()
318
- // 입력 이미지 크기에 맞게 모델 구조 수정
319
- const model = tf.sequential();
320
- // 추가적인 합성곱 층
321
- model.add(tf.layers.conv2d({
322
- filters: 64,
323
- kernelSize: 3,
324
- activation: 'relu',
325
- padding: 'same',
326
- inputShape: truncatedModel.outputs[0].shape.slice(1)
327
- }));
328
-
329
- model.add(tf.layers.flatten());
330
-
331
- // 깊은 밀집층
332
- model.add(tf.layers.dense({ units: 100, activation: 'relu' }));
333
- model.add(tf.layers.dense({ units: 100, activation: 'relu' }));
334
-
335
- // 드롭아웃 층
336
- model.add(tf.layers.dropout({ rate: 0.5 }));
337
-
338
- model.add(tf.layers.dense({ units: this.labels.length, activation: 'softmax' }));
339
-
340
-
341
- const optimizer = tf.train.adam(this.learningRate); // Optimizer를 생성하고 학습률을 설정합니다.
342
- model.compile({
343
- loss: (numClasses === 2) ? 'binaryCrossentropy': 'categoricalCrossentropy',
344
- optimizer: optimizer,
345
- metrics: ['accuracy', 'acc']
346
- });
347
- model.summary();
348
- return model;
305
+ // 입력 이미지 크기에 맞게 모델 구조 수정
306
+ let varianceScaling: Initializer;
307
+ varianceScaling = tf.initializers.varianceScaling({});
308
+
309
+ const trainModel = tf.sequential({
310
+ layers: [
311
+ tf.layers.dense({
312
+ inputShape: this.mobilenetModule!.outputs[0].shape.slice(1),
313
+ units: 64,
314
+ activation: 'relu',
315
+ kernelInitializer: varianceScaling, // 'varianceScaling'
316
+ useBias: true
317
+ }),
318
+ tf.layers.dense({
319
+ kernelInitializer: varianceScaling, // 'varianceScaling'
320
+ useBias: false,
321
+ activation: 'softmax',
322
+ units: this.classNumber.length,
323
+ })
324
+ ]
325
+ });
326
+
327
+ trainModel.compile({
328
+ loss: 'categoricalCrossentropy',
329
+ optimizer: optimizer,
330
+ metrics: ['accuracy']
331
+ });
332
+
333
+ return trainModel;
349
334
  } catch (error) {
350
- console.error('Failed to load model', error);
351
- throw error;
335
+ console.error('Failed to load model', error);
336
+ throw error;
352
337
  }
353
338
  }
354
339
  }
@@ -0,0 +1,49 @@
1
+ type Drawable = ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement;
2
+
3
+ const newCanvas = () => document.createElement('canvas');
4
+
5
+ export function cropTo(image: Drawable, size: number, flipped = false, canvas: HTMLCanvasElement = newCanvas()) {
6
+ let width: number;
7
+ let height: number;
8
+
9
+ // If ImageData
10
+ if (image instanceof ImageData) {
11
+ width = image.width;
12
+ height = image.height;
13
+ }
14
+ // If image, bitmap, or canvas
15
+ else if (image instanceof HTMLImageElement || image instanceof HTMLCanvasElement) {
16
+ width = image.width;
17
+ height = image.height;
18
+ }
19
+ // If video element
20
+ else if (image instanceof HTMLVideoElement) {
21
+ width = image.videoWidth;
22
+ height = image.videoHeight;
23
+ } else {
24
+ throw new Error("Unsupported Drawable type");
25
+ }
26
+
27
+ const min = Math.min(width, height);
28
+ const scale = size / min;
29
+ const scaledW = Math.ceil(width * scale);
30
+ const scaledH = Math.ceil(height * scale);
31
+ const dx = scaledW - size;
32
+ const dy = scaledH - size;
33
+ canvas.width = canvas.height = size;
34
+ const ctx = canvas.getContext('2d') as CanvasRenderingContext2D;
35
+
36
+ // Handle ImageData separately
37
+ if (image instanceof ImageData) {
38
+ ctx.putImageData(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1); // Adjust this if needed
39
+ } else {
40
+ ctx.drawImage(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1, scaledW, scaledH);
41
+ }
42
+
43
+ if (flipped) {
44
+ ctx.scale(-1, 1);
45
+ ctx.drawImage(canvas, size * -1, 0);
46
+ }
47
+
48
+ return canvas;
49
+ }
@@ -0,0 +1,24 @@
1
+ export interface Sample {
2
+ data: Float32Array;
3
+ label: number[];
4
+ }
5
+
6
+ export function flatOneHot(label: number, numClasses: number) {
7
+ const labelOneHot = new Array(numClasses).fill(0) as number[];
8
+ labelOneHot[label] = 1;
9
+
10
+ return labelOneHot;
11
+ }
12
+
13
+ export function fisherYates(array: Float32Array[] | Sample[]) {
14
+ const length = array.length;
15
+
16
+ // need to clone array or we'd be editing original as we goo
17
+ const shuffled = array.slice();
18
+ for (let i = (length - 1); i > 0; i -= 1) {
19
+ let randomIndex ;
20
+ randomIndex = Math.floor(Math.random() * (i + 1));
21
+ [shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex],shuffled[i]];
22
+ }
23
+ return shuffled;
24
+ }