miding 3.1.4__py3-none-any.whl → 3.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
miding/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from prediction import Predict
2
+ from preparation1 import *
3
+
4
+ name = 'miding'
miding/distribution.py ADDED
@@ -0,0 +1,19 @@
1
+ import numpy as np
2
+
3
+
4
+ class RandomDistributionReform:
5
+ def __init__(self, array, args: tuple):
6
+ self.array = array
7
+ self.args = args
8
+ if len(self.array.shape) > 1:
9
+ self.shape = (self.array.shape[1], self.array.shape[1])
10
+ else:
11
+ self.shape = self.array.shape
12
+
13
+ def beta_distribution(self):
14
+ beta = np.random.beta(a=self.args[0], b=self.args[1], size=self.shape)
15
+ return np.dot(self.array, beta)
16
+
17
+ def gamma_distribution(self):
18
+ gamma = np.random.gamma(shape=self.args[0], scale=self.args[1], size=self.shape)
19
+ return np.dot(self.array,gamma)
miding/model.py ADDED
@@ -0,0 +1,62 @@
1
+ from os import environ
2
+
3
+ environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
4
+ environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
5
+ from time import time
6
+ import matplotlib.pyplot as plt
7
+
8
+ from keras.models import Sequential
9
+ from keras.layers import Dense, Input, GRU
10
+ from keras.optimizers import Adam
11
+ from keras.losses import MeanSquaredError
12
+ from keras.callbacks import EarlyStopping, ModelCheckpoint
13
+
14
+ from preparation1 import create_databases
15
+
16
+
17
+ superparameters = {'batch_size': 256, 'train_length': 8}
18
+ structure = {'GRU1':20, 'GRU2': 16, 'Dense': 4, 'optimizer':'Adam', 'lr': 0.0156}
19
+ train_x, train_y, validate_x, validate_y = create_databases(midi_path='midi', train_length=superparameters['train_length'], step=1)
20
+
21
+ epochs = 1024
22
+ version = int(time())
23
+
24
+ model = Sequential([
25
+ Input(shape=(superparameters['train_length'], structure['Dense']), batch_size=superparameters['batch_size']),
26
+ GRU(units=structure['GRU1'], return_sequences=True),
27
+ GRU(units=structure['GRU2'], return_sequences=False),
28
+ Dense(units=structure['Dense'], activation='softmax'),
29
+ ])
30
+
31
+ optimizer = Adam(learning_rate=structure['lr'])
32
+ model.compile(optimizer=optimizer, loss=MeanSquaredError(), metrics=['accuracy'])
33
+ callbacks_list = [
34
+ EarlyStopping(monitor='val_accuracy', patience=128, mode='max'),
35
+ ModelCheckpoint(filepath=f'model_{version}_best.keras', monitor='val_accuracy', save_best_only=True)
36
+ ]
37
+ model.summary()
38
+
39
+ history = model.fit(
40
+ x=train_x,
41
+ y=train_y,
42
+ batch_size=superparameters['batch_size'],
43
+ epochs=epochs,
44
+ callbacks=callbacks_list,
45
+ validation_data=(validate_x, validate_y)
46
+ )
47
+
48
+ accuracy = history.history['accuracy']
49
+ val_accuracy = history.history['val_accuracy']
50
+
51
+
52
+ plt.plot(range(0, len(accuracy)), accuracy, 'b--',label='Train Accuracy')
53
+ plt.plot(range(0, len(val_accuracy)), val_accuracy, label='Validation Accuracy')
54
+ plt.suptitle(f'Effect of GRU Model:{version}')
55
+ plt.xlabel('Epochs')
56
+ plt.ylabel('Accuracy')
57
+ plt.title(f'Parameters: {str(structure)}')
58
+ plt.legend()
59
+ plt.savefig(fname=f'Model_{version}_Analysis.png', dpi=1980)
60
+ plt.show()
61
+
62
+ model.save(filepath=f'Model_{version}_ep{len(val_accuracy)}_va{round(max(val_accuracy), 3)}.keras')
miding/prediction.py ADDED
@@ -0,0 +1,128 @@
1
+ import numpy as np
2
+ from time import time
3
+ from os import environ
4
+ from mido import Message, MidiFile, MidiTrack, MetaMessage, bpm2tempo
5
+ environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
6
+ environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
7
+
8
+ from distribution import RandomDistributionReform
9
+ from preparation1 import read_model, load_midi_file, get_seed
10
+
11
+
12
+ class FormateError(TypeError):
13
+ pass
14
+
15
+
16
+ class Predict:
17
+ def __init__(self,
18
+ seed,
19
+ epoch: int,
20
+ model_version,
21
+ instrument_code: int = 0,
22
+ result_save_path: str = 'result',
23
+ ):
24
+ """
25
+ The main class, called to predict scores.
26
+ :param seed: the start score, need an array in size (1, 8, 4)
27
+ :param epoch: the total length of the generated score will be 8(seed length) + epoch
28
+ """
29
+ self.save_path = result_save_path
30
+ self.model = read_model(version=model_version)
31
+ self.seed = seed
32
+ self.epoch = epoch
33
+ self.instrument = instrument_code
34
+ self.prediction = None
35
+ self.sequence = []
36
+ self.save_file_name = f'{int(time())}-ß'
37
+ self.mid = MidiFile()
38
+ self.track0 = MidiTrack()
39
+ self.track1 = MidiTrack()
40
+ self.mid.tracks.append(self.track0)
41
+ self.mid.tracks.append(self.track1)
42
+ self.cycle()
43
+ self.save_track0()
44
+ self.save_track1()
45
+ self.mid.save(f'{self.save_path}/{self.save_file_name}.mid')
46
+
47
+ def cycle(self):
48
+ for i in range(0, self.epoch):
49
+ self.prediction = self.model.predict(self.seed)
50
+ print(self.seed[0, -1, :])
51
+ self.prediction = RandomDistributionReform(self.prediction, args=(5, 5))
52
+ self.prediction = self.prediction.beta_distribution()
53
+ for j in range(0, 7):
54
+ self.seed[0, j, :] = self.seed[0, j + 1, :]
55
+ self.seed[0, 7, :] = self.prediction[0, :]
56
+ self.sequence.append(self.prediction[0])
57
+
58
+ def save_track0(self):
59
+ self.track0.append(MetaMessage('time_signature', numerator=4, denominator=4, time=0))
60
+ self.track0.append(MetaMessage('key_signature', key='C', time=0))
61
+ self.track0.append(MetaMessage('set_tempo', tempo=bpm2tempo(80), time=0))
62
+ self.track0.append(MetaMessage('track_name', name=self.save_file_name, time=0))
63
+ self.track0.append(MetaMessage('end_of_track', time=1))
64
+
65
+ def save_track1(self):
66
+ self.track1.append(MetaMessage('track_name', name='Instrument', time=0))
67
+ self.track1.append(Message(type='program_change', program=self.instrument, time=0))
68
+ middle = []
69
+ for i in self.sequence:
70
+ if i[0] < 0.5:
71
+ i[0] = 0
72
+ else:
73
+ i[0] = 1
74
+ args = (int(i[0]), int(i[1] * 128), int(i[2] * 128), int(i[3] * 128))
75
+ if args[1] > 127 or args[2] > 127:
76
+ continue
77
+ elif args[3] > 1000:
78
+ continue
79
+ middle.append(args)
80
+ middle = np.array(middle).T
81
+ a, b= middle[0].T, middle[1].T
82
+ count = 0
83
+ for i, j in enumerate(b):
84
+ if a[i] == 1:
85
+ c = middle.T[i + 1:]
86
+ if np.array([0, j]) not in c[:, :2]:
87
+ middle_f = np.delete(middle, obj=i, axis=1)
88
+ middle = middle_f
89
+ print(middle_f)
90
+ continue
91
+ count += 1
92
+ middle = middle.T
93
+ for i in middle:
94
+ if i[0] < 0.5:
95
+ event_flag = 'note_off'
96
+ else:
97
+ event_flag = 'note_on'
98
+ self.track1.append(Message(type=event_flag, note=i[1], velocity=i[2], time=i[3]))
99
+ self.track1.append(MetaMessage('end_of_track', time=1))
100
+
101
+
102
+ class Seed:
103
+ def __init__(self, midi_file: str):
104
+ self.midi = midi_file
105
+ self.check_formate()
106
+ self.seed = np.zeros(shape=(1, 8, 4))
107
+ self.score = load_midi_file(self.midi)
108
+ self.form_seed_array(self.score)
109
+
110
+ def check_formate(self):
111
+ if '.mid' not in self.midi:
112
+ fe = FormateError()
113
+ fe.add_note(f'{self.midi} is not a correct input formate.')
114
+ raise fe
115
+
116
+ def form_seed_array(self, score: list):
117
+ for i in range(0, 8):
118
+ for j in range(0, 4):
119
+ self.seed[0, i, j] = score[i][j]
120
+
121
+ def get_seed(self):
122
+ return self.seed
123
+
124
+
125
+ if __name__ == '__main__':
126
+ # Predict(seed=get_seed(), epoch=256, model_version=1751770203)
127
+ s = Seed(midi_file='example_seed.mid')
128
+ Predict(seed=s.get_seed(),epoch=128, model_version=1751770203)
miding/preparation1.py ADDED
@@ -0,0 +1,88 @@
1
+ import numpy as np
2
+ from mido import Message, MidiFile
3
+ from random import randint
4
+
5
+ from os import listdir, environ
6
+ environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
7
+ environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
8
+
9
+ from keras.models import load_model
10
+
11
+
12
+ def load_midi_file(file: str):
13
+ score = []
14
+ midi = MidiFile(file)
15
+ for t, track in enumerate(midi.tracks):
16
+ if t > 0:
17
+ for msg in track:
18
+ if type(msg) == Message:
19
+ msg = msg.copy()
20
+ if msg.type == 'note_on':
21
+ flag = 1
22
+ elif msg.type == 'note_off':
23
+ flag = 0
24
+ else:
25
+ continue
26
+ unit = (flag, msg.note, msg.velocity, msg.time)
27
+ score.append(unit)
28
+ return score
29
+
30
+
31
+ def create_midi_data(midi_path: str):
32
+ database_list = []
33
+ for file in listdir(path=midi_path):
34
+ if '.mid' in file:
35
+ file_name = f'{midi_path}/{file}'
36
+ score = load_midi_file(file_name)
37
+ for unit in score:
38
+ database_list.append(unit)
39
+
40
+ database_list += database_list[::-1]
41
+
42
+ return database_list
43
+
44
+ def create_databases(midi_path: str, train_length: int, step: int):
45
+ database_list = create_midi_data(midi_path=midi_path)
46
+ x_list = []
47
+ y_list = []
48
+
49
+ for i in range(0, len(database_list) - train_length, step):
50
+ x_list.append(database_list[i: i + train_length])
51
+ y_list.append(database_list[i + train_length])
52
+
53
+ x = np.zeros(shape=(len(x_list), train_length, 4))
54
+ y = np.zeros(shape=(len(y_list), 4))
55
+
56
+ for i, pair in enumerate(x_list):
57
+ for j, unit in enumerate(pair):
58
+ for k, note in enumerate(unit):
59
+ x[i, j, k] = int(note) / 128
60
+ for i, unit in enumerate(y_list):
61
+ for j, note in enumerate(unit):
62
+ y[i, j] = int(note) / 128
63
+
64
+ validate_size = int(len(x_list) * 0.9)
65
+ validate_x = x[validate_size:]
66
+ validate_y = y[validate_size:]
67
+ train_x = x[:validate_size]
68
+ train_y = y[:validate_size]
69
+
70
+ return train_x, train_y, validate_x, validate_y
71
+
72
+ def read_model(version: int):
73
+ model = load_model(filepath=f'model_{version}_best.keras')
74
+ return model
75
+
76
+ def get_seed():
77
+ a, b, c, d = create_databases(midi_path='midi', train_length=8, step=1)
78
+ seed = np.zeros(shape=(1, 8, 4))
79
+ split_position = randint(16, 256)
80
+ for i in range(0, 8):
81
+ seed[0, :, :] = a[split_position, :, :]
82
+ print(split_position)
83
+ print(seed)
84
+ return seed
85
+
86
+ if __name__ == '__main__':
87
+ get_seed()
88
+
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: miding
3
- Version: 3.1.4
3
+ Version: 3.1.6
4
4
  Summary: A generator of midi score based on GRU.
5
5
  Author-email: Jerry_Skywolf <jerryskywolf@outlook.com>
6
- License-Expression: GPL-2.0
6
+ License-Expression: GPL-3.0
7
7
  Project-URL: Homepage, https://github.com/JerrySkywolf/miding
8
8
  Project-URL: Issues, https://github.com/JerrySkywolf/miding/issues
9
9
  Project-URL: DOWNLOAD, https://github.com/JerrySkywolf/miding/releases
@@ -0,0 +1,11 @@
1
+ miding/__init__.py,sha256=iMh6ucfGUnEF3lMvVLI_wizGpoMFj_ln0k-kifqCOws,77
2
+ miding/distribution.py,sha256=xoHE4GMTFiS6Fld-2em21r87zc7waehtvmgEAhQFfNE,647
3
+ miding/model.py,sha256=B5bVHCjKZCA56N5pktlVBfpW73-JSvQ3sQalahI5hT0,2306
4
+ miding/prediction.py,sha256=0Zi2U1LhVBEM6diAke5XL5MugK-ae9vdVEfh62hHB6U,4750
5
+ miding/preparation1.py,sha256=0POCRHtePWNYAZjasbAjoNA6xv6OE0i76zBKDM4WFrY,2673
6
+ miding-3.1.6.dist-info/licenses/LICENSE.txt,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
7
+ miding-3.1.6.dist-info/METADATA,sha256=2QfkCWRDHNADuXri-P2kl9Ve6dvP2lGEq_JytHVpKZk,1706
8
+ miding-3.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ miding-3.1.6.dist-info/entry_points.txt,sha256=x20XsdgK_5BjEIaoggcmSR1xFLBdrvcQSi6oI3f9f7w,45
10
+ miding-3.1.6.dist-info/top_level.txt,sha256=wOjSKSx1Fwtcm2G2fTlXklPxK8bkPIQZI_ff25Upm_w,7
11
+ miding-3.1.6.dist-info/RECORD,,
@@ -0,0 +1 @@
1
+ miding
@@ -1,6 +0,0 @@
1
- miding-3.1.4.dist-info/licenses/LICENSE.txt,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
2
- miding-3.1.4.dist-info/METADATA,sha256=ZQAD-jRaDaWOa7QStUXLZVj0S5XROy5Yr288sMmNQ9g,1706
3
- miding-3.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
4
- miding-3.1.4.dist-info/entry_points.txt,sha256=x20XsdgK_5BjEIaoggcmSR1xFLBdrvcQSi6oI3f9f7w,45
5
- miding-3.1.4.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
- miding-3.1.4.dist-info/RECORD,,
@@ -1 +0,0 @@
1
-
File without changes