learning_model 1.0.38 → 1.0.40

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,94 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+
3
+ export function isTensor(c: any): c is tf.Tensor {
4
+ return typeof c.dataId === 'object' && typeof c.shape === 'object';
5
+ }
6
+
7
+ export async function loadModel(): Promise<tf.LayersModel> {
8
+ const trainLayerV1 = 'conv_pw_13_relu';
9
+ const trainLayerV2 = 'out_relu';
10
+
11
+ var mobileNetVersion = 2;
12
+ const modelURL = mobileNetURL(mobileNetVersion);
13
+ const load_model = await tf.loadLayersModel(modelURL);
14
+
15
+ if (mobileNetVersion == 1) {
16
+ const layer = load_model.getLayer(trainLayerV1);
17
+ const truncatedModel = tf.model({
18
+ inputs: load_model.inputs,
19
+ outputs: layer.output
20
+ })
21
+ const model = tf.sequential();
22
+ model.add(truncatedModel);
23
+ model.add(tf.layers.flatten());
24
+ return model;
25
+ } else {
26
+ const layer = load_model.getLayer(trainLayerV2);
27
+ const truncatedModel = tf.model({
28
+ inputs: load_model.inputs,
29
+ outputs: layer.output
30
+ })
31
+ const model = tf.sequential();
32
+ model.add(truncatedModel);
33
+ model.add(tf.layers.globalAveragePooling2d({})); // go from shape [7, 7, 1280] to [1280]
34
+ return model;
35
+ }
36
+ }
37
+
38
+ export function mobileNetURL(version: number): string {
39
+ if (version == 1) {
40
+ return "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json";
41
+ }
42
+ return "https://storage.googleapis.com/teachable-machine-models/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top/model.json";
43
+ }
44
+
45
+ export function imageToTensor(data: any): tf.Tensor3D {
46
+ let tensor: tf.Tensor3D;
47
+
48
+ if (data instanceof tf.Tensor) {
49
+ tensor = data;
50
+ } else {
51
+ // MobileNet 모델 로드
52
+ tensor = tf.browser.fromPixels(data);
53
+ }
54
+ return tensor;
55
+ }
56
+
57
+ export function capture(rasterElement: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, grayscale?: boolean) {
58
+ return tf.tidy(() => {
59
+ const pixels = tf.browser.fromPixels(rasterElement);
60
+
61
+ // crop the image so we're using the center square
62
+ const cropped = cropTensor(pixels, grayscale);
63
+
64
+ // Expand the outer most dimension so we have a batch size of 1
65
+ const batchedImage = cropped.expandDims(0);
66
+
67
+ // Normalize the image between -1 and a1. The image comes in between 0-255
68
+ // so we divide by 127 and subtract 1.
69
+ return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
70
+ });
71
+ }
72
+
73
+ export function cropTensor( img: tf.Tensor3D, grayscaleModel?: boolean, grayscaleInput?: boolean ) : tf.Tensor3D {
74
+ const size = Math.min(img.shape[0], img.shape[1]);
75
+ const centerHeight = img.shape[0] / 2;
76
+ const beginHeight = centerHeight - (size / 2);
77
+ const centerWidth = img.shape[1] / 2;
78
+ const beginWidth = centerWidth - (size / 2);
79
+
80
+ if (grayscaleModel && !grayscaleInput) {
81
+ //cropped rgb data
82
+ let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
83
+
84
+ grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3])
85
+ const rgb_weights = [0.2989, 0.5870, 0.1140]
86
+ grayscale_cropped = tf.mul(grayscale_cropped, rgb_weights)
87
+ grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);
88
+
89
+ grayscale_cropped = tf.sum(grayscale_cropped, -1)
90
+ grayscale_cropped = tf.expandDims(grayscale_cropped, -1)
91
+ return grayscale_cropped;
92
+ }
93
+ return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
94
+ }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "learning_model",
3
- "version": "1.0.38",
3
+ "version": "1.0.40",
4
4
  "description": "learning model develop",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,2 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- export declare function ImageToTensor(data: any): tf.Tensor3D;
@@ -1,40 +0,0 @@
1
- "use strict";
2
- var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
3
- if (k2 === undefined) k2 = k;
4
- var desc = Object.getOwnPropertyDescriptor(m, k);
5
- if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
6
- desc = { enumerable: true, get: function() { return m[k]; } };
7
- }
8
- Object.defineProperty(o, k2, desc);
9
- }) : (function(o, m, k, k2) {
10
- if (k2 === undefined) k2 = k;
11
- o[k2] = m[k];
12
- }));
13
- var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
14
- Object.defineProperty(o, "default", { enumerable: true, value: v });
15
- }) : function(o, v) {
16
- o["default"] = v;
17
- });
18
- var __importStar = (this && this.__importStar) || function (mod) {
19
- if (mod && mod.__esModule) return mod;
20
- var result = {};
21
- if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
22
- __setModuleDefault(result, mod);
23
- return result;
24
- };
25
- Object.defineProperty(exports, "__esModule", { value: true });
26
- exports.ImageToTensor = void 0;
27
- const tf = __importStar(require("@tensorflow/tfjs"));
28
- function ImageToTensor(data) {
29
- let tensor;
30
- console.log('data', data instanceof tf.Tensor, typeof data);
31
- if (data instanceof tf.Tensor) {
32
- tensor = data;
33
- }
34
- else {
35
- // MobileNet 모델 로드
36
- tensor = tf.browser.fromPixels(data);
37
- }
38
- return tensor;
39
- }
40
- exports.ImageToTensor = ImageToTensor;
@@ -1,2 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- export declare function ImageToTensor(data: any): tf.Tensor3D;
@@ -1,16 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
-
3
- export function ImageToTensor(data: any): tf.Tensor3D {
4
- let tensor: tf.Tensor3D;
5
-
6
- console.log('data', data instanceof tf.Tensor, typeof data);
7
- if (data instanceof tf.Tensor) {
8
- tensor = data;
9
- } else {
10
- // MobileNet 모델 로드
11
- tensor = tf.browser.fromPixels(data);
12
- }
13
-
14
- return tensor;
15
- }
16
-