job-shop-lib 0.5.1__py3-none-any.whl → 1.0.0a1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- job_shop_lib/__init__.py +16 -8
- job_shop_lib/{base_solver.py → _base_solver.py} +1 -1
- job_shop_lib/{job_shop_instance.py → _job_shop_instance.py} +9 -4
- job_shop_lib/_operation.py +95 -0
- job_shop_lib/{schedule.py → _schedule.py} +73 -54
- job_shop_lib/{scheduled_operation.py → _scheduled_operation.py} +13 -37
- job_shop_lib/benchmarking/__init__.py +66 -43
- job_shop_lib/benchmarking/_load_benchmark.py +88 -0
- job_shop_lib/constraint_programming/__init__.py +13 -0
- job_shop_lib/{cp_sat/ortools_solver.py → constraint_programming/_ortools_solver.py} +57 -18
- job_shop_lib/dispatching/__init__.py +45 -41
- job_shop_lib/dispatching/{dispatcher.py → _dispatcher.py} +153 -80
- job_shop_lib/dispatching/_dispatcher_observer_config.py +54 -0
- job_shop_lib/dispatching/_factories.py +125 -0
- job_shop_lib/dispatching/{history_tracker.py → _history_observer.py} +4 -6
- job_shop_lib/dispatching/{pruning_functions.py → _ready_operation_filters.py} +6 -35
- job_shop_lib/dispatching/_unscheduled_operations_observer.py +69 -0
- job_shop_lib/dispatching/feature_observers/__init__.py +16 -10
- job_shop_lib/dispatching/feature_observers/{composite_feature_observer.py → _composite_feature_observer.py} +84 -2
- job_shop_lib/dispatching/feature_observers/{duration_observer.py → _duration_observer.py} +6 -17
- job_shop_lib/dispatching/feature_observers/{earliest_start_time_observer.py → _earliest_start_time_observer.py} +114 -35
- job_shop_lib/dispatching/feature_observers/{factory.py → _factory.py} +31 -5
- job_shop_lib/dispatching/feature_observers/{feature_observer.py → _feature_observer.py} +59 -16
- job_shop_lib/dispatching/feature_observers/_is_completed_observer.py +97 -0
- job_shop_lib/dispatching/feature_observers/_is_ready_observer.py +33 -0
- job_shop_lib/dispatching/feature_observers/{position_in_job_observer.py → _position_in_job_observer.py} +1 -8
- job_shop_lib/dispatching/feature_observers/{remaining_operations_observer.py → _remaining_operations_observer.py} +8 -26
- job_shop_lib/dispatching/rules/__init__.py +51 -0
- job_shop_lib/dispatching/rules/_dispatching_rule_factory.py +82 -0
- job_shop_lib/dispatching/{dispatching_rule_solver.py → rules/_dispatching_rule_solver.py} +44 -15
- job_shop_lib/dispatching/{dispatching_rules.py → rules/_dispatching_rules_functions.py} +74 -21
- job_shop_lib/dispatching/rules/_machine_chooser_factory.py +69 -0
- job_shop_lib/dispatching/rules/_utils.py +127 -0
- job_shop_lib/exceptions.py +18 -0
- job_shop_lib/generation/__init__.py +2 -2
- job_shop_lib/generation/{general_instance_generator.py → _general_instance_generator.py} +26 -7
- job_shop_lib/generation/{instance_generator.py → _instance_generator.py} +13 -3
- job_shop_lib/graphs/__init__.py +17 -6
- job_shop_lib/graphs/{job_shop_graph.py → _job_shop_graph.py} +81 -2
- job_shop_lib/graphs/{node.py → _node.py} +18 -12
- job_shop_lib/graphs/graph_updaters/__init__.py +13 -0
- job_shop_lib/graphs/graph_updaters/_graph_updater.py +59 -0
- job_shop_lib/graphs/graph_updaters/_residual_graph_updater.py +154 -0
- job_shop_lib/graphs/graph_updaters/_utils.py +25 -0
- job_shop_lib/reinforcement_learning/__init__.py +41 -0
- job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py +366 -0
- job_shop_lib/reinforcement_learning/_reward_observers.py +85 -0
- job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py +337 -0
- job_shop_lib/reinforcement_learning/_types_and_constants.py +61 -0
- job_shop_lib/reinforcement_learning/_utils.py +96 -0
- job_shop_lib/visualization/__init__.py +20 -4
- job_shop_lib/visualization/{agent_task_graph.py → _agent_task_graph.py} +28 -9
- job_shop_lib/visualization/_gantt_chart_creator.py +219 -0
- job_shop_lib/visualization/_gantt_chart_video_and_gif_creation.py +388 -0
- {job_shop_lib-0.5.1.dist-info → job_shop_lib-1.0.0a1.dist-info}/METADATA +68 -44
- job_shop_lib-1.0.0a1.dist-info/RECORD +66 -0
- job_shop_lib/benchmarking/load_benchmark.py +0 -142
- job_shop_lib/cp_sat/__init__.py +0 -5
- job_shop_lib/dispatching/factories.py +0 -206
- job_shop_lib/dispatching/feature_observers/is_completed_observer.py +0 -98
- job_shop_lib/dispatching/feature_observers/is_ready_observer.py +0 -40
- job_shop_lib/generators/__init__.py +0 -8
- job_shop_lib/generators/basic_generator.py +0 -200
- job_shop_lib/generators/transformations.py +0 -164
- job_shop_lib/operation.py +0 -122
- job_shop_lib/visualization/create_gif.py +0 -209
- job_shop_lib-0.5.1.dist-info/RECORD +0 -52
- /job_shop_lib/dispatching/feature_observers/{is_scheduled_observer.py → _is_scheduled_observer.py} +0 -0
- /job_shop_lib/generation/{transformations.py → _transformations.py} +0 -0
- /job_shop_lib/graphs/{build_agent_task_graph.py → _build_agent_task_graph.py} +0 -0
- /job_shop_lib/graphs/{build_disjunctive_graph.py → _build_disjunctive_graph.py} +0 -0
- /job_shop_lib/graphs/{constants.py → _constants.py} +0 -0
- /job_shop_lib/visualization/{disjunctive_graph.py → _disjunctive_graph.py} +0 -0
- /job_shop_lib/visualization/{gantt_chart.py → _gantt_chart.py} +0 -0
- {job_shop_lib-0.5.1.dist-info → job_shop_lib-1.0.0a1.dist-info}/LICENSE +0 -0
- {job_shop_lib-0.5.1.dist-info → job_shop_lib-1.0.0a1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,337 @@
|
|
1
|
+
"""Home of the `SingleJobShopGraphEnv` class."""
|
2
|
+
|
3
|
+
from copy import deepcopy
|
4
|
+
from collections.abc import Callable, Sequence
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import matplotlib.pyplot as plt
|
8
|
+
import gymnasium as gym
|
9
|
+
import numpy as np
|
10
|
+
from numpy.typing import NDArray
|
11
|
+
|
12
|
+
from job_shop_lib import JobShopInstance, Operation
|
13
|
+
from job_shop_lib.graphs import JobShopGraph
|
14
|
+
from job_shop_lib.graphs.graph_updaters import (
|
15
|
+
GraphUpdater,
|
16
|
+
ResidualGraphUpdater,
|
17
|
+
)
|
18
|
+
from job_shop_lib.dispatching import (
|
19
|
+
Dispatcher,
|
20
|
+
filter_dominated_operations,
|
21
|
+
DispatcherObserverConfig,
|
22
|
+
)
|
23
|
+
from job_shop_lib.dispatching.feature_observers import (
|
24
|
+
FeatureObserverConfig,
|
25
|
+
CompositeFeatureObserver,
|
26
|
+
)
|
27
|
+
from job_shop_lib.visualization import GanttChartCreator
|
28
|
+
from job_shop_lib.reinforcement_learning import (
|
29
|
+
RewardObserver,
|
30
|
+
MakespanReward,
|
31
|
+
add_padding,
|
32
|
+
RenderConfig,
|
33
|
+
ObservationSpaceKey,
|
34
|
+
ObservationDict,
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
class SingleJobShopGraphEnv(gym.Env):
|
39
|
+
"""A Gymnasium environment for solving a specific instance of the Job Shop
|
40
|
+
Scheduling Problem represented as a graph.
|
41
|
+
|
42
|
+
This environment manages the scheduling process for a single Job Shop
|
43
|
+
instance, using a graph representation and various observers to track the
|
44
|
+
state and compute rewards.
|
45
|
+
|
46
|
+
Observation Space:
|
47
|
+
A dictionary with the following keys:
|
48
|
+
- "removed_nodes": Binary vector indicating removed graph nodes.
|
49
|
+
- "edge_list": Matrix of graph edges in COO format.
|
50
|
+
- Feature matrices: Keys corresponding to the composite observer
|
51
|
+
features (e.g., "operations", "jobs", "machines").
|
52
|
+
|
53
|
+
Action Space:
|
54
|
+
MultiDiscrete space representing (job_id, machine_id) pairs.
|
55
|
+
|
56
|
+
Render Modes:
|
57
|
+
- "human": Displays the current Gantt chart.
|
58
|
+
- "save_video": Saves a video of the complete Gantt chart.
|
59
|
+
- "save_gif": Saves a GIF of the complete Gantt chart.
|
60
|
+
|
61
|
+
Attributes:
|
62
|
+
dispatcher:
|
63
|
+
Manages the scheduling process. See
|
64
|
+
:class:`~job_shop_lib.dispatching.Dispatcher`.
|
65
|
+
composite_observer:
|
66
|
+
A :class:`~job_shop_lib.dispatching.feature_observers.
|
67
|
+
CompositeFeatureObserver` which aggregates features from multiple
|
68
|
+
observers.
|
69
|
+
graph_updater:
|
70
|
+
Updates the graph representation after each action. See
|
71
|
+
:class:`~job_shop_lib.graphs.GraphUpdater`.
|
72
|
+
reward_function:
|
73
|
+
Computes rewards for actions taken. See
|
74
|
+
:class:`~job_shop_lib.reinforcement_learning.RewardObserver`.
|
75
|
+
action_space:
|
76
|
+
Defines the action space. The action is a tuple of two integers
|
77
|
+
(job_id, machine_id). The machine_id can be -1 if the selected
|
78
|
+
operation can only be scheduled in one machine.
|
79
|
+
observation_space:
|
80
|
+
Defines the observation space. The observation is a dictionary
|
81
|
+
with the following keys:
|
82
|
+
- "removed_nodes": Binary vector indicating removed graph nodes.
|
83
|
+
- "edge_list": Matrix of graph edges in COO format.
|
84
|
+
- Feature matrices: Keys corresponding to the composite observer
|
85
|
+
features (e.g., "operations", "jobs", "machines").
|
86
|
+
render_mode:
|
87
|
+
The mode for rendering the environment ("human", "save_video",
|
88
|
+
"save_gif").
|
89
|
+
gantt_chart_creator:
|
90
|
+
Creates Gantt chart visualizations. See
|
91
|
+
:class:`~job_shop_lib.visualization.GanttChartCreator`.
|
92
|
+
use_padding:
|
93
|
+
Whether to use padding in observations. Padding maintains the
|
94
|
+
"""
|
95
|
+
|
96
|
+
metadata = {"render_modes": ["human", "save_video", "save_gif"]}
|
97
|
+
|
98
|
+
# I think the class is easier to use this way. We could initiliaze the
|
99
|
+
# class from Dispatcher or an already initialized RewardFunction. However,
|
100
|
+
# it would be impossible to add good default values.
|
101
|
+
# pylint: disable=too-many-arguments
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
job_shop_graph: JobShopGraph,
|
105
|
+
feature_observer_configs: Sequence[FeatureObserverConfig],
|
106
|
+
reward_function_config: DispatcherObserverConfig[
|
107
|
+
type[RewardObserver]
|
108
|
+
] = DispatcherObserverConfig(class_type=MakespanReward),
|
109
|
+
graph_updater_config: DispatcherObserverConfig[
|
110
|
+
type[GraphUpdater]
|
111
|
+
] = DispatcherObserverConfig(class_type=ResidualGraphUpdater),
|
112
|
+
ready_operations_filter: (
|
113
|
+
Callable[[Dispatcher, list[Operation]], list[Operation]] | None
|
114
|
+
) = filter_dominated_operations,
|
115
|
+
render_mode: str | None = None,
|
116
|
+
render_config: RenderConfig | None = None,
|
117
|
+
use_padding: bool = True,
|
118
|
+
) -> None:
|
119
|
+
"""Initializes the SingleJobShopGraphEnv environment.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
job_shop_graph:
|
123
|
+
The JobShopGraph instance representing the job shop problem.
|
124
|
+
feature_observer_configs:
|
125
|
+
A list of FeatureObserverConfig instances for the feature
|
126
|
+
observers.
|
127
|
+
reward_function_config:
|
128
|
+
The configuration for the reward function.
|
129
|
+
graph_updater_config:
|
130
|
+
The configuration for the graph updater.
|
131
|
+
ready_operations_filter:
|
132
|
+
The function to use for pruning dominated operations.
|
133
|
+
render_mode:
|
134
|
+
The mode for rendering the environment ("human", "save_video",
|
135
|
+
"save_gif").
|
136
|
+
render_config:
|
137
|
+
Configuration for rendering (e.g., paths for saving videos
|
138
|
+
or GIFs).
|
139
|
+
use_padding:
|
140
|
+
Whether to use padding for the edge index.
|
141
|
+
"""
|
142
|
+
super().__init__()
|
143
|
+
# Used for resetting the environment
|
144
|
+
self.initial_job_shop_graph = deepcopy(job_shop_graph)
|
145
|
+
|
146
|
+
self.dispatcher = Dispatcher(
|
147
|
+
job_shop_graph.instance,
|
148
|
+
ready_operations_filter=ready_operations_filter,
|
149
|
+
)
|
150
|
+
|
151
|
+
# Observers added to track the environment state
|
152
|
+
self.composite_observer = (
|
153
|
+
CompositeFeatureObserver.from_feature_observer_configs(
|
154
|
+
self.dispatcher, feature_observer_configs
|
155
|
+
)
|
156
|
+
)
|
157
|
+
self.graph_updater = graph_updater_config.class_type(
|
158
|
+
dispatcher=self.dispatcher,
|
159
|
+
job_shop_graph=job_shop_graph,
|
160
|
+
**graph_updater_config.kwargs,
|
161
|
+
)
|
162
|
+
self.reward_function = reward_function_config.class_type(
|
163
|
+
dispatcher=self.dispatcher, **reward_function_config.kwargs
|
164
|
+
)
|
165
|
+
self.action_space = gym.spaces.MultiDiscrete(
|
166
|
+
[self.instance.num_jobs, self.instance.num_machines], start=[0, -1]
|
167
|
+
)
|
168
|
+
self.observation_space: gym.spaces.Dict = self._get_observation_space()
|
169
|
+
self.render_mode = render_mode
|
170
|
+
if render_config is None:
|
171
|
+
render_config = {}
|
172
|
+
self.gantt_chart_creator = GanttChartCreator(
|
173
|
+
dispatcher=self.dispatcher, **render_config
|
174
|
+
)
|
175
|
+
self.use_padding = use_padding
|
176
|
+
|
177
|
+
@property
|
178
|
+
def instance(self) -> JobShopInstance:
|
179
|
+
"""Returns the instance the environment is working on."""
|
180
|
+
return self.job_shop_graph.instance
|
181
|
+
|
182
|
+
@property
|
183
|
+
def job_shop_graph(self) -> JobShopGraph:
|
184
|
+
"""Returns the job shop graph."""
|
185
|
+
return self.graph_updater.job_shop_graph
|
186
|
+
|
187
|
+
def _get_observation_space(self) -> gym.spaces.Dict:
|
188
|
+
"""Returns the observation space dictionary."""
|
189
|
+
num_edges = self.job_shop_graph.num_edges
|
190
|
+
dict_space: dict[str, gym.Space] = {
|
191
|
+
ObservationSpaceKey.REMOVED_NODES.value: gym.spaces.MultiBinary(
|
192
|
+
len(self.job_shop_graph.nodes)
|
193
|
+
),
|
194
|
+
ObservationSpaceKey.EDGE_INDEX.value: gym.spaces.MultiDiscrete(
|
195
|
+
np.full(
|
196
|
+
(2, num_edges),
|
197
|
+
fill_value=len(self.job_shop_graph.nodes) + 1,
|
198
|
+
dtype=np.int32,
|
199
|
+
),
|
200
|
+
start=np.full(
|
201
|
+
(2, num_edges),
|
202
|
+
fill_value=-1, # -1 is used for padding
|
203
|
+
dtype=np.int32,
|
204
|
+
),
|
205
|
+
),
|
206
|
+
}
|
207
|
+
for feature_type, matrix in self.composite_observer.features.items():
|
208
|
+
dict_space[feature_type.value] = gym.spaces.Box(
|
209
|
+
low=-np.inf, high=np.inf, shape=matrix.shape
|
210
|
+
)
|
211
|
+
return gym.spaces.Dict(dict_space)
|
212
|
+
|
213
|
+
def reset(
|
214
|
+
self,
|
215
|
+
*,
|
216
|
+
seed: int | None = None,
|
217
|
+
options: dict[str, Any] | None = None,
|
218
|
+
) -> tuple[ObservationDict, dict]:
|
219
|
+
"""Resets the environment."""
|
220
|
+
super().reset(seed=seed, options=options)
|
221
|
+
self.dispatcher.reset()
|
222
|
+
obs = self.get_observation()
|
223
|
+
return obs, {}
|
224
|
+
|
225
|
+
def step(
|
226
|
+
self, action: tuple[int, int]
|
227
|
+
) -> tuple[ObservationDict, float, bool, bool, dict[str, Any]]:
|
228
|
+
"""Takes a step in the environment.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
action:
|
232
|
+
The action to take. The action is a tuple of two integers
|
233
|
+
(job_id, machine_id):
|
234
|
+
the job ID and the machine ID in which to schedule the
|
235
|
+
operation.
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
A tuple containing the following elements:
|
239
|
+
- The observation of the environment.
|
240
|
+
- The reward obtained.
|
241
|
+
- Whether the environment is done.
|
242
|
+
- Whether the episode was truncated (always False).
|
243
|
+
- A dictionary with additional information. The dictionary
|
244
|
+
contains the following keys:
|
245
|
+
- "feature_names": The names of the features in the
|
246
|
+
observation.
|
247
|
+
- "available_operations": The operations that are ready to be
|
248
|
+
scheduled.
|
249
|
+
"""
|
250
|
+
job_id, machine_id = action
|
251
|
+
operation = self.dispatcher.next_operation(job_id)
|
252
|
+
if machine_id == -1:
|
253
|
+
machine_id = operation.machine_id
|
254
|
+
|
255
|
+
self.dispatcher.dispatch(operation, machine_id)
|
256
|
+
|
257
|
+
obs = self.get_observation()
|
258
|
+
reward = self.reward_function.last_reward
|
259
|
+
done = self.dispatcher.schedule.is_complete()
|
260
|
+
truncated = False
|
261
|
+
info: dict[str, Any] = {
|
262
|
+
"feature_names": self.composite_observer.column_names,
|
263
|
+
"available_operations": self.dispatcher.ready_operations(),
|
264
|
+
}
|
265
|
+
return obs, reward, done, truncated, info
|
266
|
+
|
267
|
+
def get_observation(self) -> ObservationDict:
|
268
|
+
"""Returns the current observation of the environment."""
|
269
|
+
observation: ObservationDict = {
|
270
|
+
ObservationSpaceKey.REMOVED_NODES.value: np.array(
|
271
|
+
self.job_shop_graph.removed_nodes, dtype=bool
|
272
|
+
),
|
273
|
+
ObservationSpaceKey.EDGE_INDEX.value: self._get_edge_index(),
|
274
|
+
}
|
275
|
+
for feature_type, matrix in self.composite_observer.features.items():
|
276
|
+
observation[feature_type.value] = matrix
|
277
|
+
return observation
|
278
|
+
|
279
|
+
def _get_edge_index(self) -> NDArray[np.int32]:
|
280
|
+
"""Returns the edge index matrix."""
|
281
|
+
edge_index = np.array(
|
282
|
+
self.job_shop_graph.graph.edges(), dtype=np.int32
|
283
|
+
).T
|
284
|
+
|
285
|
+
if self.use_padding:
|
286
|
+
output_shape = self.observation_space[
|
287
|
+
ObservationSpaceKey.EDGE_INDEX.value
|
288
|
+
].shape
|
289
|
+
assert output_shape is not None # For the type checker
|
290
|
+
edge_index = add_padding(
|
291
|
+
edge_index, output_shape=output_shape, dtype=np.int32
|
292
|
+
)
|
293
|
+
return edge_index
|
294
|
+
|
295
|
+
def render(self):
|
296
|
+
"""Renders the environment.
|
297
|
+
|
298
|
+
The rendering mode is set by the `render_mode` attribute:
|
299
|
+
|
300
|
+
- human: Renders the current Gannt chart.
|
301
|
+
- save_video: Saves a video of the Gantt chart. Used only if the
|
302
|
+
schedule is completed.
|
303
|
+
- save_gif: Saves a GIF of the Gantt chart. Used only if the schedule
|
304
|
+
is completed.
|
305
|
+
"""
|
306
|
+
if self.render_mode == "human":
|
307
|
+
self.gantt_chart_creator.plot_gantt_chart()
|
308
|
+
plt.show(block=False)
|
309
|
+
elif self.render_mode == "save_video":
|
310
|
+
self.gantt_chart_creator.create_video()
|
311
|
+
elif self.render_mode == "save_gif":
|
312
|
+
self.gantt_chart_creator.create_gif()
|
313
|
+
|
314
|
+
|
315
|
+
if __name__ == "__main__":
|
316
|
+
from job_shop_lib.dispatching.feature_observers import (
|
317
|
+
FeatureObserverType,
|
318
|
+
FeatureType,
|
319
|
+
)
|
320
|
+
from job_shop_lib.graphs import build_disjunctive_graph
|
321
|
+
from job_shop_lib.benchmarking import load_benchmark_instance
|
322
|
+
|
323
|
+
instance = load_benchmark_instance("ft06")
|
324
|
+
job_shop_graph_ = build_disjunctive_graph(instance)
|
325
|
+
feature_observer_configs_ = [
|
326
|
+
DispatcherObserverConfig(
|
327
|
+
FeatureObserverType.IS_READY,
|
328
|
+
kwargs={"feature_types": [FeatureType.JOBS]},
|
329
|
+
)
|
330
|
+
]
|
331
|
+
|
332
|
+
env = SingleJobShopGraphEnv(
|
333
|
+
job_shop_graph=job_shop_graph_,
|
334
|
+
feature_observer_configs=feature_observer_configs_,
|
335
|
+
render_mode="save_video",
|
336
|
+
render_config={"video_config": {"fps": 4}},
|
337
|
+
)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""Contains types and enumerations used in the reinforcement learning
|
2
|
+
module."""
|
3
|
+
|
4
|
+
from enum import Enum
|
5
|
+
from typing import TypedDict
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from job_shop_lib.dispatching.feature_observers import FeatureType
|
10
|
+
from job_shop_lib.visualization import (
|
11
|
+
GanttChartWrapperConfig,
|
12
|
+
GifConfig,
|
13
|
+
VideoConfig,
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
class RenderConfig(TypedDict, total=False):
|
18
|
+
"""Configuration needed to initialize the `GanttChartCreator` class."""
|
19
|
+
|
20
|
+
gantt_chart_wrapper_config: GanttChartWrapperConfig
|
21
|
+
video_config: VideoConfig
|
22
|
+
gif_config: GifConfig
|
23
|
+
|
24
|
+
|
25
|
+
class ObservationSpaceKey(str, Enum):
|
26
|
+
"""Enumeration of the keys for the observation space dictionary."""
|
27
|
+
|
28
|
+
REMOVED_NODES = "removed_nodes"
|
29
|
+
EDGE_INDEX = "edge_index"
|
30
|
+
OPERATIONS = FeatureType.OPERATIONS.value
|
31
|
+
JOBS = FeatureType.JOBS.value
|
32
|
+
MACHINES = FeatureType.MACHINES.value
|
33
|
+
|
34
|
+
|
35
|
+
class _ObservationDictRequired(TypedDict):
|
36
|
+
"""Required fields for the observation dictionary."""
|
37
|
+
|
38
|
+
removed_nodes: np.ndarray
|
39
|
+
edge_index: np.ndarray
|
40
|
+
|
41
|
+
|
42
|
+
class _ObservationDictOptional(TypedDict, total=False):
|
43
|
+
"""Optional fields for the observation dictionary."""
|
44
|
+
|
45
|
+
operations: np.ndarray
|
46
|
+
jobs: np.ndarray
|
47
|
+
machines: np.ndarray
|
48
|
+
|
49
|
+
|
50
|
+
class ObservationDict(_ObservationDictRequired, _ObservationDictOptional):
|
51
|
+
"""A dictionary containing the observation of the environment.
|
52
|
+
|
53
|
+
Required fields:
|
54
|
+
removed_nodes (np.ndarray): Binary vector indicating removed nodes.
|
55
|
+
edge_index (np.ndarray): Edge list in COO format.
|
56
|
+
|
57
|
+
Optional fields:
|
58
|
+
operations (np.ndarray): Matrix of operation features.
|
59
|
+
jobs (np.ndarray): Matrix of job features.
|
60
|
+
machines (np.ndarray): Matrix of machine features.
|
61
|
+
"""
|
@@ -0,0 +1,96 @@
|
|
1
|
+
"""Utility functions for reinforcement learning."""
|
2
|
+
|
3
|
+
from typing import TypeVar, Any
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from numpy.typing import NDArray
|
7
|
+
|
8
|
+
from job_shop_lib.exceptions import ValidationError
|
9
|
+
|
10
|
+
T = TypeVar("T", bound=np.number)
|
11
|
+
|
12
|
+
|
13
|
+
def add_padding(
|
14
|
+
array: NDArray[Any],
|
15
|
+
output_shape: tuple[int, ...],
|
16
|
+
padding_value: float = -1,
|
17
|
+
dtype: type[T] | None = None,
|
18
|
+
) -> NDArray[T]:
|
19
|
+
"""Adds padding to the array.
|
20
|
+
|
21
|
+
Pads the input array to the specified output shape with a given padding
|
22
|
+
value. If the ``dtype`` is not specified, the ``dtype`` of the input array
|
23
|
+
is used.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
array:
|
27
|
+
The input array to be padded.
|
28
|
+
output_shape:
|
29
|
+
The desired shape of the output array.
|
30
|
+
padding_value:
|
31
|
+
The value to use for padding. Defaults to -1.
|
32
|
+
dtype:
|
33
|
+
The data type for the output array. Defaults to ``None``, in which
|
34
|
+
case the dtype of the input array is used.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
The padded array with the specified output shape.
|
38
|
+
|
39
|
+
Raises:
|
40
|
+
ValidationError:
|
41
|
+
If the output shape is smaller than the input shape.
|
42
|
+
|
43
|
+
Examples:
|
44
|
+
|
45
|
+
.. doctest::
|
46
|
+
|
47
|
+
>>> array = np.array([[1, 2], [3, 4]])
|
48
|
+
>>> add_padding(array, (3, 3))
|
49
|
+
array([[ 1, 2, -1],
|
50
|
+
[ 3, 4, -1],
|
51
|
+
[-1, -1, -1]])
|
52
|
+
|
53
|
+
>>> add_padding(array, (3, 3), padding_value=0)
|
54
|
+
array([[1, 2, 0],
|
55
|
+
[3, 4, 0],
|
56
|
+
[0, 0, 0]])
|
57
|
+
|
58
|
+
>>> bool_array = np.array([[True, False], [False, True]])
|
59
|
+
>>> add_padding(bool_array, (3, 3), padding_value=False, dtype=int)
|
60
|
+
array([[1, 0, 0],
|
61
|
+
[0, 1, 0],
|
62
|
+
[0, 0, 0]])
|
63
|
+
|
64
|
+
>>> add_padding(bool_array, (3, 3), dtype=int)
|
65
|
+
array([[ 1, 0, -1],
|
66
|
+
[ 0, 1, -1],
|
67
|
+
[-1, -1, -1]])
|
68
|
+
"""
|
69
|
+
|
70
|
+
if np.any(np.less(output_shape, array.shape)):
|
71
|
+
raise ValidationError(
|
72
|
+
"Output shape must be greater than the input shape. "
|
73
|
+
f"Got output shape: {output_shape}, input shape: {array.shape}."
|
74
|
+
)
|
75
|
+
|
76
|
+
if dtype is None:
|
77
|
+
dtype = array.dtype.type
|
78
|
+
|
79
|
+
padded_array = np.full(
|
80
|
+
output_shape,
|
81
|
+
fill_value=padding_value,
|
82
|
+
dtype=dtype,
|
83
|
+
)
|
84
|
+
|
85
|
+
if array.size == 0:
|
86
|
+
return padded_array
|
87
|
+
|
88
|
+
slices = tuple(slice(0, dim) for dim in array.shape)
|
89
|
+
padded_array[slices] = array
|
90
|
+
return padded_array
|
91
|
+
|
92
|
+
|
93
|
+
if __name__ == "__main__":
|
94
|
+
import doctest
|
95
|
+
|
96
|
+
doctest.testmod()
|
@@ -1,25 +1,41 @@
|
|
1
1
|
"""Package for visualization."""
|
2
2
|
|
3
|
-
from job_shop_lib.visualization.
|
4
|
-
from job_shop_lib.visualization.
|
3
|
+
from job_shop_lib.visualization._gantt_chart import plot_gantt_chart
|
4
|
+
from job_shop_lib.visualization._gantt_chart_video_and_gif_creation import (
|
5
5
|
create_gif,
|
6
|
+
create_gantt_chart_video,
|
6
7
|
create_gantt_chart_frames,
|
7
8
|
plot_gantt_chart_wrapper,
|
9
|
+
create_video_from_frames,
|
8
10
|
create_gif_from_frames,
|
9
11
|
)
|
10
|
-
from job_shop_lib.visualization.
|
11
|
-
|
12
|
+
from job_shop_lib.visualization._disjunctive_graph import (
|
13
|
+
plot_disjunctive_graph,
|
14
|
+
)
|
15
|
+
from job_shop_lib.visualization._agent_task_graph import (
|
12
16
|
plot_agent_task_graph,
|
13
17
|
three_columns_layout,
|
14
18
|
)
|
19
|
+
from job_shop_lib.visualization._gantt_chart_creator import (
|
20
|
+
GanttChartCreator,
|
21
|
+
GanttChartWrapperConfig,
|
22
|
+
GifConfig,
|
23
|
+
VideoConfig,
|
24
|
+
)
|
15
25
|
|
16
26
|
__all__ = [
|
17
27
|
"plot_gantt_chart",
|
28
|
+
"create_gantt_chart_video",
|
18
29
|
"create_gif",
|
19
30
|
"create_gantt_chart_frames",
|
20
31
|
"plot_gantt_chart_wrapper",
|
21
32
|
"create_gif_from_frames",
|
33
|
+
"create_video_from_frames",
|
22
34
|
"plot_disjunctive_graph",
|
23
35
|
"plot_agent_task_graph",
|
24
36
|
"three_columns_layout",
|
37
|
+
"GanttChartCreator",
|
38
|
+
"GanttChartWrapperConfig",
|
39
|
+
"GifConfig",
|
40
|
+
"VideoConfig",
|
25
41
|
]
|
@@ -51,6 +51,7 @@ def plot_agent_task_graph(
|
|
51
51
|
|
52
52
|
# Create the networkx graph
|
53
53
|
graph = job_shop_graph.graph
|
54
|
+
nodes = job_shop_graph.non_removed_nodes()
|
54
55
|
|
55
56
|
# Create the layout if it was not provided
|
56
57
|
if layout is None:
|
@@ -64,16 +65,17 @@ def plot_agent_task_graph(
|
|
64
65
|
job_shop_graph.nodes_by_type[NodeType.MACHINE]
|
65
66
|
)
|
66
67
|
}
|
68
|
+
|
67
69
|
node_colors = [
|
68
70
|
_get_node_color(node, machine_colors) for node in job_shop_graph.nodes
|
69
|
-
]
|
71
|
+
] # We need to get the color of all nodes to avoid an index error
|
70
72
|
node_shapes = {"machine": "s", "job": "d", "operation": "o", "global": "o"}
|
71
73
|
|
72
74
|
# Draw nodes with different shapes based on their type
|
73
75
|
for node_type, shape in node_shapes.items():
|
74
76
|
current_nodes = [
|
75
77
|
node.node_id
|
76
|
-
for node in
|
78
|
+
for node in nodes
|
77
79
|
if node.node_type.name.lower() == node_type
|
78
80
|
]
|
79
81
|
nx.draw_networkx_nodes(
|
@@ -90,9 +92,7 @@ def plot_agent_task_graph(
|
|
90
92
|
# Draw edges
|
91
93
|
nx.draw_networkx_edges(graph, layout, ax=ax)
|
92
94
|
|
93
|
-
node_labels = {
|
94
|
-
node.node_id: _get_node_label(node) for node in job_shop_graph.nodes
|
95
|
-
}
|
95
|
+
node_labels = {node.node_id: _get_node_label(node) for node in nodes}
|
96
96
|
nx.draw_networkx_labels(graph, layout, node_labels, ax=ax)
|
97
97
|
|
98
98
|
ax.set_axis_off()
|
@@ -181,10 +181,29 @@ def three_columns_layout(
|
|
181
181
|
|
182
182
|
x_positions = _get_x_positions(leftmost_position, rightmost_position)
|
183
183
|
|
184
|
-
operation_nodes =
|
185
|
-
|
186
|
-
|
187
|
-
|
184
|
+
operation_nodes = [
|
185
|
+
node
|
186
|
+
for node in job_shop_graph.nodes_by_type[NodeType.OPERATION]
|
187
|
+
if not job_shop_graph.is_removed(node)
|
188
|
+
]
|
189
|
+
machine_nodes = [
|
190
|
+
node
|
191
|
+
for node in job_shop_graph.nodes_by_type[NodeType.MACHINE]
|
192
|
+
if not job_shop_graph.is_removed(node)
|
193
|
+
]
|
194
|
+
job_nodes = [
|
195
|
+
node
|
196
|
+
for node in job_shop_graph.nodes_by_type[NodeType.JOB]
|
197
|
+
if not job_shop_graph.is_removed(node)
|
198
|
+
]
|
199
|
+
global_nodes = [
|
200
|
+
node
|
201
|
+
for node in job_shop_graph.nodes_by_type[NodeType.GLOBAL]
|
202
|
+
if not job_shop_graph.is_removed(node)
|
203
|
+
]
|
204
|
+
|
205
|
+
# job_nodes = job_shop_graph.nodes_by_type[NodeType.JOB]
|
206
|
+
# global_nodes = job_shop_graph.nodes_by_type[NodeType.GLOBAL]
|
188
207
|
|
189
208
|
total_positions = len(operation_nodes) + len(global_nodes) * 2
|
190
209
|
y_spacing = (topmost_position - bottommost_position) / total_positions
|