tfjs-evolution 0.0.3 → 0.0.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,165 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ export class Util {
3
+ /**
4
+ * Receives an image and normalizes it between -1 and 1.
5
+ * Returns a batched image (1 - element batch) of shape [1, w, h, c]
6
+ * @param rasterElement the element with pixels to convert to a Tensor
7
+ * @param grayscale optinal flag that changes the crop to [1, w, h, 1]
8
+ */
9
+ capture(rasterElement, grayscale) {
10
+ return tf.tidy(() => {
11
+ const pixels = tf.browser.fromPixels(rasterElement);
12
+ // // crop the image so we're using the center square
13
+ const cropped = this.cropTensor(pixels, grayscale);
14
+ // // Expand the outer most dimension so we have a batch size of 1
15
+ const batchedImage = cropped.expandDims(0);
16
+ // // Normalize the image between -1 and a1. The image comes in between 0-255
17
+ // // so we divide by 127 and subtract 1.
18
+ return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
19
+ });
20
+ }
21
+ cropTensor(img, grayscaleModel, grayscaleInput) {
22
+ const size = Math.min(img.shape[0], img.shape[1]);
23
+ const centerHeight = img.shape[0] / 2;
24
+ const beginHeight = centerHeight - (size / 2);
25
+ const centerWidth = img.shape[1] / 2;
26
+ const beginWidth = centerWidth - (size / 2);
27
+ if (grayscaleModel && !grayscaleInput) {
28
+ //cropped rgb data
29
+ let grayscale_cropped = img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
30
+ grayscale_cropped = grayscale_cropped.reshape([size * size, 1, 3]);
31
+ const rgb_weights = [0.2989, 0.5870, 0.1140];
32
+ grayscale_cropped = tf.mul(grayscale_cropped, rgb_weights);
33
+ grayscale_cropped = grayscale_cropped.reshape([size, size, 3]);
34
+ grayscale_cropped = tf.sum(grayscale_cropped, -1);
35
+ grayscale_cropped = tf.expandDims(grayscale_cropped, -1);
36
+ return grayscale_cropped;
37
+ }
38
+ return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
39
+ }
40
+ /**
41
+ * This function will make a copy of a model on the weight level
42
+ * This is an attempt to avoid influencing the new mode when the old one
43
+ * is eliminated.
44
+ *
45
+ * @param originalModel - the model to be copied
46
+ * @param recipient - the new model
47
+ */
48
+ async copyModel_v3(originalModel, recipient) {
49
+ originalModel.layers.forEach((layer, index) => {
50
+ recipient.layers[index].setWeights(layer.getWeights());
51
+ });
52
+ // originalModel.dispose();
53
+ }
54
+ /**
55
+ * This function will make a copy of a TFJS model, as so it would be possible
56
+ * to erase the original.
57
+ * @param model - model to be copied
58
+ * @returns - copy of the model
59
+ */
60
+ async copyModel_v2(originalModel) {
61
+ // Serialize the original model
62
+ const modelTopology = originalModel.toJSON();
63
+ // Load the serialized model into a new model
64
+ const copiedModel = await tf.loadLayersModel(tf.io.fromMemory(modelTopology, undefined, undefined));
65
+ // Compile the copied model with the same settings as the original
66
+ copiedModel.compile({
67
+ loss: originalModel.loss,
68
+ optimizer: originalModel.optimizer
69
+ });
70
+ return copiedModel;
71
+ }
72
+ /**
73
+ * This function will make a copy of a TFJS model, as so it would be possible
74
+ * to erase the original.
75
+ * @param model - model to be copied
76
+ * @returns - copy of the model
77
+ */
78
+ copyModel(model) {
79
+ const copy = tf.sequential();
80
+ `
81
+ `;
82
+ model.layers.forEach(layer => {
83
+ const aux = layer;
84
+ // layer.dispose();
85
+ copy.add(aux);
86
+ });
87
+ copy.compile({
88
+ loss: model.loss,
89
+ optimizer: model.optimizer
90
+ });
91
+ return copy;
92
+ }
93
+ removeElementByIndex(arr, index) {
94
+ // Check if the index is within bounds
95
+ if (index >= 0 && index < arr.length) {
96
+ // Remove the element at the specified index
97
+ arr.splice(index, 1);
98
+ }
99
+ return arr;
100
+ }
101
+ removeElement(arr, element) {
102
+ // Remove all occurrences of the specified element from the array
103
+ return arr.filter((item) => item !== element);
104
+ }
105
+ clean_array_of_tensors(tensors) {
106
+ tensors.forEach((elem, index) => {
107
+ // if(!index_selection.includes(index))
108
+ elem.dispose();
109
+ });
110
+ }
111
+ getClassNameBySignature(classes, signature) {
112
+ const class_name = classes.find(p => {
113
+ let match = true;
114
+ p.signature?.forEach((elem, index) => {
115
+ if (elem !== signature[index])
116
+ match = false;
117
+ });
118
+ return match;
119
+ });
120
+ return class_name ? class_name.name : "not found";
121
+ }
122
+ identityMatrix(n) {
123
+ return Array.from({ length: n }, (_, i) => Array.from({ length: n }, (_, j) => (i === j ? 1 : 0)));
124
+ }
125
+ indexOfMax(arr) {
126
+ if (arr.length === 0) {
127
+ return -1; // Return -1 for an empty array
128
+ }
129
+ let max = arr[0];
130
+ let maxIndex = 0;
131
+ for (let i = 1; i < arr.length; i++) {
132
+ if (arr[i] > max) {
133
+ maxIndex = i;
134
+ max = arr[i];
135
+ }
136
+ }
137
+ return maxIndex;
138
+ }
139
+ suffle(array1, array2) {
140
+ // Shuffle the order of elements
141
+ for (let i = array1.length - 1; i > 0; i--) {
142
+ const j = Math.floor(Math.random() * (i + 1));
143
+ // Swap elements in both arrays
144
+ [array1[i], array1[j]] = [array1[j], array1[i]];
145
+ [array2[i], array2[j]] = [array2[j], array2[i]];
146
+ }
147
+ }
148
+ sortByValuePreservingIndex(arr1, arr2) {
149
+ // console.log("Vector for organizing: ", arr1)
150
+ // arr2[0].summary()
151
+ // Create an array of objects with value, index from arr1, and original index
152
+ const pairingArray = arr1.map((value, index) => ({
153
+ value,
154
+ index,
155
+ originalIndex: index,
156
+ elementFromArr2: arr2[index], // Preserve the corresponding element from arr2
157
+ }));
158
+ // Sort the pairing array by value (largest to smallest)
159
+ pairingArray.sort((a, b) => b.value - a.value);
160
+ // Extract the sorted elements from arr2 based on the original index
161
+ const sortedElementsFromArr2 = pairingArray.map(pair => pair.elementFromArr2);
162
+ return sortedElementsFromArr2;
163
+ }
164
+ }
165
+ //# sourceMappingURL=data:application/json;base64,