learning_model 1.0.51 → 1.0.52
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/learning/mobilenet.ts +3 -3
- package/package.json +2 -3
- package/dist/index.bundle.js +0 -2
- package/dist/index.bundle.js.LICENSE.txt +0 -352
- package/dist/index.d.ts +0 -3
- package/dist/index.html +0 -1
- package/dist/index.js +0 -10
- package/dist/learning/base.d.ts +0 -23
- package/dist/learning/base.js +0 -2
- package/dist/learning/data_model.d.ts +0 -41
- package/dist/learning/data_model.js +0 -205
- package/dist/learning/data_model.test.d.ts +0 -1
- package/dist/learning/data_model.test.js +0 -56
- package/dist/learning/mobilenet.d.ts +0 -52
- package/dist/learning/mobilenet.js +0 -376
- package/dist/learning/mobilenet.test.d.ts +0 -1
- package/dist/learning/mobilenet.test.js +0 -79
- package/dist/lib/index.d.ts +0 -3
- package/dist/lib/learning/base.d.ts +0 -23
- package/dist/lib/learning/data_model.d.ts +0 -41
- package/dist/lib/learning/data_model.test.d.ts +0 -1
- package/dist/lib/learning/mobilenet.d.ts +0 -52
- package/dist/lib/learning/mobilenet.test.d.ts +0 -1
- package/dist/lib/utils/canvas.d.ts +0 -3
- package/dist/lib/utils/data_manager.d.ts +0 -15
- package/dist/lib/utils/dataset.d.ts +0 -6
- package/dist/lib/utils/tf.d.ts +0 -7
- package/dist/utils/canvas.d.ts +0 -3
- package/dist/utils/canvas.js +0 -47
- package/dist/utils/data_manager.d.ts +0 -15
- package/dist/utils/data_manager.js +0 -62
- package/dist/utils/dataset.d.ts +0 -6
- package/dist/utils/dataset.js +0 -21
- package/dist/utils/tf.d.ts +0 -7
- package/dist/utils/tf.js +0 -126
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import '@tensorflow/tfjs-backend-wasm';
|
|
3
|
-
import { io } from '@tensorflow/tfjs-core';
|
|
4
|
-
import LearningInterface from './base';
|
|
5
|
-
declare class LearningMobilenet implements LearningInterface {
|
|
6
|
-
model: tf.LayersModel | null;
|
|
7
|
-
epochs: number;
|
|
8
|
-
batchSize: number;
|
|
9
|
-
learningRate: number;
|
|
10
|
-
validateRate: number;
|
|
11
|
-
isRunning: boolean;
|
|
12
|
-
isReady: boolean;
|
|
13
|
-
isTrainedDone: boolean;
|
|
14
|
-
limitSize: number;
|
|
15
|
-
mobilenetModule: tf.LayersModel | null;
|
|
16
|
-
imageExamples: Float32Array[][];
|
|
17
|
-
classNumber: string[];
|
|
18
|
-
readonly MOBILE_NET_INPUT_WIDTH = 224;
|
|
19
|
-
readonly MOBILE_NET_INPUT_HEIGHT = 224;
|
|
20
|
-
readonly MOBILE_NET_INPUT_CHANNEL = 3;
|
|
21
|
-
readonly IMAGE_NORMALIZATION_FACTOR = 255;
|
|
22
|
-
constructor({ epochs, batchSize, limitSize, learningRate, validateRate, }?: {
|
|
23
|
-
epochs?: number;
|
|
24
|
-
batchSize?: number;
|
|
25
|
-
limitSize?: number;
|
|
26
|
-
learningRate?: number;
|
|
27
|
-
validateRate?: number;
|
|
28
|
-
});
|
|
29
|
-
onProgress: (progress: number) => void;
|
|
30
|
-
onLoss: (loss: number) => void;
|
|
31
|
-
onEvents: (logs: any) => void;
|
|
32
|
-
onTrainBegin: (log: any) => void;
|
|
33
|
-
onTrainEnd: (log: any) => void;
|
|
34
|
-
onEpochEnd: (epoch: number, logs: any) => void;
|
|
35
|
-
load({ jsonURL, labels }: {
|
|
36
|
-
jsonURL: string;
|
|
37
|
-
labels: Array<string>;
|
|
38
|
-
}): Promise<void>;
|
|
39
|
-
private registerClassNumber;
|
|
40
|
-
private _convertToTfDataset;
|
|
41
|
-
addData(label: string, data: any): Promise<void>;
|
|
42
|
-
init(): Promise<void>;
|
|
43
|
-
private setupBackend;
|
|
44
|
-
private checkWasmSupport;
|
|
45
|
-
train(): Promise<tf.History>;
|
|
46
|
-
infer(data: any): Promise<Map<string, number>>;
|
|
47
|
-
saveModel(handlerOrURL: io.IOHandler | string, config?: io.SaveConfig): Promise<void>;
|
|
48
|
-
running(): boolean;
|
|
49
|
-
ready(): boolean;
|
|
50
|
-
private _createModel;
|
|
51
|
-
}
|
|
52
|
-
export default LearningMobilenet;
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export {};
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
declare class DataManager {
|
|
3
|
-
private labelMap;
|
|
4
|
-
private labelIndex;
|
|
5
|
-
private data;
|
|
6
|
-
addData(label: string, values: number[]): Promise<void>;
|
|
7
|
-
convertToTensors(): {
|
|
8
|
-
xs: tf.Tensor2D;
|
|
9
|
-
ys: tf.Tensor1D;
|
|
10
|
-
};
|
|
11
|
-
getLabelMap(): {
|
|
12
|
-
[key: string]: number;
|
|
13
|
-
};
|
|
14
|
-
}
|
|
15
|
-
export default DataManager;
|
package/dist/lib/utils/tf.d.ts
DELETED
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
export declare function isTensor(c: any): c is tf.Tensor;
|
|
3
|
-
export declare function loadModel(): Promise<tf.LayersModel>;
|
|
4
|
-
export declare function mobileNetURL(version: number): string;
|
|
5
|
-
export declare function imageToTensor(data: any): tf.Tensor3D;
|
|
6
|
-
export declare function capture(rasterElement: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, grayscale?: boolean): Promise<tf.Tensor<tf.Rank>>;
|
|
7
|
-
export declare function cropTensor(img: tf.Tensor3D, grayscaleModel?: boolean, grayscaleInput?: boolean): tf.Tensor3D;
|
package/dist/utils/canvas.d.ts
DELETED
package/dist/utils/canvas.js
DELETED
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.cropTo = void 0;
|
|
4
|
-
const newCanvas = () => document.createElement('canvas');
|
|
5
|
-
function cropTo(image, size, flipped = false, canvas = newCanvas()) {
|
|
6
|
-
let width;
|
|
7
|
-
let height;
|
|
8
|
-
// If ImageData
|
|
9
|
-
if (image instanceof ImageData) {
|
|
10
|
-
width = image.width;
|
|
11
|
-
height = image.height;
|
|
12
|
-
}
|
|
13
|
-
// If image, bitmap, or canvas
|
|
14
|
-
else if (image instanceof HTMLImageElement || image instanceof HTMLCanvasElement) {
|
|
15
|
-
width = image.width;
|
|
16
|
-
height = image.height;
|
|
17
|
-
}
|
|
18
|
-
// If video element
|
|
19
|
-
else if (image instanceof HTMLVideoElement) {
|
|
20
|
-
width = image.videoWidth;
|
|
21
|
-
height = image.videoHeight;
|
|
22
|
-
}
|
|
23
|
-
else {
|
|
24
|
-
throw new Error("Unsupported Drawable type");
|
|
25
|
-
}
|
|
26
|
-
const min = Math.min(width, height);
|
|
27
|
-
const scale = size / min;
|
|
28
|
-
const scaledW = Math.ceil(width * scale);
|
|
29
|
-
const scaledH = Math.ceil(height * scale);
|
|
30
|
-
const dx = scaledW - size;
|
|
31
|
-
const dy = scaledH - size;
|
|
32
|
-
canvas.width = canvas.height = size;
|
|
33
|
-
const ctx = canvas.getContext('2d');
|
|
34
|
-
// Handle ImageData separately
|
|
35
|
-
if (image instanceof ImageData) {
|
|
36
|
-
ctx.putImageData(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1); // Adjust this if needed
|
|
37
|
-
}
|
|
38
|
-
else {
|
|
39
|
-
ctx.drawImage(image, ~~(dx / 2) * -1, ~~(dy / 2) * -1, scaledW, scaledH);
|
|
40
|
-
}
|
|
41
|
-
if (flipped) {
|
|
42
|
-
ctx.scale(-1, 1);
|
|
43
|
-
ctx.drawImage(canvas, size * -1, 0);
|
|
44
|
-
}
|
|
45
|
-
return canvas;
|
|
46
|
-
}
|
|
47
|
-
exports.cropTo = cropTo;
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
declare class DataManager {
|
|
3
|
-
private labelMap;
|
|
4
|
-
private labelIndex;
|
|
5
|
-
private data;
|
|
6
|
-
addData(label: string, values: number[]): Promise<void>;
|
|
7
|
-
convertToTensors(): {
|
|
8
|
-
xs: tf.Tensor2D;
|
|
9
|
-
ys: tf.Tensor1D;
|
|
10
|
-
};
|
|
11
|
-
getLabelMap(): {
|
|
12
|
-
[key: string]: number;
|
|
13
|
-
};
|
|
14
|
-
}
|
|
15
|
-
export default DataManager;
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
|
|
3
|
-
if (k2 === undefined) k2 = k;
|
|
4
|
-
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
5
|
-
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
6
|
-
desc = { enumerable: true, get: function() { return m[k]; } };
|
|
7
|
-
}
|
|
8
|
-
Object.defineProperty(o, k2, desc);
|
|
9
|
-
}) : (function(o, m, k, k2) {
|
|
10
|
-
if (k2 === undefined) k2 = k;
|
|
11
|
-
o[k2] = m[k];
|
|
12
|
-
}));
|
|
13
|
-
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
|
|
14
|
-
Object.defineProperty(o, "default", { enumerable: true, value: v });
|
|
15
|
-
}) : function(o, v) {
|
|
16
|
-
o["default"] = v;
|
|
17
|
-
});
|
|
18
|
-
var __importStar = (this && this.__importStar) || function (mod) {
|
|
19
|
-
if (mod && mod.__esModule) return mod;
|
|
20
|
-
var result = {};
|
|
21
|
-
if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
|
|
22
|
-
__setModuleDefault(result, mod);
|
|
23
|
-
return result;
|
|
24
|
-
};
|
|
25
|
-
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
|
|
26
|
-
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
|
|
27
|
-
return new (P || (P = Promise))(function (resolve, reject) {
|
|
28
|
-
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
|
|
29
|
-
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
|
|
30
|
-
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
|
|
31
|
-
step((generator = generator.apply(thisArg, _arguments || [])).next());
|
|
32
|
-
});
|
|
33
|
-
};
|
|
34
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
35
|
-
const tf = __importStar(require("@tensorflow/tfjs"));
|
|
36
|
-
class DataManager {
|
|
37
|
-
constructor() {
|
|
38
|
-
this.labelMap = {};
|
|
39
|
-
this.labelIndex = 0;
|
|
40
|
-
this.data = [];
|
|
41
|
-
}
|
|
42
|
-
addData(label, values) {
|
|
43
|
-
return __awaiter(this, void 0, void 0, function* () {
|
|
44
|
-
if (!(label in this.labelMap)) {
|
|
45
|
-
this.labelMap[label] = this.labelIndex++;
|
|
46
|
-
}
|
|
47
|
-
const numericLabel = this.labelMap[label];
|
|
48
|
-
this.data.push({ label: numericLabel, values });
|
|
49
|
-
});
|
|
50
|
-
}
|
|
51
|
-
convertToTensors() {
|
|
52
|
-
const xs = this.data.map(d => d.values);
|
|
53
|
-
const ys = this.data.map(d => d.label);
|
|
54
|
-
const xsTensor = tf.tensor2d(xs, [xs.length, xs[0].length], 'float32');
|
|
55
|
-
const ysTensor = tf.tensor1d(ys, 'float32');
|
|
56
|
-
return { xs: xsTensor, ys: ysTensor };
|
|
57
|
-
}
|
|
58
|
-
getLabelMap() {
|
|
59
|
-
return this.labelMap;
|
|
60
|
-
}
|
|
61
|
-
}
|
|
62
|
-
exports.default = DataManager;
|
package/dist/utils/dataset.d.ts
DELETED
package/dist/utils/dataset.js
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.fisherYates = exports.flatOneHot = void 0;
|
|
4
|
-
function flatOneHot(label, numClasses) {
|
|
5
|
-
const labelOneHot = new Array(numClasses).fill(0);
|
|
6
|
-
labelOneHot[label] = 1;
|
|
7
|
-
return labelOneHot;
|
|
8
|
-
}
|
|
9
|
-
exports.flatOneHot = flatOneHot;
|
|
10
|
-
function fisherYates(array) {
|
|
11
|
-
const length = array.length;
|
|
12
|
-
// need to clone array or we'd be editing original as we goo
|
|
13
|
-
const shuffled = array.slice();
|
|
14
|
-
for (let i = (length - 1); i > 0; i -= 1) {
|
|
15
|
-
let randomIndex;
|
|
16
|
-
randomIndex = Math.floor(Math.random() * (i + 1));
|
|
17
|
-
[shuffled[i], shuffled[randomIndex]] = [shuffled[randomIndex], shuffled[i]];
|
|
18
|
-
}
|
|
19
|
-
return shuffled;
|
|
20
|
-
}
|
|
21
|
-
exports.fisherYates = fisherYates;
|
package/dist/utils/tf.d.ts
DELETED
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
export declare function isTensor(c: any): c is tf.Tensor;
|
|
3
|
-
export declare function loadModel(): Promise<tf.LayersModel>;
|
|
4
|
-
export declare function mobileNetURL(version: number): string;
|
|
5
|
-
export declare function imageToTensor(data: any): tf.Tensor3D;
|
|
6
|
-
export declare function capture(rasterElement: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, grayscale?: boolean): Promise<tf.Tensor<tf.Rank>>;
|
|
7
|
-
export declare function cropTensor(img: tf.Tensor3D, grayscaleModel?: boolean, grayscaleInput?: boolean): tf.Tensor3D;
|
package/dist/utils/tf.js
DELETED
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
var __createBinding = (this && this.__createBinding) || (Object.create ? (function(o, m, k, k2) {
|
|
3
|
-
if (k2 === undefined) k2 = k;
|
|
4
|
-
var desc = Object.getOwnPropertyDescriptor(m, k);
|
|
5
|
-
if (!desc || ("get" in desc ? !m.__esModule : desc.writable || desc.configurable)) {
|
|
6
|
-
desc = { enumerable: true, get: function() { return m[k]; } };
|
|
7
|
-
}
|
|
8
|
-
Object.defineProperty(o, k2, desc);
|
|
9
|
-
}) : (function(o, m, k, k2) {
|
|
10
|
-
if (k2 === undefined) k2 = k;
|
|
11
|
-
o[k2] = m[k];
|
|
12
|
-
}));
|
|
13
|
-
var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (function(o, v) {
|
|
14
|
-
Object.defineProperty(o, "default", { enumerable: true, value: v });
|
|
15
|
-
}) : function(o, v) {
|
|
16
|
-
o["default"] = v;
|
|
17
|
-
});
|
|
18
|
-
var __importStar = (this && this.__importStar) || function (mod) {
|
|
19
|
-
if (mod && mod.__esModule) return mod;
|
|
20
|
-
var result = {};
|
|
21
|
-
if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
|
|
22
|
-
__setModuleDefault(result, mod);
|
|
23
|
-
return result;
|
|
24
|
-
};
|
|
25
|
-
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
|
|
26
|
-
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
|
|
27
|
-
return new (P || (P = Promise))(function (resolve, reject) {
|
|
28
|
-
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
|
|
29
|
-
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
|
|
30
|
-
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
|
|
31
|
-
step((generator = generator.apply(thisArg, _arguments || [])).next());
|
|
32
|
-
});
|
|
33
|
-
};
|
|
34
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
35
|
-
exports.cropTensor = exports.capture = exports.imageToTensor = exports.mobileNetURL = exports.loadModel = exports.isTensor = void 0;
|
|
36
|
-
const tf = __importStar(require("@tensorflow/tfjs"));
|
|
37
|
-
function isTensor(c) {
|
|
38
|
-
return typeof c.dataId === 'object' && typeof c.shape === 'object';
|
|
39
|
-
}
|
|
40
|
-
exports.isTensor = isTensor;
|
|
41
|
-
function loadModel() {
|
|
42
|
-
return __awaiter(this, void 0, void 0, function* () {
|
|
43
|
-
const trainLayerV1 = 'conv_pw_13_relu';
|
|
44
|
-
const trainLayerV2 = 'out_relu';
|
|
45
|
-
var mobileNetVersion = 2;
|
|
46
|
-
const modelURL = mobileNetURL(mobileNetVersion);
|
|
47
|
-
const load_model = yield tf.loadLayersModel(modelURL);
|
|
48
|
-
if (mobileNetVersion == 1) {
|
|
49
|
-
const layer = load_model.getLayer(trainLayerV1);
|
|
50
|
-
const truncatedModel = tf.model({
|
|
51
|
-
inputs: load_model.inputs,
|
|
52
|
-
outputs: layer.output
|
|
53
|
-
});
|
|
54
|
-
const model = tf.sequential();
|
|
55
|
-
model.add(truncatedModel);
|
|
56
|
-
model.add(tf.layers.flatten());
|
|
57
|
-
return model;
|
|
58
|
-
}
|
|
59
|
-
else {
|
|
60
|
-
const layer = load_model.getLayer(trainLayerV2);
|
|
61
|
-
const truncatedModel = tf.model({
|
|
62
|
-
inputs: load_model.inputs,
|
|
63
|
-
outputs: layer.output
|
|
64
|
-
});
|
|
65
|
-
const model = tf.sequential();
|
|
66
|
-
model.add(truncatedModel);
|
|
67
|
-
model.add(tf.layers.globalAveragePooling2d({})); // go from shape [7, 7, 1280] to [1280]
|
|
68
|
-
return model;
|
|
69
|
-
}
|
|
70
|
-
});
|
|
71
|
-
}
|
|
72
|
-
exports.loadModel = loadModel;
|
|
73
|
-
function mobileNetURL(version) {
|
|
74
|
-
if (version == 1) {
|
|
75
|
-
return "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json";
|
|
76
|
-
}
|
|
77
|
-
return "https://storage.googleapis.com/teachable-machine-models/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top/model.json";
|
|
78
|
-
}
|
|
79
|
-
exports.mobileNetURL = mobileNetURL;
|
|
80
|
-
function imageToTensor(data) {
|
|
81
|
-
let tensor;
|
|
82
|
-
if (data instanceof tf.Tensor) {
|
|
83
|
-
tensor = data;
|
|
84
|
-
}
|
|
85
|
-
else {
|
|
86
|
-
// MobileNet 모델 로드
|
|
87
|
-
tensor = tf.browser.fromPixels(data);
|
|
88
|
-
}
|
|
89
|
-
return tensor;
|
|
90
|
-
}
|
|
91
|
-
exports.imageToTensor = imageToTensor;
|
|
92
|
-
function capture(rasterElement, grayscale) {
|
|
93
|
-
return __awaiter(this, void 0, void 0, function* () {
|
|
94
|
-
return tf.tidy(() => {
|
|
95
|
-
const pixels = tf.browser.fromPixels(rasterElement);
|
|
96
|
-
// crop the image so we're using the center square
|
|
97
|
-
const cropped = cropTensor(pixels, grayscale);
|
|
98
|
-
// Expand the outer most dimension so we have a batch size of 1
|
|
99
|
-
const batchedImage = cropped.expandDims(0);
|
|
100
|
-
// Normalize the image between -1 and a1. The image comes in between 0-255
|
|
101
|
-
// so we divide by 127 and subtract 1.
|
|
102
|
-
return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
|
|
103
|
-
});
|
|
104
|
-
});
|
|
105
|
-
}
|
|
106
|
-
exports.capture = capture;
|
|
107
|
-
function cropTensor(img, grayscaleModel, grayscaleInput) {
|
|
108
|
-
const size = Math.min(img.shape[0], img.shape[1]);
|
|
109
|
-
const centerHeight = img.shape[0] / 2;
|
|
110
|
-
const beginHeight = centerHeight - (size / 2);
|
|
111
|
-
const centerWidth = img.shape[1] / 2;
|
|
112
|
-
const beginWidth = centerWidth - (size / 2);
|
|
113
|
-
if (grayscaleModel && !grayscaleInput) {
|
|
114
|
-
//cropped rgb data
|
|
115
|
-
let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
|
|
116
|
-
grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3]);
|
|
117
|
-
const rgb_weights = [0.2989, 0.5870, 0.1140];
|
|
118
|
-
grayscale_cropped = tf.mul(grayscale_cropped, rgb_weights);
|
|
119
|
-
grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);
|
|
120
|
-
grayscale_cropped = tf.sum(grayscale_cropped, -1);
|
|
121
|
-
grayscale_cropped = tf.expandDims(grayscale_cropped, -1);
|
|
122
|
-
return grayscale_cropped;
|
|
123
|
-
}
|
|
124
|
-
return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
|
|
125
|
-
}
|
|
126
|
-
exports.cropTensor = cropTensor;
|