stream-dataset 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,98 @@
1
+ import stream_dataset.evals as evals
2
+ import numpy as np
3
+
4
+ tasks = [
5
+ 'sinus_forecasting',
6
+ 'chaotic_forecasting',
7
+ 'discrete_postcasting',
8
+ 'continuous_postcasting',
9
+ 'discrete_pattern_completion',
10
+ 'continuous_pattern_completion',
11
+ 'bracket_matching',
12
+ 'simple_copy',
13
+ 'selective_copy',
14
+ 'adding_problem',
15
+ 'sorting_problem',
16
+ 'sequential_mnist',
17
+ ]
18
+
19
+ def compute_score(Y, Y_hat, prediction_timesteps, classification):
20
+ """
21
+ Compute the accuracy of the model.
22
+
23
+ Parameters:
24
+ - Y (np.ndarray): Target array [B, T, O]
25
+ - Y_hat (np.ndarray): Predicted array [B, T, O]
26
+ - prediction_timesteps (list): List of prediction timesteps
27
+ - classification (bool): Whether the task is a classification task -> accuracy or MSE
28
+
29
+ Returns:
30
+ - accuracy (float): Accuracy value
31
+ """
32
+ # Make sure Y_hat and Y are numpy arrays
33
+ if not isinstance(Y_hat, np.ndarray) or not isinstance(Y, np.ndarray):
34
+ Y = np.array(Y, dtype=np.float32)
35
+ Y_hat = np.array(Y_hat, dtype=np.float32)
36
+
37
+ # Select only the prediction timesteps
38
+ preds = []
39
+ truths = []
40
+ for j in range(Y.shape[0]):
41
+ preds.append(Y_hat[j, prediction_timesteps[j], :])
42
+ truths.append(Y[j, prediction_timesteps[j], :])
43
+
44
+ if classification:
45
+ # Compute the accuracy
46
+ preds = np.argmax(np.stack(preds, axis=0), axis=-1) # [B, prediction_timesteps] int: class
47
+ truths = np.argmax(np.stack(truths, axis=0), axis=-1) # [B, prediction_timesteps] int: class
48
+ score = np.sum(preds == truths) / (truths.shape[0] * len(prediction_timesteps[0]))
49
+ score = 1 - score
50
+
51
+ else:
52
+ # Compute the MSE
53
+ preds = np.stack(preds, axis=0).reshape(-1, Y.shape[-1]) # [B * prediction_timesteps, O] float: logits
54
+ truths = np.stack(truths, axis=0).reshape(-1, Y.shape[-1])
55
+ score = np.mean((preds - truths) ** 2)
56
+
57
+ return score
58
+
59
+ def build_task(task_name, difficulty='small', seed=None, **kwargs):
60
+ """
61
+ Build the task.
62
+
63
+ Parameters:
64
+ - task_name (str): Name of the task between 'sinus_forecasting', 'chaotic_forecasting', 'discrete_postcasting',
65
+ 'continuous_postcasting', 'discrete_pattern_completion', 'continuous_pattern_completion', 'bracket_matching',
66
+ 'simple_copy', 'selective_copy', 'adding_problem', 'sorting_problem', and 'sequential_mnist'
67
+ - difficulty (str): Difficulty level of the task ('small', 'medium' or 'large')
68
+ - seed (int, optional): Seed for reproducibility. Default is None.
69
+
70
+ The other optional parameters are given as arguments in the task generation function.
71
+
72
+ Returns:
73
+ - Task: Task object
74
+ """
75
+ # Check if the task name is valid
76
+ if task_name not in evals.stream_small:
77
+ raise ValueError(f"Task {task_name} not found. Available tasks are: {list(evals.stream_small.keys())}")
78
+ # Check if the difficulty level is valid
79
+ if difficulty not in ['small', 'medium', 'large']:
80
+ raise ValueError("Difficulty level must be 'small', 'medium' or 'large'.")
81
+
82
+ # Get the corresponding stream configuration
83
+ stream = {
84
+ 'small': evals.stream_small,
85
+ 'medium': evals.stream_medium,
86
+ 'large': evals.stream_large,
87
+ }[difficulty]
88
+
89
+ # Get the function and parameters from the stream config
90
+ fct = stream[task_name]['fct']
91
+ params = stream[task_name]['params']
92
+ params['seed'] = seed
93
+
94
+ # Update params with optional arguments
95
+ params |= kwargs
96
+
97
+ # Generate the task
98
+ return fct(**params)
@@ -0,0 +1,206 @@
1
+ from stream_dataset.tasks import *
2
+
3
+ # ----- Define the set configs -----
4
+ stream_small = {
5
+ 'sinus_forecasting': {
6
+ 'fct': generate_sinus_forecasting,
7
+ 'params': {"sequence_length": 200, "forecast_length": 5, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
8
+ 'classification': False,
9
+ },
10
+ 'chaotic_forecasting': {
11
+ 'fct': generate_chaotic_forecasting,
12
+ 'params': {"sequence_length": 200, "forecast_length": 5, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
13
+ 'classification': False,
14
+ },
15
+ 'discrete_postcasting': {
16
+ 'fct': generate_discrete_postcasting,
17
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 50, "delay": 5, "n_symbols": 3},
18
+ 'classification': True,
19
+ },
20
+ 'continuous_postcasting': {
21
+ 'fct': generate_continuous_postcasting,
22
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 50, "delay": 5},
23
+ 'classification': False,
24
+ },
25
+ 'discrete_pattern_completion': {
26
+ 'fct': generate_discrete_pattern_completion,
27
+ 'classification': True,
28
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 60, "n_symbols": 3, "base_length": 4, "mask_ratio": 0.2},
29
+ },
30
+ 'continuous_pattern_completion': {
31
+ 'fct': generate_continuous_pattern_completion,
32
+ 'classification': False,
33
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 60, "base_length": 4, "mask_ratio": 0.2},
34
+ },
35
+ 'bracket_matching': {
36
+ 'fct': generate_bracket_matching,
37
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 50, "max_depth": 5},
38
+ 'classification': True,
39
+ },
40
+ 'simple_copy': {
41
+ 'fct': generate_simple_copy,
42
+ 'classification': True,
43
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 22, "delay": 5, "n_symbols": 3},
44
+ },
45
+ 'selective_copy': {
46
+ 'fct': generate_selective_copy,
47
+ 'classification': True,
48
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 40, "delay": 5, "n_markers": 5, "n_symbols": 3},
49
+ },
50
+ 'adding_problem': {
51
+ 'fct': generate_adding_problem,
52
+ 'classification': True,
53
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 10, "max_number": 3},
54
+ },
55
+ 'sorting_problem': {
56
+ 'fct': generate_sorting_problem,
57
+ 'classification': True,
58
+ 'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 10, "n_symbols": 3},
59
+ },
60
+ 'cross_situation': {
61
+ 'fct': generate_csl,
62
+ 'classification': True,
63
+ 'params': {
64
+ "n_train": 100, "n_valid": 20, "n_test": 100,
65
+ "objects": ['glass', 'orange'],
66
+ "colors": ['blue', 'orange'],
67
+ "positions": ['left', 'right']
68
+ },
69
+ },
70
+ }
71
+
72
+ stream_medium = {
73
+ 'sinus_forecasting': {
74
+ 'fct': generate_sinus_forecasting,
75
+ 'params': {"sequence_length": 2000, "forecast_length": 15, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
76
+ 'classification': False,
77
+ },
78
+ 'chaotic_forecasting': {
79
+ 'fct': generate_chaotic_forecasting,
80
+ 'params': {"sequence_length": 2000, "forecast_length": 15, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
81
+ 'classification': False,
82
+ },
83
+ 'discrete_postcasting': {
84
+ 'fct': generate_discrete_postcasting,
85
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 100, "delay": 15, "n_symbols": 8},
86
+ 'classification': True,
87
+ },
88
+ 'continuous_postcasting': {
89
+ 'fct': generate_continuous_postcasting,
90
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 100, "delay": 15},
91
+ 'classification': False,
92
+ },
93
+ 'discrete_pattern_completion': {
94
+ 'fct': generate_discrete_pattern_completion,
95
+ 'classification': True,
96
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 150, "n_symbols": 8, "base_length": 10, "mask_ratio": 0.2},
97
+ },
98
+ 'continuous_pattern_completion': {
99
+ 'fct': generate_continuous_pattern_completion,
100
+ 'classification': False,
101
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 150, "base_length": 10, "mask_ratio": 0.2},
102
+ },
103
+ 'bracket_matching': {
104
+ 'fct': generate_bracket_matching,
105
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 100, "max_depth": 10},
106
+ 'classification': True,
107
+ },
108
+ 'simple_copy': {
109
+ 'fct': generate_simple_copy,
110
+ 'classification': True,
111
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 50, "delay": 10, "n_symbols": 8},
112
+ },
113
+ 'selective_copy': {
114
+ 'fct': generate_selective_copy,
115
+ 'classification': True,
116
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 80, "delay": 10, "n_markers": 10, "n_symbols": 8},
117
+ },
118
+ 'adding_problem': {
119
+ 'fct': generate_adding_problem,
120
+ 'classification': True,
121
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 20, "max_number": 8},
122
+ },
123
+ 'sorting_problem': {
124
+ 'fct': generate_sorting_problem,
125
+ 'classification': True,
126
+ 'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 20, "n_symbols": 8},
127
+ },
128
+ 'cross_situation': {
129
+ 'fct': generate_csl,
130
+ 'classification': True,
131
+ 'params': {
132
+ "n_train": 1000, "n_valid": 200, "n_test": 1000,
133
+ "objects": ['glass', 'orange', 'cup', 'bowl'],
134
+ "colors": ['blue', 'orange', 'green', 'red'],
135
+ "positions": ['left', 'right', ('center', 'middle')]
136
+ },
137
+ },
138
+ }
139
+
140
+ stream_large = {
141
+ 'sinus_forecasting': {
142
+ 'fct': generate_sinus_forecasting,
143
+ 'params': {"sequence_length": 20000, "forecast_length": 50, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
144
+ 'classification': False,
145
+ },
146
+ 'chaotic_forecasting': {
147
+ 'fct': generate_chaotic_forecasting,
148
+ 'params': {"sequence_length": 20000, "forecast_length": 50, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
149
+ 'classification': False,
150
+ },
151
+ 'discrete_postcasting': {
152
+ 'fct': generate_discrete_postcasting,
153
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 300, "delay": 50, "n_symbols": 8},
154
+ 'classification': True,
155
+ },
156
+ 'continuous_postcasting': {
157
+ 'fct': generate_continuous_postcasting,
158
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 300, "delay": 50},
159
+ 'classification': False,
160
+ },
161
+ 'discrete_pattern_completion': {
162
+ 'fct': generate_discrete_pattern_completion,
163
+ 'classification': True,
164
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 500, "n_symbols": 20, "base_length": 25, "mask_ratio": 0.2},
165
+ },
166
+ 'continuous_pattern_completion': {
167
+ 'fct': generate_continuous_pattern_completion,
168
+ 'classification': False,
169
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 500, "base_length": 20, "mask_ratio": 0.2},
170
+ },
171
+ 'bracket_matching': {
172
+ 'fct': generate_bracket_matching,
173
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 300, "max_depth": 30},
174
+ 'classification': True,
175
+ },
176
+ 'simple_copy': {
177
+ 'fct': generate_simple_copy,
178
+ 'classification': True,
179
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 150, "delay": 30, "n_symbols": 20},
180
+ },
181
+ 'selective_copy': {
182
+ 'fct': generate_selective_copy,
183
+ 'classification': True,
184
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 240, "delay": 30, "n_markers": 30, "n_symbols": 20},
185
+ },
186
+ 'adding_problem': {
187
+ 'fct': generate_adding_problem,
188
+ 'classification': True,
189
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 50, "max_number": 20},
190
+ },
191
+ 'sorting_problem': {
192
+ 'fct': generate_sorting_problem,
193
+ 'classification': True,
194
+ 'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 50, "n_symbols": 20},
195
+ },
196
+ 'cross_situation': {
197
+ 'fct': generate_csl,
198
+ 'classification': True,
199
+ 'params': {
200
+ "n_train": 10000, "n_valid": 2000, "n_test": 10000,
201
+ "objects": ['glass', 'orange', 'cup', 'bowl', 'plate', 'spoon'],
202
+ "colors": ['blue', 'orange', 'green', 'red', 'yellow', 'purple'],
203
+ "positions": ['left', 'right', ('center', 'middle'), 'top', 'bottom']
204
+ },
205
+ },
206
+ }
@@ -0,0 +1,342 @@
1
+ import numpy as np
2
+
3
+ NO_ROLE = 0
4
+ ACTION = 1
5
+ OBJECT = 2
6
+ COLOR = 3
7
+ NB_ROLES = 4
8
+
9
+ class TwoSituationCSLDataset:
10
+ """Class to create a dataset of sentences, their roles and predicates."""
11
+
12
+ def __init__(self, objects, colors, positions):
13
+ """Create a dataset of sentences of two situations, their roles and predicates.
14
+
15
+ Args:
16
+ objects (list): A list of objects.
17
+ colors (list): A list of colors.
18
+ positions (list): A list of positions.
19
+ """
20
+ # Define objects, colors and positions
21
+ self.objects = objects
22
+ self.colors = colors
23
+ self.positions = positions
24
+
25
+ # Create one situation datasets
26
+ self.osb = OneSituationCSLDataset(objects, colors, positions)
27
+ self.sentences, self.roles = self._combine_situation()
28
+ self.predicates = [[p1, p2] for p1 in self.osb.predicates for p2 in self.osb.predicates]
29
+
30
+ # Create one hot vectors & labels
31
+ self.input_encoder = OneHotEncoder(self.sentences)
32
+ self.output_encoder = self.osb.output_encoder
33
+ self.X = np.array([self.input_encoder.encode(s) for s in self.sentences])
34
+ self.Y = np.array([np.concatenate([y1, y2]) for y1 in self.osb.Y for y2 in self.osb.Y])
35
+
36
+
37
+ def _combine_situation(self):
38
+ """Combine two situations into one sentence and role list.
39
+
40
+ Returns:
41
+ tuple: A tuple containing a list of sentences and a list of roles.
42
+ """
43
+ combine = {'and': [NO_ROLE]}
44
+ one_situation = {k: v for k, v in list(zip(self.osb.sentences, self.osb.roles))}
45
+ two_situation = create_sentences(one_situation, combine, one_situation)
46
+ sentences, roles = zip(*two_situation.items())
47
+ return list(sentences), list(roles)
48
+
49
+
50
+
51
+ class OneSituationCSLDataset:
52
+ """Class to create a dataset of sentences, their roles and predicates."""
53
+
54
+ def __init__(self, objects, colors, positions):
55
+ """Create a dataset of sentences of one situation, their roles and predicates.
56
+
57
+ Args:
58
+ objects (list): A list of objects.
59
+ colors (list): A list of colors.
60
+ positions (list): A list of positions.
61
+ """
62
+ # Define objects, colors and positions
63
+ self.objects = objects
64
+ self.colors = colors
65
+ self.positions = positions
66
+ self.others = ['this', 'that', 'is']
67
+
68
+ # Create dataset
69
+ self.sentences, self.roles = self._create_dataset()
70
+ self.predicates = [Predicates(s, r, objects, colors, positions+self.others) for s, r in zip(self.sentences, self.roles)]
71
+
72
+ # Create one hot vectors & labels
73
+ self.input_encoder = OneHotEncoder(self.sentences)
74
+ self.output_encoder = Labeler(objects, colors, positions, self.others)
75
+ self.X = np.array([self.input_encoder.encode(s) for s in self.sentences])
76
+ self.Y = np.array([self.output_encoder.encode(s, r) for s, r in zip(self.sentences, self.roles)])
77
+
78
+
79
+ def _create_dataset(self):
80
+ """Create a dataset of sentences and their roles.
81
+
82
+ Returns:
83
+ tuple: A tuple containing a list of sentences and a list of roles.
84
+ """
85
+ # Objects, Colors, Positions
86
+ obj = create_dict_from_labels(self.objects, [OBJECT])
87
+ col = create_dict_from_labels(self.colors, [COLOR])
88
+ pos = create_dict_from_labels(self.positions, [ACTION])
89
+
90
+ # Complementary words
91
+ is_action = {'is': [ACTION]}
92
+ is_norole = {'is': [NO_ROLE]}
93
+ to_the = {'on the': [NO_ROLE, NO_ROLE]}
94
+ this_is = {'this is': [ACTION, NO_ROLE], 'that is': [ACTION, NO_ROLE]}
95
+ there_is = {'there is': [NO_ROLE, NO_ROLE]}
96
+ det = {'a': [NO_ROLE], 'the': [NO_ROLE]}
97
+
98
+ # Create sentences with corresponding roles
99
+ a_color_object = create_sentences(det, {**col, '': []}, obj) # An (color) Object
100
+ to_the_position = create_sentences(to_the, pos) # On the (position)
101
+ one_situation = {
102
+ **create_sentences(this_is, a_color_object), # This is a_color_object
103
+ **create_sentences(det, obj, is_action, col), # An object is a color
104
+ **create_sentences(det, obj, to_the_position, is_norole, col), # An object on the position is a color
105
+ **create_sentences(a_color_object, is_norole, to_the_position), # a_color_object on the position
106
+ **create_sentences(there_is, a_color_object, to_the_position), # There is a_color_object on the position
107
+ **create_sentences(to_the_position, {**is_norole, **there_is}, a_color_object) # On the position there is a_color_object
108
+ }
109
+
110
+ # Return sentences and roles
111
+ sentences, roles = zip(*one_situation.items())
112
+ return list(sentences), list(roles)
113
+
114
+
115
+ class OneHotEncoder:
116
+ """Class to encode and decode one hot vectors."""
117
+
118
+ def __init__(self, sentences):
119
+ """Create a one hot encoder from a list of sentences.
120
+
121
+ Args:
122
+ sentences (list): A list of sentences.
123
+ """
124
+ self.words = list(set(' '.join(sentences).split()))
125
+ self.word2idx = {w: i for i, w in enumerate(self.words)}
126
+ self.vocab_size = len(self.words)
127
+ self.max_length = max([len(s.split()) for s in sentences])
128
+
129
+ def encode(self, sentence):
130
+ """Encode a sentence into a one hot vector.
131
+
132
+ Args:
133
+ sentence (str): The sentence to encode.
134
+
135
+ Returns:
136
+ np.array: The one hot vector.
137
+ """
138
+ # Create matrix and fill with one hot vectors
139
+ words = sentence.split()
140
+ matrix = np.zeros((len(words), self.vocab_size))
141
+ for i, word in enumerate(words):
142
+ matrix[i, self.word2idx[word]] = 1
143
+
144
+ # Padd start of sequence with zeros
145
+ matrix = np.pad(matrix, ((self.max_length - len(words), 0), (0, 0)))
146
+
147
+ return matrix
148
+
149
+
150
+ class Labeler:
151
+ """Class to encode and decode labels."""
152
+
153
+ def __init__(self, objects, colors, positions, others):
154
+ """Create a labeler from a list of sentences and roles.
155
+
156
+ Args:
157
+ objects (list): A list of objects.
158
+ colors (list): A list of colors.
159
+ positions (list): A list of positions.
160
+ others (list): A list of other words.
161
+ """
162
+ self.labels = objects + colors + positions + others
163
+ self.idx2label = {i: l if isinstance(l, str) else l[0] for i, l in enumerate(self.labels)}
164
+ self.objects2idx = create_dict_from_labels(objects)
165
+ self.colors2idx = create_dict_from_labels(colors, base_index=len(objects))
166
+ self.actions2idx = create_dict_from_labels(positions+others, base_index=len(objects)+len(colors))
167
+
168
+ def encode(self, sentence, roles):
169
+ """Encode a sentence of one situation and its roles into a labels vector.
170
+
171
+ Args:
172
+ sentence (str): The sentence of one situation to encode.
173
+ roles (list): The roles of the sentence.
174
+
175
+ Returns:
176
+ np.array: The encoded label.
177
+ """
178
+ label = np.zeros(len(self.labels))
179
+ for i, word in enumerate(sentence.split()):
180
+ if roles[i] == OBJECT:
181
+ label[self.objects2idx[word]] = 1
182
+ elif roles[i] == COLOR:
183
+ label[self.colors2idx[word]] = 1
184
+ elif roles[i] == ACTION:
185
+ label[self.actions2idx[word]] = 1
186
+ return label
187
+
188
+ def decode(self, label):
189
+ """Decode a labels vector from one or several situation into corresponding words.
190
+
191
+ Args:
192
+ label (np.array): The label to decode.
193
+
194
+ Returns:
195
+ str: The decoded label.
196
+ """
197
+ if label.shape[0] % len(self.labels) != 0:
198
+ return 'Invalid label shape'
199
+
200
+ nb_situations = label.shape[0] // len(self.labels)
201
+ label = label.reshape(nb_situations, len(self.labels))
202
+ return [' '.join([self.idx2label[i] for i in np.where(l == 1)[0]]) for l in label]
203
+
204
+
205
+ class Predicates:
206
+ """Class to create a predicate from a sentence and its roles."""
207
+
208
+ def __init__(self, sentence, roles, objects, colors, actions):
209
+ """Create a predicate from a sentence and its roles.
210
+
211
+ Args:
212
+ sentence (str): The sentence to create the predicate from.
213
+ roles (list): The roles of the sentence.
214
+ objects (list): A list of objects.
215
+ colors (list): A list of colors.
216
+ positions (list): A list of positions.
217
+ """
218
+ # Define objects, colors and actions predicates
219
+ obj_pred = create_dict_from_labels(objects, value='first')
220
+ col_pred = create_dict_from_labels(colors, value='first')
221
+ act_pred = create_dict_from_labels(actions, value='first')
222
+
223
+ # Split sentence & prepare role list
224
+ words = sentence.split(' ')
225
+ found_roles = {x : None for x in [ACTION, OBJECT, COLOR]}
226
+ self.is_invalid = False
227
+
228
+ # Check if the sentence is valid
229
+ for i, role in enumerate(roles):
230
+ # Check if the role is valid
231
+ if role == NO_ROLE:
232
+ continue
233
+ if found_roles[role] is not None:
234
+ self.is_invalid = True
235
+ return
236
+
237
+ # Set the role
238
+ if role == ACTION:
239
+ found_roles[role] = act_pred[words[i]]
240
+ elif role == OBJECT:
241
+ found_roles[role] = obj_pred[words[i]]
242
+ elif role == COLOR:
243
+ found_roles[role] = col_pred[words[i]]
244
+
245
+
246
+ # Check if all roles are present
247
+ if found_roles[ACTION] is None or found_roles[OBJECT] is None:
248
+ self.is_invalid = True
249
+ return
250
+
251
+ # Set the roles
252
+ self.action = found_roles[ACTION]
253
+ self.object = found_roles[OBJECT]
254
+ self.color = found_roles[COLOR]
255
+
256
+ def __str__(self):
257
+ """Return the predicate as a string."""
258
+ if self.is_invalid:
259
+ return 'INVALID'
260
+ if self.color is None:
261
+ return self.action + '(' + self.object + ')'
262
+ return self.action + '(' + self.object + ', ' + self.color + ')'
263
+
264
+ def __repr__(self):
265
+ """Return the predicate as a string."""
266
+ return self.__str__()
267
+
268
+
269
+
270
+ def create_sentences(*grammars):
271
+ """Create sentences from a list of grammars.
272
+
273
+ Args:
274
+ grammars (list): A list of grammars to create sentences from.
275
+
276
+ Returns:
277
+ dict: A dictionary of sentences.
278
+ """
279
+ # If no grammars, return empty dictionary
280
+ if not grammars:
281
+ return {}
282
+
283
+ # Create sentences
284
+ sentences = grammars[0]
285
+ for grammar in grammars[1:]:
286
+ new_grammar = {}
287
+ for s1, r1 in sentences.items():
288
+ for s2, r2 in grammar.items():
289
+ new_grammar[f"{s1} {s2}".strip()] = r1 + r2
290
+ sentences = new_grammar
291
+
292
+ return sentences
293
+
294
+ def create_dict_from_labels(labels, value=None, base_index=0):
295
+ """Create a dictionary from a list of labels.
296
+
297
+ Args:
298
+ labels (list): A list of labels, can contain tuples. Each value in the tuple will be assigned the same value.
299
+ value : If none, the value will be the index in the list. If 'first', the value will be the first element in the tuple, otherwise the value will be the one defined.
300
+ base_index (int): The base index to start from.
301
+
302
+ Returns:
303
+ dict: A dictionary with labels as keys and value or index as values.
304
+ """
305
+ b = {}
306
+ for i, item in enumerate(labels):
307
+ if isinstance(item, tuple):
308
+ for key in item:
309
+ if value == None:
310
+ b[key] = base_index + i
311
+ elif value == 'first':
312
+ b[key] = item[0]
313
+ else:
314
+ b[key] = value
315
+ else:
316
+ if value == None:
317
+ b[item] = base_index + i
318
+ elif value == 'first':
319
+ b[item] = item
320
+ else:
321
+ b[item] = value
322
+ return b
323
+
324
+
325
+
326
+ # Example usage
327
+ if __name__ == '__main__':
328
+ objects = ['glass', 'orange', 'cup', 'bowl']
329
+ colors = ['blue', 'orange', 'green', 'red']
330
+ positions = ['left', 'right', ('center', 'middle')]
331
+
332
+ dataset = TwoSituationCSLDataset(objects=objects, colors=colors, positions=positions)
333
+ print(f'Shape of X: {dataset.X.shape}')
334
+ print(f'Shape of Y: {dataset.Y.shape}')
335
+ print()
336
+
337
+ random_indices = np.random.choice(len(dataset.sentences), 5, replace=False)
338
+ print("5 random sentences:")
339
+ for i in random_indices:
340
+ print(dataset.sentences[i])
341
+ print(dataset.predicates[i])
342
+ print()