stream-dataset 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stream_dataset/__init__.py +98 -0
- stream_dataset/evals.py +206 -0
- stream_dataset/libs/CSL.py +342 -0
- stream_dataset/tasks.py +679 -0
- stream_dataset-0.1.0.dist-info/METADATA +302 -0
- stream_dataset-0.1.0.dist-info/RECORD +8 -0
- stream_dataset-0.1.0.dist-info/WHEEL +4 -0
- stream_dataset-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import stream_dataset.evals as evals
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
tasks = [
|
|
5
|
+
'sinus_forecasting',
|
|
6
|
+
'chaotic_forecasting',
|
|
7
|
+
'discrete_postcasting',
|
|
8
|
+
'continuous_postcasting',
|
|
9
|
+
'discrete_pattern_completion',
|
|
10
|
+
'continuous_pattern_completion',
|
|
11
|
+
'bracket_matching',
|
|
12
|
+
'simple_copy',
|
|
13
|
+
'selective_copy',
|
|
14
|
+
'adding_problem',
|
|
15
|
+
'sorting_problem',
|
|
16
|
+
'sequential_mnist',
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
def compute_score(Y, Y_hat, prediction_timesteps, classification):
|
|
20
|
+
"""
|
|
21
|
+
Compute the accuracy of the model.
|
|
22
|
+
|
|
23
|
+
Parameters:
|
|
24
|
+
- Y (np.ndarray): Target array [B, T, O]
|
|
25
|
+
- Y_hat (np.ndarray): Predicted array [B, T, O]
|
|
26
|
+
- prediction_timesteps (list): List of prediction timesteps
|
|
27
|
+
- classification (bool): Whether the task is a classification task -> accuracy or MSE
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
- accuracy (float): Accuracy value
|
|
31
|
+
"""
|
|
32
|
+
# Make sure Y_hat and Y are numpy arrays
|
|
33
|
+
if not isinstance(Y_hat, np.ndarray) or not isinstance(Y, np.ndarray):
|
|
34
|
+
Y = np.array(Y, dtype=np.float32)
|
|
35
|
+
Y_hat = np.array(Y_hat, dtype=np.float32)
|
|
36
|
+
|
|
37
|
+
# Select only the prediction timesteps
|
|
38
|
+
preds = []
|
|
39
|
+
truths = []
|
|
40
|
+
for j in range(Y.shape[0]):
|
|
41
|
+
preds.append(Y_hat[j, prediction_timesteps[j], :])
|
|
42
|
+
truths.append(Y[j, prediction_timesteps[j], :])
|
|
43
|
+
|
|
44
|
+
if classification:
|
|
45
|
+
# Compute the accuracy
|
|
46
|
+
preds = np.argmax(np.stack(preds, axis=0), axis=-1) # [B, prediction_timesteps] int: class
|
|
47
|
+
truths = np.argmax(np.stack(truths, axis=0), axis=-1) # [B, prediction_timesteps] int: class
|
|
48
|
+
score = np.sum(preds == truths) / (truths.shape[0] * len(prediction_timesteps[0]))
|
|
49
|
+
score = 1 - score
|
|
50
|
+
|
|
51
|
+
else:
|
|
52
|
+
# Compute the MSE
|
|
53
|
+
preds = np.stack(preds, axis=0).reshape(-1, Y.shape[-1]) # [B * prediction_timesteps, O] float: logits
|
|
54
|
+
truths = np.stack(truths, axis=0).reshape(-1, Y.shape[-1])
|
|
55
|
+
score = np.mean((preds - truths) ** 2)
|
|
56
|
+
|
|
57
|
+
return score
|
|
58
|
+
|
|
59
|
+
def build_task(task_name, difficulty='small', seed=None, **kwargs):
|
|
60
|
+
"""
|
|
61
|
+
Build the task.
|
|
62
|
+
|
|
63
|
+
Parameters:
|
|
64
|
+
- task_name (str): Name of the task between 'sinus_forecasting', 'chaotic_forecasting', 'discrete_postcasting',
|
|
65
|
+
'continuous_postcasting', 'discrete_pattern_completion', 'continuous_pattern_completion', 'bracket_matching',
|
|
66
|
+
'simple_copy', 'selective_copy', 'adding_problem', 'sorting_problem', and 'sequential_mnist'
|
|
67
|
+
- difficulty (str): Difficulty level of the task ('small', 'medium' or 'large')
|
|
68
|
+
- seed (int, optional): Seed for reproducibility. Default is None.
|
|
69
|
+
|
|
70
|
+
The other optional parameters are given as arguments in the task generation function.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
- Task: Task object
|
|
74
|
+
"""
|
|
75
|
+
# Check if the task name is valid
|
|
76
|
+
if task_name not in evals.stream_small:
|
|
77
|
+
raise ValueError(f"Task {task_name} not found. Available tasks are: {list(evals.stream_small.keys())}")
|
|
78
|
+
# Check if the difficulty level is valid
|
|
79
|
+
if difficulty not in ['small', 'medium', 'large']:
|
|
80
|
+
raise ValueError("Difficulty level must be 'small', 'medium' or 'large'.")
|
|
81
|
+
|
|
82
|
+
# Get the corresponding stream configuration
|
|
83
|
+
stream = {
|
|
84
|
+
'small': evals.stream_small,
|
|
85
|
+
'medium': evals.stream_medium,
|
|
86
|
+
'large': evals.stream_large,
|
|
87
|
+
}[difficulty]
|
|
88
|
+
|
|
89
|
+
# Get the function and parameters from the stream config
|
|
90
|
+
fct = stream[task_name]['fct']
|
|
91
|
+
params = stream[task_name]['params']
|
|
92
|
+
params['seed'] = seed
|
|
93
|
+
|
|
94
|
+
# Update params with optional arguments
|
|
95
|
+
params |= kwargs
|
|
96
|
+
|
|
97
|
+
# Generate the task
|
|
98
|
+
return fct(**params)
|
stream_dataset/evals.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from stream_dataset.tasks import *
|
|
2
|
+
|
|
3
|
+
# ----- Define the set configs -----
|
|
4
|
+
stream_small = {
|
|
5
|
+
'sinus_forecasting': {
|
|
6
|
+
'fct': generate_sinus_forecasting,
|
|
7
|
+
'params': {"sequence_length": 200, "forecast_length": 5, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
|
|
8
|
+
'classification': False,
|
|
9
|
+
},
|
|
10
|
+
'chaotic_forecasting': {
|
|
11
|
+
'fct': generate_chaotic_forecasting,
|
|
12
|
+
'params': {"sequence_length": 200, "forecast_length": 5, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
|
|
13
|
+
'classification': False,
|
|
14
|
+
},
|
|
15
|
+
'discrete_postcasting': {
|
|
16
|
+
'fct': generate_discrete_postcasting,
|
|
17
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 50, "delay": 5, "n_symbols": 3},
|
|
18
|
+
'classification': True,
|
|
19
|
+
},
|
|
20
|
+
'continuous_postcasting': {
|
|
21
|
+
'fct': generate_continuous_postcasting,
|
|
22
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 50, "delay": 5},
|
|
23
|
+
'classification': False,
|
|
24
|
+
},
|
|
25
|
+
'discrete_pattern_completion': {
|
|
26
|
+
'fct': generate_discrete_pattern_completion,
|
|
27
|
+
'classification': True,
|
|
28
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 60, "n_symbols": 3, "base_length": 4, "mask_ratio": 0.2},
|
|
29
|
+
},
|
|
30
|
+
'continuous_pattern_completion': {
|
|
31
|
+
'fct': generate_continuous_pattern_completion,
|
|
32
|
+
'classification': False,
|
|
33
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 60, "base_length": 4, "mask_ratio": 0.2},
|
|
34
|
+
},
|
|
35
|
+
'bracket_matching': {
|
|
36
|
+
'fct': generate_bracket_matching,
|
|
37
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 50, "max_depth": 5},
|
|
38
|
+
'classification': True,
|
|
39
|
+
},
|
|
40
|
+
'simple_copy': {
|
|
41
|
+
'fct': generate_simple_copy,
|
|
42
|
+
'classification': True,
|
|
43
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 22, "delay": 5, "n_symbols": 3},
|
|
44
|
+
},
|
|
45
|
+
'selective_copy': {
|
|
46
|
+
'fct': generate_selective_copy,
|
|
47
|
+
'classification': True,
|
|
48
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 40, "delay": 5, "n_markers": 5, "n_symbols": 3},
|
|
49
|
+
},
|
|
50
|
+
'adding_problem': {
|
|
51
|
+
'fct': generate_adding_problem,
|
|
52
|
+
'classification': True,
|
|
53
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 10, "max_number": 3},
|
|
54
|
+
},
|
|
55
|
+
'sorting_problem': {
|
|
56
|
+
'fct': generate_sorting_problem,
|
|
57
|
+
'classification': True,
|
|
58
|
+
'params': {"n_train": 100, "n_valid": 20, "n_test": 100, "sequence_length": 10, "n_symbols": 3},
|
|
59
|
+
},
|
|
60
|
+
'cross_situation': {
|
|
61
|
+
'fct': generate_csl,
|
|
62
|
+
'classification': True,
|
|
63
|
+
'params': {
|
|
64
|
+
"n_train": 100, "n_valid": 20, "n_test": 100,
|
|
65
|
+
"objects": ['glass', 'orange'],
|
|
66
|
+
"colors": ['blue', 'orange'],
|
|
67
|
+
"positions": ['left', 'right']
|
|
68
|
+
},
|
|
69
|
+
},
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
stream_medium = {
|
|
73
|
+
'sinus_forecasting': {
|
|
74
|
+
'fct': generate_sinus_forecasting,
|
|
75
|
+
'params': {"sequence_length": 2000, "forecast_length": 15, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
|
|
76
|
+
'classification': False,
|
|
77
|
+
},
|
|
78
|
+
'chaotic_forecasting': {
|
|
79
|
+
'fct': generate_chaotic_forecasting,
|
|
80
|
+
'params': {"sequence_length": 2000, "forecast_length": 15, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
|
|
81
|
+
'classification': False,
|
|
82
|
+
},
|
|
83
|
+
'discrete_postcasting': {
|
|
84
|
+
'fct': generate_discrete_postcasting,
|
|
85
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 100, "delay": 15, "n_symbols": 8},
|
|
86
|
+
'classification': True,
|
|
87
|
+
},
|
|
88
|
+
'continuous_postcasting': {
|
|
89
|
+
'fct': generate_continuous_postcasting,
|
|
90
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 100, "delay": 15},
|
|
91
|
+
'classification': False,
|
|
92
|
+
},
|
|
93
|
+
'discrete_pattern_completion': {
|
|
94
|
+
'fct': generate_discrete_pattern_completion,
|
|
95
|
+
'classification': True,
|
|
96
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 150, "n_symbols": 8, "base_length": 10, "mask_ratio": 0.2},
|
|
97
|
+
},
|
|
98
|
+
'continuous_pattern_completion': {
|
|
99
|
+
'fct': generate_continuous_pattern_completion,
|
|
100
|
+
'classification': False,
|
|
101
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 150, "base_length": 10, "mask_ratio": 0.2},
|
|
102
|
+
},
|
|
103
|
+
'bracket_matching': {
|
|
104
|
+
'fct': generate_bracket_matching,
|
|
105
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 100, "max_depth": 10},
|
|
106
|
+
'classification': True,
|
|
107
|
+
},
|
|
108
|
+
'simple_copy': {
|
|
109
|
+
'fct': generate_simple_copy,
|
|
110
|
+
'classification': True,
|
|
111
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 50, "delay": 10, "n_symbols": 8},
|
|
112
|
+
},
|
|
113
|
+
'selective_copy': {
|
|
114
|
+
'fct': generate_selective_copy,
|
|
115
|
+
'classification': True,
|
|
116
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 80, "delay": 10, "n_markers": 10, "n_symbols": 8},
|
|
117
|
+
},
|
|
118
|
+
'adding_problem': {
|
|
119
|
+
'fct': generate_adding_problem,
|
|
120
|
+
'classification': True,
|
|
121
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 20, "max_number": 8},
|
|
122
|
+
},
|
|
123
|
+
'sorting_problem': {
|
|
124
|
+
'fct': generate_sorting_problem,
|
|
125
|
+
'classification': True,
|
|
126
|
+
'params': {"n_train": 1000, "n_valid": 200, "n_test": 1000, "sequence_length": 20, "n_symbols": 8},
|
|
127
|
+
},
|
|
128
|
+
'cross_situation': {
|
|
129
|
+
'fct': generate_csl,
|
|
130
|
+
'classification': True,
|
|
131
|
+
'params': {
|
|
132
|
+
"n_train": 1000, "n_valid": 200, "n_test": 1000,
|
|
133
|
+
"objects": ['glass', 'orange', 'cup', 'bowl'],
|
|
134
|
+
"colors": ['blue', 'orange', 'green', 'red'],
|
|
135
|
+
"positions": ['left', 'right', ('center', 'middle')]
|
|
136
|
+
},
|
|
137
|
+
},
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
stream_large = {
|
|
141
|
+
'sinus_forecasting': {
|
|
142
|
+
'fct': generate_sinus_forecasting,
|
|
143
|
+
'params': {"sequence_length": 20000, "forecast_length": 50, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
|
|
144
|
+
'classification': False,
|
|
145
|
+
},
|
|
146
|
+
'chaotic_forecasting': {
|
|
147
|
+
'fct': generate_chaotic_forecasting,
|
|
148
|
+
'params': {"sequence_length": 20000, "forecast_length": 50, "training_ratio": 0.45, "validation_ratio": 0.1, "testing_ratio": 0.45},
|
|
149
|
+
'classification': False,
|
|
150
|
+
},
|
|
151
|
+
'discrete_postcasting': {
|
|
152
|
+
'fct': generate_discrete_postcasting,
|
|
153
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 300, "delay": 50, "n_symbols": 8},
|
|
154
|
+
'classification': True,
|
|
155
|
+
},
|
|
156
|
+
'continuous_postcasting': {
|
|
157
|
+
'fct': generate_continuous_postcasting,
|
|
158
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 300, "delay": 50},
|
|
159
|
+
'classification': False,
|
|
160
|
+
},
|
|
161
|
+
'discrete_pattern_completion': {
|
|
162
|
+
'fct': generate_discrete_pattern_completion,
|
|
163
|
+
'classification': True,
|
|
164
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 500, "n_symbols": 20, "base_length": 25, "mask_ratio": 0.2},
|
|
165
|
+
},
|
|
166
|
+
'continuous_pattern_completion': {
|
|
167
|
+
'fct': generate_continuous_pattern_completion,
|
|
168
|
+
'classification': False,
|
|
169
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 500, "base_length": 20, "mask_ratio": 0.2},
|
|
170
|
+
},
|
|
171
|
+
'bracket_matching': {
|
|
172
|
+
'fct': generate_bracket_matching,
|
|
173
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 300, "max_depth": 30},
|
|
174
|
+
'classification': True,
|
|
175
|
+
},
|
|
176
|
+
'simple_copy': {
|
|
177
|
+
'fct': generate_simple_copy,
|
|
178
|
+
'classification': True,
|
|
179
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 150, "delay": 30, "n_symbols": 20},
|
|
180
|
+
},
|
|
181
|
+
'selective_copy': {
|
|
182
|
+
'fct': generate_selective_copy,
|
|
183
|
+
'classification': True,
|
|
184
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 240, "delay": 30, "n_markers": 30, "n_symbols": 20},
|
|
185
|
+
},
|
|
186
|
+
'adding_problem': {
|
|
187
|
+
'fct': generate_adding_problem,
|
|
188
|
+
'classification': True,
|
|
189
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 50, "max_number": 20},
|
|
190
|
+
},
|
|
191
|
+
'sorting_problem': {
|
|
192
|
+
'fct': generate_sorting_problem,
|
|
193
|
+
'classification': True,
|
|
194
|
+
'params': {"n_train": 10000, "n_valid": 2000, "n_test": 10000, "sequence_length": 50, "n_symbols": 20},
|
|
195
|
+
},
|
|
196
|
+
'cross_situation': {
|
|
197
|
+
'fct': generate_csl,
|
|
198
|
+
'classification': True,
|
|
199
|
+
'params': {
|
|
200
|
+
"n_train": 10000, "n_valid": 2000, "n_test": 10000,
|
|
201
|
+
"objects": ['glass', 'orange', 'cup', 'bowl', 'plate', 'spoon'],
|
|
202
|
+
"colors": ['blue', 'orange', 'green', 'red', 'yellow', 'purple'],
|
|
203
|
+
"positions": ['left', 'right', ('center', 'middle'), 'top', 'bottom']
|
|
204
|
+
},
|
|
205
|
+
},
|
|
206
|
+
}
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
NO_ROLE = 0
|
|
4
|
+
ACTION = 1
|
|
5
|
+
OBJECT = 2
|
|
6
|
+
COLOR = 3
|
|
7
|
+
NB_ROLES = 4
|
|
8
|
+
|
|
9
|
+
class TwoSituationCSLDataset:
|
|
10
|
+
"""Class to create a dataset of sentences, their roles and predicates."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, objects, colors, positions):
|
|
13
|
+
"""Create a dataset of sentences of two situations, their roles and predicates.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
objects (list): A list of objects.
|
|
17
|
+
colors (list): A list of colors.
|
|
18
|
+
positions (list): A list of positions.
|
|
19
|
+
"""
|
|
20
|
+
# Define objects, colors and positions
|
|
21
|
+
self.objects = objects
|
|
22
|
+
self.colors = colors
|
|
23
|
+
self.positions = positions
|
|
24
|
+
|
|
25
|
+
# Create one situation datasets
|
|
26
|
+
self.osb = OneSituationCSLDataset(objects, colors, positions)
|
|
27
|
+
self.sentences, self.roles = self._combine_situation()
|
|
28
|
+
self.predicates = [[p1, p2] for p1 in self.osb.predicates for p2 in self.osb.predicates]
|
|
29
|
+
|
|
30
|
+
# Create one hot vectors & labels
|
|
31
|
+
self.input_encoder = OneHotEncoder(self.sentences)
|
|
32
|
+
self.output_encoder = self.osb.output_encoder
|
|
33
|
+
self.X = np.array([self.input_encoder.encode(s) for s in self.sentences])
|
|
34
|
+
self.Y = np.array([np.concatenate([y1, y2]) for y1 in self.osb.Y for y2 in self.osb.Y])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _combine_situation(self):
|
|
38
|
+
"""Combine two situations into one sentence and role list.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
tuple: A tuple containing a list of sentences and a list of roles.
|
|
42
|
+
"""
|
|
43
|
+
combine = {'and': [NO_ROLE]}
|
|
44
|
+
one_situation = {k: v for k, v in list(zip(self.osb.sentences, self.osb.roles))}
|
|
45
|
+
two_situation = create_sentences(one_situation, combine, one_situation)
|
|
46
|
+
sentences, roles = zip(*two_situation.items())
|
|
47
|
+
return list(sentences), list(roles)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class OneSituationCSLDataset:
|
|
52
|
+
"""Class to create a dataset of sentences, their roles and predicates."""
|
|
53
|
+
|
|
54
|
+
def __init__(self, objects, colors, positions):
|
|
55
|
+
"""Create a dataset of sentences of one situation, their roles and predicates.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
objects (list): A list of objects.
|
|
59
|
+
colors (list): A list of colors.
|
|
60
|
+
positions (list): A list of positions.
|
|
61
|
+
"""
|
|
62
|
+
# Define objects, colors and positions
|
|
63
|
+
self.objects = objects
|
|
64
|
+
self.colors = colors
|
|
65
|
+
self.positions = positions
|
|
66
|
+
self.others = ['this', 'that', 'is']
|
|
67
|
+
|
|
68
|
+
# Create dataset
|
|
69
|
+
self.sentences, self.roles = self._create_dataset()
|
|
70
|
+
self.predicates = [Predicates(s, r, objects, colors, positions+self.others) for s, r in zip(self.sentences, self.roles)]
|
|
71
|
+
|
|
72
|
+
# Create one hot vectors & labels
|
|
73
|
+
self.input_encoder = OneHotEncoder(self.sentences)
|
|
74
|
+
self.output_encoder = Labeler(objects, colors, positions, self.others)
|
|
75
|
+
self.X = np.array([self.input_encoder.encode(s) for s in self.sentences])
|
|
76
|
+
self.Y = np.array([self.output_encoder.encode(s, r) for s, r in zip(self.sentences, self.roles)])
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_dataset(self):
|
|
80
|
+
"""Create a dataset of sentences and their roles.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
tuple: A tuple containing a list of sentences and a list of roles.
|
|
84
|
+
"""
|
|
85
|
+
# Objects, Colors, Positions
|
|
86
|
+
obj = create_dict_from_labels(self.objects, [OBJECT])
|
|
87
|
+
col = create_dict_from_labels(self.colors, [COLOR])
|
|
88
|
+
pos = create_dict_from_labels(self.positions, [ACTION])
|
|
89
|
+
|
|
90
|
+
# Complementary words
|
|
91
|
+
is_action = {'is': [ACTION]}
|
|
92
|
+
is_norole = {'is': [NO_ROLE]}
|
|
93
|
+
to_the = {'on the': [NO_ROLE, NO_ROLE]}
|
|
94
|
+
this_is = {'this is': [ACTION, NO_ROLE], 'that is': [ACTION, NO_ROLE]}
|
|
95
|
+
there_is = {'there is': [NO_ROLE, NO_ROLE]}
|
|
96
|
+
det = {'a': [NO_ROLE], 'the': [NO_ROLE]}
|
|
97
|
+
|
|
98
|
+
# Create sentences with corresponding roles
|
|
99
|
+
a_color_object = create_sentences(det, {**col, '': []}, obj) # An (color) Object
|
|
100
|
+
to_the_position = create_sentences(to_the, pos) # On the (position)
|
|
101
|
+
one_situation = {
|
|
102
|
+
**create_sentences(this_is, a_color_object), # This is a_color_object
|
|
103
|
+
**create_sentences(det, obj, is_action, col), # An object is a color
|
|
104
|
+
**create_sentences(det, obj, to_the_position, is_norole, col), # An object on the position is a color
|
|
105
|
+
**create_sentences(a_color_object, is_norole, to_the_position), # a_color_object on the position
|
|
106
|
+
**create_sentences(there_is, a_color_object, to_the_position), # There is a_color_object on the position
|
|
107
|
+
**create_sentences(to_the_position, {**is_norole, **there_is}, a_color_object) # On the position there is a_color_object
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# Return sentences and roles
|
|
111
|
+
sentences, roles = zip(*one_situation.items())
|
|
112
|
+
return list(sentences), list(roles)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class OneHotEncoder:
|
|
116
|
+
"""Class to encode and decode one hot vectors."""
|
|
117
|
+
|
|
118
|
+
def __init__(self, sentences):
|
|
119
|
+
"""Create a one hot encoder from a list of sentences.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
sentences (list): A list of sentences.
|
|
123
|
+
"""
|
|
124
|
+
self.words = list(set(' '.join(sentences).split()))
|
|
125
|
+
self.word2idx = {w: i for i, w in enumerate(self.words)}
|
|
126
|
+
self.vocab_size = len(self.words)
|
|
127
|
+
self.max_length = max([len(s.split()) for s in sentences])
|
|
128
|
+
|
|
129
|
+
def encode(self, sentence):
|
|
130
|
+
"""Encode a sentence into a one hot vector.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
sentence (str): The sentence to encode.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
np.array: The one hot vector.
|
|
137
|
+
"""
|
|
138
|
+
# Create matrix and fill with one hot vectors
|
|
139
|
+
words = sentence.split()
|
|
140
|
+
matrix = np.zeros((len(words), self.vocab_size))
|
|
141
|
+
for i, word in enumerate(words):
|
|
142
|
+
matrix[i, self.word2idx[word]] = 1
|
|
143
|
+
|
|
144
|
+
# Padd start of sequence with zeros
|
|
145
|
+
matrix = np.pad(matrix, ((self.max_length - len(words), 0), (0, 0)))
|
|
146
|
+
|
|
147
|
+
return matrix
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class Labeler:
|
|
151
|
+
"""Class to encode and decode labels."""
|
|
152
|
+
|
|
153
|
+
def __init__(self, objects, colors, positions, others):
|
|
154
|
+
"""Create a labeler from a list of sentences and roles.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
objects (list): A list of objects.
|
|
158
|
+
colors (list): A list of colors.
|
|
159
|
+
positions (list): A list of positions.
|
|
160
|
+
others (list): A list of other words.
|
|
161
|
+
"""
|
|
162
|
+
self.labels = objects + colors + positions + others
|
|
163
|
+
self.idx2label = {i: l if isinstance(l, str) else l[0] for i, l in enumerate(self.labels)}
|
|
164
|
+
self.objects2idx = create_dict_from_labels(objects)
|
|
165
|
+
self.colors2idx = create_dict_from_labels(colors, base_index=len(objects))
|
|
166
|
+
self.actions2idx = create_dict_from_labels(positions+others, base_index=len(objects)+len(colors))
|
|
167
|
+
|
|
168
|
+
def encode(self, sentence, roles):
|
|
169
|
+
"""Encode a sentence of one situation and its roles into a labels vector.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
sentence (str): The sentence of one situation to encode.
|
|
173
|
+
roles (list): The roles of the sentence.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
np.array: The encoded label.
|
|
177
|
+
"""
|
|
178
|
+
label = np.zeros(len(self.labels))
|
|
179
|
+
for i, word in enumerate(sentence.split()):
|
|
180
|
+
if roles[i] == OBJECT:
|
|
181
|
+
label[self.objects2idx[word]] = 1
|
|
182
|
+
elif roles[i] == COLOR:
|
|
183
|
+
label[self.colors2idx[word]] = 1
|
|
184
|
+
elif roles[i] == ACTION:
|
|
185
|
+
label[self.actions2idx[word]] = 1
|
|
186
|
+
return label
|
|
187
|
+
|
|
188
|
+
def decode(self, label):
|
|
189
|
+
"""Decode a labels vector from one or several situation into corresponding words.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
label (np.array): The label to decode.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
str: The decoded label.
|
|
196
|
+
"""
|
|
197
|
+
if label.shape[0] % len(self.labels) != 0:
|
|
198
|
+
return 'Invalid label shape'
|
|
199
|
+
|
|
200
|
+
nb_situations = label.shape[0] // len(self.labels)
|
|
201
|
+
label = label.reshape(nb_situations, len(self.labels))
|
|
202
|
+
return [' '.join([self.idx2label[i] for i in np.where(l == 1)[0]]) for l in label]
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class Predicates:
|
|
206
|
+
"""Class to create a predicate from a sentence and its roles."""
|
|
207
|
+
|
|
208
|
+
def __init__(self, sentence, roles, objects, colors, actions):
|
|
209
|
+
"""Create a predicate from a sentence and its roles.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
sentence (str): The sentence to create the predicate from.
|
|
213
|
+
roles (list): The roles of the sentence.
|
|
214
|
+
objects (list): A list of objects.
|
|
215
|
+
colors (list): A list of colors.
|
|
216
|
+
positions (list): A list of positions.
|
|
217
|
+
"""
|
|
218
|
+
# Define objects, colors and actions predicates
|
|
219
|
+
obj_pred = create_dict_from_labels(objects, value='first')
|
|
220
|
+
col_pred = create_dict_from_labels(colors, value='first')
|
|
221
|
+
act_pred = create_dict_from_labels(actions, value='first')
|
|
222
|
+
|
|
223
|
+
# Split sentence & prepare role list
|
|
224
|
+
words = sentence.split(' ')
|
|
225
|
+
found_roles = {x : None for x in [ACTION, OBJECT, COLOR]}
|
|
226
|
+
self.is_invalid = False
|
|
227
|
+
|
|
228
|
+
# Check if the sentence is valid
|
|
229
|
+
for i, role in enumerate(roles):
|
|
230
|
+
# Check if the role is valid
|
|
231
|
+
if role == NO_ROLE:
|
|
232
|
+
continue
|
|
233
|
+
if found_roles[role] is not None:
|
|
234
|
+
self.is_invalid = True
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
# Set the role
|
|
238
|
+
if role == ACTION:
|
|
239
|
+
found_roles[role] = act_pred[words[i]]
|
|
240
|
+
elif role == OBJECT:
|
|
241
|
+
found_roles[role] = obj_pred[words[i]]
|
|
242
|
+
elif role == COLOR:
|
|
243
|
+
found_roles[role] = col_pred[words[i]]
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# Check if all roles are present
|
|
247
|
+
if found_roles[ACTION] is None or found_roles[OBJECT] is None:
|
|
248
|
+
self.is_invalid = True
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
# Set the roles
|
|
252
|
+
self.action = found_roles[ACTION]
|
|
253
|
+
self.object = found_roles[OBJECT]
|
|
254
|
+
self.color = found_roles[COLOR]
|
|
255
|
+
|
|
256
|
+
def __str__(self):
|
|
257
|
+
"""Return the predicate as a string."""
|
|
258
|
+
if self.is_invalid:
|
|
259
|
+
return 'INVALID'
|
|
260
|
+
if self.color is None:
|
|
261
|
+
return self.action + '(' + self.object + ')'
|
|
262
|
+
return self.action + '(' + self.object + ', ' + self.color + ')'
|
|
263
|
+
|
|
264
|
+
def __repr__(self):
|
|
265
|
+
"""Return the predicate as a string."""
|
|
266
|
+
return self.__str__()
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def create_sentences(*grammars):
|
|
271
|
+
"""Create sentences from a list of grammars.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
grammars (list): A list of grammars to create sentences from.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
dict: A dictionary of sentences.
|
|
278
|
+
"""
|
|
279
|
+
# If no grammars, return empty dictionary
|
|
280
|
+
if not grammars:
|
|
281
|
+
return {}
|
|
282
|
+
|
|
283
|
+
# Create sentences
|
|
284
|
+
sentences = grammars[0]
|
|
285
|
+
for grammar in grammars[1:]:
|
|
286
|
+
new_grammar = {}
|
|
287
|
+
for s1, r1 in sentences.items():
|
|
288
|
+
for s2, r2 in grammar.items():
|
|
289
|
+
new_grammar[f"{s1} {s2}".strip()] = r1 + r2
|
|
290
|
+
sentences = new_grammar
|
|
291
|
+
|
|
292
|
+
return sentences
|
|
293
|
+
|
|
294
|
+
def create_dict_from_labels(labels, value=None, base_index=0):
|
|
295
|
+
"""Create a dictionary from a list of labels.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
labels (list): A list of labels, can contain tuples. Each value in the tuple will be assigned the same value.
|
|
299
|
+
value : If none, the value will be the index in the list. If 'first', the value will be the first element in the tuple, otherwise the value will be the one defined.
|
|
300
|
+
base_index (int): The base index to start from.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
dict: A dictionary with labels as keys and value or index as values.
|
|
304
|
+
"""
|
|
305
|
+
b = {}
|
|
306
|
+
for i, item in enumerate(labels):
|
|
307
|
+
if isinstance(item, tuple):
|
|
308
|
+
for key in item:
|
|
309
|
+
if value == None:
|
|
310
|
+
b[key] = base_index + i
|
|
311
|
+
elif value == 'first':
|
|
312
|
+
b[key] = item[0]
|
|
313
|
+
else:
|
|
314
|
+
b[key] = value
|
|
315
|
+
else:
|
|
316
|
+
if value == None:
|
|
317
|
+
b[item] = base_index + i
|
|
318
|
+
elif value == 'first':
|
|
319
|
+
b[item] = item
|
|
320
|
+
else:
|
|
321
|
+
b[item] = value
|
|
322
|
+
return b
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# Example usage
|
|
327
|
+
if __name__ == '__main__':
|
|
328
|
+
objects = ['glass', 'orange', 'cup', 'bowl']
|
|
329
|
+
colors = ['blue', 'orange', 'green', 'red']
|
|
330
|
+
positions = ['left', 'right', ('center', 'middle')]
|
|
331
|
+
|
|
332
|
+
dataset = TwoSituationCSLDataset(objects=objects, colors=colors, positions=positions)
|
|
333
|
+
print(f'Shape of X: {dataset.X.shape}')
|
|
334
|
+
print(f'Shape of Y: {dataset.Y.shape}')
|
|
335
|
+
print()
|
|
336
|
+
|
|
337
|
+
random_indices = np.random.choice(len(dataset.sentences), 5, replace=False)
|
|
338
|
+
print("5 random sentences:")
|
|
339
|
+
for i in random_indices:
|
|
340
|
+
print(dataset.sentences[i])
|
|
341
|
+
print(dataset.predicates[i])
|
|
342
|
+
print()
|