learning_model 1.0.18 → 1.0.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +5 -1
  2. package/jest.config.js +6 -0
  3. package/lib/learning/mobilenet_image.test.ts +44 -0
  4. package/{src → lib}/learning/mobilenet_image.ts +3 -1
  5. package/lib/learning/util.ts +15 -0
  6. package/package.json +5 -4
  7. package/tsconfig.json +4 -6
  8. package/dist/index.bundle.js +0 -2
  9. package/dist/index.bundle.js.LICENSE.txt +0 -335
  10. package/dist/index.js +0 -10
  11. package/dist/types/index.d.ts +0 -3
  12. package/dist/types/learning/base.d.ts +0 -19
  13. package/dist/types/learning/base.js +0 -2
  14. package/dist/types/learning/image.d.ts +0 -40
  15. package/dist/types/learning/image.js +0 -259
  16. package/dist/types/learning/mobilenet_image.d.ts +0 -42
  17. package/dist/types/learning/mobilenet_image.js +0 -262
  18. package/dist/types/learning/mobilenet_image.test.d.ts +0 -1
  19. package/dist/types/public/index.d.ts +0 -1
  20. package/dist/types/src/index.d.ts +0 -3
  21. package/dist/types/src/learning/base.d.ts +0 -19
  22. package/dist/types/src/learning/image.d.ts +0 -40
  23. package/dist/types/src/learning/mobilenet_image.d.ts +0 -42
  24. package/public/index.css +0 -7
  25. package/public/index.html +0 -15
  26. package/public/index.ts +0 -153
  27. package/src/learning/mobilenet_image.test.ts +0 -63
  28. package/types/index.d.ts +0 -3
  29. package/types/learning/base.d.ts +0 -19
  30. package/types/learning/image.d.ts +0 -40
  31. package/types/learning/mobilenet_image.d.ts +0 -42
  32. package/types/learning/mobilenet_image.test.d.ts +0 -1
  33. package/types/public/index.d.ts +0 -1
  34. package/types/src/index.d.ts +0 -3
  35. package/types/src/learning/base.d.ts +0 -19
  36. package/types/src/learning/image.d.ts +0 -40
  37. package/types/src/learning/mobilenet_image.d.ts +0 -42
  38. /package/{src → lib}/index.ts +0 -0
  39. /package/{src → lib}/learning/base.ts +0 -0
  40. /package/{src → lib}/learning/image.ts +0 -0
@@ -1,42 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import LearningInterface from './base';
3
- declare class LearningMobilenetImage implements LearningInterface {
4
- model: tf.LayersModel | null;
5
- epochs: number;
6
- batchSize: number;
7
- learningRate: number;
8
- labels: string[];
9
- modelURL: string;
10
- isRunning: boolean;
11
- isReady: boolean;
12
- limitSize: number;
13
- trainImages: tf.Tensor3D[];
14
- readonly MOBILE_NET_INPUT_WIDTH = 224;
15
- readonly MOBILE_NET_INPUT_HEIGHT = 224;
16
- readonly MOBILE_NET_INPUT_CHANNEL = 3;
17
- readonly IMAGE_NORMALIZATION_FACTOR = 255;
18
- constructor({ modelURL, // 디폴트 mobilenet 이미지
19
- epochs, batchSize, limitSize, learningRate, }?: {
20
- modelURL?: string;
21
- epochs?: number;
22
- batchSize?: number;
23
- limitSize?: number;
24
- learningRate?: number;
25
- });
26
- onProgress: (progress: number) => void;
27
- onLoss: (loss: number) => void;
28
- onEvents: (logs: any) => void;
29
- onTrainBegin: (log: any) => void;
30
- onTrainEnd: (log: any) => void;
31
- addData(label: string, data: any): Promise<void>;
32
- train(): Promise<tf.History>;
33
- infer(data: any): Promise<Map<string, number>>;
34
- saveModel(): void;
35
- running(): boolean;
36
- ready(): boolean;
37
- private _preprocessedTargetData;
38
- private _preprocessedInputData;
39
- private _preprocessData;
40
- private _createModel;
41
- }
42
- export default LearningMobilenetImage;
@@ -1,262 +0,0 @@
1
- "use strict";
2
- ///////////////////////////////////////////////////////////////////////////
3
- ///////////////////////////////////////////////////////////////////////////
4
- ///////////////////////////////////////////////////////////////////////////
5
- // mobilenet 모델을 이용한 전이학습 방법
6
- ///////////////////////////////////////////////////////////////////////////
7
- var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
8
- if (k2 === undefined) k2 = k;
9
- var desc = Object.getOwnPropertyDescriptor(m, k);
10
- if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
11
- desc = { enumerable: true, get: function() { return m[k]; } };
12
- }
13
- Object.defineProperty(o, k2, desc);
14
- }) : (function(o, m, k, k2) {
15
- if (k2 === undefined) k2 = k;
16
- o[k2] = m[k];
17
- }));
18
- var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
19
- Object.defineProperty(o, "default", { enumerable: true, value: v });
20
- }) : function(o, v) {
21
- o["default"] = v;
22
- });
23
- var __importStar = (this && this.__importStar) || function (mod) {
24
- if (mod && mod.__esModule) return mod;
25
- var result = {};
26
- if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
27
- __setModuleDefault(result, mod);
28
- return result;
29
- };
30
- var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
31
- function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
32
- return new (P || (P = Promise))(function (resolve, reject) {
33
- function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
34
- function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
35
- function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
36
- step((generator = generator.apply(thisArg, _arguments || [])).next());
37
- });
38
- };
39
- Object.defineProperty(exports, "__esModule", { value: true });
40
- const tf = __importStar(require("@tensorflow/tfjs"));
41
- class LearningMobilenetImage {
42
- constructor({ modelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json', // 디폴트 mobilenet 이미지
43
- epochs = 10, batchSize = 16, limitSize = 2, learningRate = 0.001, } = {}) {
44
- this.trainImages = [];
45
- this.MOBILE_NET_INPUT_WIDTH = 224;
46
- this.MOBILE_NET_INPUT_HEIGHT = 224;
47
- this.MOBILE_NET_INPUT_CHANNEL = 3;
48
- this.IMAGE_NORMALIZATION_FACTOR = 255.0;
49
- // 진행 상태를 나타내는 이벤트를 정의합니다.
50
- this.onProgress = () => { };
51
- this.onLoss = () => { };
52
- this.onEvents = () => { };
53
- this.onTrainBegin = () => { };
54
- this.onTrainEnd = () => { };
55
- this.model = null;
56
- this.epochs = epochs;
57
- this.batchSize = batchSize;
58
- this.learningRate = learningRate;
59
- this.labels = [];
60
- this.modelURL = modelURL;
61
- this.isRunning = false;
62
- this.isReady = false;
63
- this.limitSize = limitSize;
64
- }
65
- // 학습 데이타 등록
66
- addData(label, data) {
67
- return __awaiter(this, void 0, void 0, function* () {
68
- try {
69
- const tensor = tf.browser.fromPixels(data);
70
- console.log('addData', tensor);
71
- this.trainImages.push(tensor);
72
- this.labels.push(label);
73
- if (this.labels.length >= this.limitSize) {
74
- this.isReady = true;
75
- }
76
- return Promise.resolve();
77
- }
78
- catch (error) {
79
- console.error('Model training failed', error);
80
- throw error;
81
- }
82
- });
83
- }
84
- // 모델 학습 처리
85
- train() {
86
- return __awaiter(this, void 0, void 0, function* () {
87
- if (this.isRunning) {
88
- return Promise.reject(new Error('Training is already in progress.'));
89
- }
90
- // 콜백 정의
91
- const customCallback = {
92
- onTrainBegin: (log) => {
93
- this.onTrainBegin(log);
94
- console.log('Training has started.');
95
- },
96
- onTrainEnd: (log) => {
97
- this.onTrainEnd(log);
98
- console.log('Training has ended.');
99
- this.isRunning = false;
100
- },
101
- onBatchBegin: (batch, logs) => {
102
- console.log(`Batch ${batch} is starting.`);
103
- },
104
- onBatchEnd: (batch, logs) => {
105
- console.log(`Batch ${batch} has ended.`);
106
- },
107
- onEpochBegin: (epoch, logs) => {
108
- console.log(`Epoch ${epoch + 1} is starting.`, logs);
109
- },
110
- onEpochEnd: (epoch, logs) => {
111
- console.log(`Epoch ${epoch + 1} has ended.`);
112
- console.log('Loss:', logs);
113
- this.onLoss(logs.loss);
114
- this.onProgress(epoch + 1);
115
- this.onEvents(logs);
116
- }
117
- };
118
- try {
119
- this.isRunning = true;
120
- if (this.labels.length < this.limitSize) {
121
- return Promise.reject(new Error('Please train Data need over 2 data length'));
122
- }
123
- this.model = yield this._createModel(this.labels.length);
124
- const inputData = this._preprocessedInputData(this.model);
125
- const targetData = this._preprocessedTargetData();
126
- const history = yield this.model.fit(inputData, targetData, {
127
- epochs: this.epochs,
128
- batchSize: this.batchSize,
129
- callbacks: customCallback
130
- });
131
- console.log('Model training completed', history);
132
- return history;
133
- }
134
- catch (error) {
135
- this.isRunning = false;
136
- console.error('Model training failed', error);
137
- throw error;
138
- }
139
- });
140
- }
141
- // 추론하기
142
- infer(data) {
143
- return __awaiter(this, void 0, void 0, function* () {
144
- if (this.model === null) {
145
- return Promise.reject(new Error('Model is Null'));
146
- }
147
- try {
148
- const tensor = tf.browser.fromPixels(data);
149
- const resizedTensor = tf.image.resizeBilinear(tensor, [this.MOBILE_NET_INPUT_WIDTH, this.MOBILE_NET_INPUT_HEIGHT]);
150
- const reshapedTensor = resizedTensor.expandDims(0); // 배치 크기 1을 추가하여 4차원으로 변환
151
- const predictions = this.model.predict(reshapedTensor);
152
- const predictionsData = yield predictions.data(); // 예측 텐서의 데이터를 비동기로 가져옴
153
- const classProbabilities = new Map(); // 클래스별 확률 누적값을 저장할 맵
154
- for (let i = 0; i < predictionsData.length; i++) {
155
- const className = this.labels[i]; // 클래스 이름
156
- const probability = predictionsData[i];
157
- const existingProbability = classProbabilities.get(className);
158
- if (existingProbability !== undefined) {
159
- classProbabilities.set(className, existingProbability + probability);
160
- }
161
- else {
162
- classProbabilities.set(className, probability);
163
- }
164
- }
165
- console.log('Class Probabilities:', classProbabilities);
166
- return classProbabilities;
167
- }
168
- catch (error) {
169
- throw error;
170
- }
171
- });
172
- }
173
- // 모델 저장
174
- saveModel() {
175
- console.log('saved model');
176
- }
177
- // 진행중 여부
178
- running() {
179
- return this.isRunning;
180
- }
181
- ready() {
182
- return this.isReady;
183
- }
184
- // target 라벨 데이타
185
- _preprocessedTargetData() {
186
- // 라벨 unique 처리 & 배열 리턴
187
- console.log('uniqueLabels.length', this.labels, this.labels.length);
188
- const labelIndices = this.labels.map((label) => this.labels.indexOf(label));
189
- console.log('labelIndices', labelIndices);
190
- const oneHotEncode = tf.oneHot(tf.tensor1d(labelIndices, 'int32'), this.labels.length);
191
- console.log('oneHotEncode', oneHotEncode);
192
- return oneHotEncode;
193
- }
194
- // 입력 이미지 데이타
195
- _preprocessedInputData(model) {
196
- // 이미지 배열을 배치로 변환 - [null, 224, 224, 3]
197
- const inputShape = model.inputs[0].shape;
198
- console.log('inputShape', inputShape);
199
- // inputShape를 이와 같이 포멧 맞춘다. for reshape to [224, 224, 3]
200
- const inputShapeArray = inputShape.slice(1);
201
- console.log('inputShapeArray', inputShapeArray);
202
- const inputBatch = tf.stack(this.trainImages.map((image) => {
203
- // 이미지 전처리 및 크기 조정 등을 수행한 후에
204
- // 모델의 입력 형태로 변환하여 반환
205
- const xs = this._preprocessData(image); // 전처리 함수는 사용자 정의해야 함
206
- return tf.reshape(xs, inputShapeArray);
207
- }));
208
- return inputBatch;
209
- }
210
- // 모델 학습하기 위한 데이타 전처리 단계
211
- _preprocessData(tensor) {
212
- try {
213
- // mobilenet model summary를 하면 위와 같이 224,224 사이즈의 입력값 설정되어 있다. ex) input_1 (InputLayer) [null,224,224,3]
214
- const resizedImage = tf.image.resizeBilinear(tensor, [this.MOBILE_NET_INPUT_WIDTH, this.MOBILE_NET_INPUT_HEIGHT]);
215
- // 이미지를 [0,1] 범위로 정규화 255로 나뉜 픽셀값
216
- const normalizedImage = resizedImage.div(this.IMAGE_NORMALIZATION_FACTOR);
217
- // expandDims(0)을 하여 차원을 추가하여 4D텐서 반환
218
- return normalizedImage.expandDims(0);
219
- }
220
- catch (error) {
221
- console.error('Failed to _preprocessData data', error);
222
- throw error;
223
- }
224
- }
225
- // 모델 저장
226
- _createModel(numClasses) {
227
- return __awaiter(this, void 0, void 0, function* () {
228
- try {
229
- const load_model = yield tf.loadLayersModel(this.modelURL);
230
- // 기존 MobileNet 모델에서 마지막 레이어 제외
231
- const truncatedModel = tf.model({
232
- inputs: load_model.inputs,
233
- outputs: load_model.layers[load_model.layers.length - 2].output
234
- });
235
- // 모델을 학습 가능하게 설정하고 선택한 레이어까지 고정
236
- for (let layer of truncatedModel.layers) {
237
- layer.trainable = false;
238
- }
239
- const model = tf.sequential();
240
- model.add(truncatedModel);
241
- model.add(tf.layers.flatten()); // 필요한 경우 Flatten 레이어 추가
242
- model.add(tf.layers.dense({
243
- units: numClasses,
244
- activation: 'softmax'
245
- }));
246
- const optimizer = tf.train.adam(this.learningRate); // Optimizer를 생성하고 학습률을 설정합니다.
247
- model.compile({
248
- loss: (numClasses === 2) ? 'binaryCrossentropy' : 'categoricalCrossentropy',
249
- optimizer: optimizer,
250
- metrics: ['accuracy', 'acc']
251
- });
252
- model.summary();
253
- return model;
254
- }
255
- catch (error) {
256
- console.error('Failed to load model', error);
257
- throw error;
258
- }
259
- });
260
- }
261
- }
262
- exports.default = LearningMobilenetImage;
@@ -1 +0,0 @@
1
- export {};
@@ -1 +0,0 @@
1
- export {};
@@ -1,3 +0,0 @@
1
- import LearningImage from './learning/image';
2
- import LearningMobilenetImage from './learning/mobilenet_image';
3
- export { LearningImage, LearningMobilenetImage };
@@ -1,19 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- interface LearningInterface {
3
- model: tf.LayersModel | null;
4
- labels: string[];
5
- isRunning: boolean;
6
- isReady: boolean;
7
- onProgress(progress: number): void;
8
- onLoss(loss: number): void;
9
- onEvents(logs: any): void;
10
- onTrainBegin(log: any): void;
11
- onTrainEnd(log: any): void;
12
- addData(label: string, data: any): Promise<void>;
13
- train(): Promise<tf.History>;
14
- infer(data: any): Promise<any>;
15
- saveModel(): void;
16
- running(): boolean;
17
- ready(): boolean;
18
- }
19
- export default LearningInterface;
@@ -1,40 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import LearningInterface from './base';
3
- declare class LearningImage implements LearningInterface {
4
- model: tf.LayersModel | null;
5
- epochs: number;
6
- batchSize: number;
7
- learningRate: number;
8
- labels: string[];
9
- isRunning: boolean;
10
- isReady: boolean;
11
- limitSize: number;
12
- trainImages: tf.Tensor3D[];
13
- readonly MOBILE_NET_INPUT_WIDTH = 224;
14
- readonly MOBILE_NET_INPUT_HEIGHT = 224;
15
- readonly MOBILE_NET_INPUT_CHANNEL = 3;
16
- readonly IMAGE_NORMALIZATION_FACTOR = 255;
17
- constructor({ epochs, batchSize, limitSize, learningRate, }?: {
18
- modelURL?: string;
19
- epochs?: number;
20
- batchSize?: number;
21
- limitSize?: number;
22
- learningRate?: number;
23
- });
24
- onProgress: (progress: number) => void;
25
- onLoss: (loss: number) => void;
26
- onEvents: (logs: any) => void;
27
- onTrainBegin: (log: any) => void;
28
- onTrainEnd: (log: any) => void;
29
- addData(label: string, data: any): Promise<void>;
30
- train(): Promise<tf.History>;
31
- infer(data: any): Promise<Map<string, number>>;
32
- saveModel(): void;
33
- running(): boolean;
34
- ready(): boolean;
35
- private _preprocessedTargetData;
36
- private _preprocessedInputData;
37
- private _preprocessData;
38
- private _createModel;
39
- }
40
- export default LearningImage;
@@ -1,42 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import LearningInterface from './base';
3
- declare class LearningMobilenetImage implements LearningInterface {
4
- model: tf.LayersModel | null;
5
- epochs: number;
6
- batchSize: number;
7
- learningRate: number;
8
- labels: string[];
9
- modelURL: string;
10
- isRunning: boolean;
11
- isReady: boolean;
12
- limitSize: number;
13
- trainImages: tf.Tensor3D[];
14
- readonly MOBILE_NET_INPUT_WIDTH = 224;
15
- readonly MOBILE_NET_INPUT_HEIGHT = 224;
16
- readonly MOBILE_NET_INPUT_CHANNEL = 3;
17
- readonly IMAGE_NORMALIZATION_FACTOR = 255;
18
- constructor({ modelURL, // 디폴트 mobilenet 이미지
19
- epochs, batchSize, limitSize, learningRate, }?: {
20
- modelURL?: string;
21
- epochs?: number;
22
- batchSize?: number;
23
- limitSize?: number;
24
- learningRate?: number;
25
- });
26
- onProgress: (progress: number) => void;
27
- onLoss: (loss: number) => void;
28
- onEvents: (logs: any) => void;
29
- onTrainBegin: (log: any) => void;
30
- onTrainEnd: (log: any) => void;
31
- addData(label: string, data: any): Promise<void>;
32
- train(): Promise<tf.History>;
33
- infer(data: any): Promise<Map<string, number>>;
34
- saveModel(): void;
35
- running(): boolean;
36
- ready(): boolean;
37
- private _preprocessedTargetData;
38
- private _preprocessedInputData;
39
- private _preprocessData;
40
- private _createModel;
41
- }
42
- export default LearningMobilenetImage;
package/public/index.css DELETED
@@ -1,7 +0,0 @@
1
- #container {
2
- position: relative;
3
- }
4
-
5
- #canvas {
6
- margin-bottom: 20px; /* 간격을 조정할 값 */
7
- }
package/public/index.html DELETED
@@ -1,15 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <link rel="stylesheet" href="index.css">
5
- <title>My CNN Module</title>
6
- </head>
7
- <body>
8
- <div id="loss-bar"></div>
9
- <div id="progress-bar"></div>
10
- <div id="root"></div>
11
- <div id="container">
12
- <canvas id="canvas"></canvas>
13
- </div>
14
- </body>
15
- </html>
package/public/index.ts DELETED
@@ -1,153 +0,0 @@
1
-
2
- import LearningMobilenetImage from '../src/learning/mobilenet_image';
3
-
4
- // 프로그레스 바를 표시하는 클래스
5
- class StatusBar {
6
- constructor(private container: HTMLElement) {}
7
- update(status: number, message: string) {
8
- this.container.innerText = `Status: ${status} ${message}`;
9
- }
10
- }
11
-
12
-
13
- async function appRun() {
14
- //////////////////////////////////////////////////////////////////////////////////////////
15
- const learningImage = new LearningMobilenetImage({
16
- epochs: 10,
17
- batchSize: 4
18
- });
19
- learningImage.onProgress = (progress: number) => {
20
- const element = document.getElementById('progress-bar');
21
- if (element !== null) {
22
- const bar = new StatusBar(element);
23
- bar.update(progress, '%');
24
- }
25
- }
26
- learningImage.onLoss = (loss: number) => {
27
- const element = document.getElementById('loss-bar');
28
- if (element !== null) {
29
- const bar = new StatusBar(element);
30
- bar.update(loss, 'loss');
31
- }
32
- }
33
- //////////////////////////////////////////////////////////////////////////////////////////
34
- //////////////////////////////////////////////////////////////////////////////////////////
35
- //////////////////////////////////////////////////////////////////////////////////////////
36
-
37
- // root UI
38
- const container = document.createElement('div');
39
- container.id = 'root';
40
-
41
- // learning 준비 체크
42
- const learningReadyCheck = () => {
43
- if (!learningImage.ready()) {
44
- window.alert('준비된 데이터가 없습니다.');
45
- return false;
46
- }
47
- if (learningImage.running()) {
48
- window.alert('이미 진행 중입니다.');
49
- return false;
50
- }
51
- return true;
52
- };
53
-
54
- // 이미지 저장 버튼 클릭
55
- async function handleImageButtonClick(label: string) {
56
- // Canvas 생성
57
- const canvas = document.createElement('canvas');
58
- const video = document.createElement('video');
59
-
60
- // 웹캠 활성화
61
- if (navigator.mediaDevices.getUserMedia) {
62
- const stream = await navigator.mediaDevices.getUserMedia({ video: true });
63
- video.srcObject = stream;
64
- }
65
-
66
- // 비디오가 메타데이터 로딩되면 캔버스에 그리고 이미지 캡처
67
- video.addEventListener('loadedmetadata', () => {
68
- video.play(); // 비디오 플레이 시작
69
- });
70
-
71
- // 비디오가 메타데이터 로딩되면 캔버스에 그리고 이미지 캡처
72
- video.addEventListener('play', () => {
73
- canvas.width = 128;
74
- canvas.height = 128;
75
- const context = canvas.getContext('2d');
76
- if (context) {
77
- context.drawImage(video, 0, 0, canvas.width, canvas.height);
78
- learningImage.addData(label, context.getImageData(0, 0, canvas.width, canvas.height));
79
- }
80
- });
81
- container.appendChild(canvas);
82
- }
83
-
84
- // 추론하기 버튼
85
- async function handleInferButtonClick() {
86
- if(!learningReadyCheck()) {
87
- return;
88
- }
89
-
90
- // Canvas 생성
91
- const canvas = document.createElement('canvas');
92
- const video = document.createElement('video');
93
-
94
- // 웹캠 활성화
95
- if (navigator.mediaDevices.getUserMedia) {
96
- const stream = await navigator.mediaDevices.getUserMedia({ video: true });
97
- video.srcObject = stream;
98
- }
99
-
100
- // 비디오가 메타데이터 로딩되면 캔버스에 그리고 이미지 캡처
101
- video.addEventListener('loadedmetadata', () => {
102
- video.play(); // 비디오 플레이 시작
103
- });
104
-
105
- // 비디오가 메타데이터 로딩되면 캔버스에 그리고 이미지 캡처
106
- video.addEventListener('play', () => {
107
- canvas.width = 128;
108
- canvas.height = 128;
109
- const context = canvas.getContext('2d');
110
- if (context) {
111
- context.drawImage(video, 0, 0, canvas.width, canvas.height);
112
- learningImage.infer(context.getImageData(0, 0, canvas.width, canvas.height));
113
- }
114
- });
115
- container.appendChild(canvas);
116
- }
117
-
118
- // Train 학습하기 버튼
119
- async function handleTrainButtonClick() {
120
- if(learningReadyCheck()) {
121
- await learningImage.train();
122
- }
123
- }
124
-
125
- // image button UI
126
- const image1Button = document.createElement('button');
127
- image1Button.textContent = '라벨1 이미지';
128
- image1Button.addEventListener('click', () => handleImageButtonClick('라벨1 이미지'));
129
- container.appendChild(image1Button);
130
-
131
- // image button UI
132
- const image2Button = document.createElement('button');
133
- image2Button.textContent = '라벨2 이미지';
134
- image2Button.addEventListener('click', () => handleImageButtonClick('라벨2 이미지'));
135
- container.appendChild(image2Button);
136
-
137
- // save button UI
138
- const trainButton = document.createElement('button');
139
- trainButton.textContent = '모델 Train';
140
- trainButton.addEventListener('click', handleTrainButtonClick);
141
- container.appendChild(trainButton);
142
- document.body.appendChild(container);
143
-
144
- // image button UI
145
- const inferButton = document.createElement('button');
146
- inferButton.textContent = '예측하기';
147
- inferButton.addEventListener('click', () => handleInferButtonClick());
148
- container.appendChild(inferButton);
149
-
150
-
151
- }
152
-
153
- appRun();
@@ -1,63 +0,0 @@
1
- import * as fs from 'fs';
2
- import * as path from 'path';
3
- import * as tf from '@tensorflow/tfjs';
4
- import LearningMobilenetImage from './mobilenet_image';
5
-
6
- describe('LearningMobilenetImage', () => {
7
- let learning: LearningMobilenetImage;
8
- const imagePath = path.join(__dirname, 'data/images/image.jpg'); // 이미지 파일 경로
9
-
10
- beforeEach(() => {
11
- learning = new LearningMobilenetImage({});
12
- });
13
-
14
- afterEach(() => {
15
- // 테스트 후에 모델 저장
16
- learning.saveModel();
17
- });
18
-
19
- it('should add data', () => {
20
- const label = 'label1';
21
- const image = fs.readFileSync(imagePath); // 이미지 파일 읽어오기
22
-
23
- learning.addData(label, image);
24
-
25
- expect(learning.labels).toContain(label);
26
- expect(learning.trainImages.length).toBe(1);
27
- });
28
-
29
- it('should not train if data is insufficient', async () => {
30
- spyOn(console, 'error'); // 콘솔 에러 로그를 가로채기 위해 spyOn 사용
31
-
32
- const result = await learning.train().catch((error: any) => error);
33
-
34
- expect(result).toBeInstanceOf(Error);
35
- expect(result.message).toBe('Please train Data need over 2 data length');
36
- expect(console.error).toHaveBeenCalledWith('Model training failed', result);
37
- expect(learning.model).toBeNull();
38
- });
39
-
40
- it('should train and infer', async () => {
41
- const label1 = 'label1';
42
- const image1 = fs.readFileSync(imagePath); // 이미지 파일 읽어오기
43
- const label2 = 'label2';
44
- const image2 = fs.readFileSync(imagePath); // 이미지 파일 읽어오기
45
- const testData = fs.readFileSync(imagePath); // 테스트 데이터 읽어오기
46
-
47
- learning.addData(label1, image1);
48
- learning.addData(label2, image2);
49
-
50
- const history = await learning.train();
51
-
52
- expect(history).toBeDefined();
53
- expect(learning.model).toBeInstanceOf(tf.LayersModel);
54
- expect(learning.isRunning).toBe(false);
55
-
56
- const predictions = await learning.infer(testData);
57
-
58
- expect(predictions).toBeInstanceOf(Map);
59
- expect(predictions.size).toBe(2);
60
- // expect(predictions.get(label1)).toBeCloseTo(/* 예측값 */);
61
- // expect(predictions.get(label2)).toBeCloseTo(/* 예측값 */);
62
- });
63
- });
package/types/index.d.ts DELETED
@@ -1,3 +0,0 @@
1
- import LearningImage from './learning/image';
2
- import LearningMobilenetImage from './learning/mobilenet_image';
3
- export { LearningImage, LearningMobilenetImage };