likelihood 2.2.0.dev1__cp311-cp311-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,866 @@
1
+ import logging
2
+ import os
3
+
4
+ import networkx as nx
5
+ import pandas as pd
6
+
7
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
8
+ logging.getLogger("tensorflow").setLevel(logging.ERROR)
9
+
10
+ import sys
11
+ import warnings
12
+ from functools import wraps
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import seaborn as sns
17
+ import tensorflow as tf
18
+ import torch
19
+ from torch.utils.data import DataLoader, TensorDataset
20
+
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+
23
+ from .figures import *
24
+
25
+
26
+ class suppress_prints:
27
+ def __enter__(self):
28
+ self.original_stdout = sys.stdout
29
+ sys.stdout = open(os.devnull, "w")
30
+
31
+ def __exit__(self, exc_type, exc_value, traceback):
32
+ sys.stdout.close()
33
+ sys.stdout = self.original_stdout
34
+
35
+
36
+ def suppress_warnings(func):
37
+ @wraps(func)
38
+ def wrapper(*args, **kwargs):
39
+ with warnings.catch_warnings():
40
+ warnings.simplefilter("ignore")
41
+ return func(*args, **kwargs)
42
+
43
+ return wrapper
44
+
45
+
46
+ class TransformRange:
47
+ """
48
+ Generates a new DataFrame with ranges represented as strings.
49
+
50
+ Transforms numerical columns into categorical range bins with descriptive labels.
51
+ """
52
+
53
+ def __init__(self, columns_bin_sizes: Dict[str, int]) -> None:
54
+ """Initializes the class with the original DataFrame.
55
+
56
+ Parameters
57
+ ----------
58
+ columns_bin_sizes : `dict`
59
+ A dictionary where the keys are column names and the values are the bin sizes.
60
+
61
+ Raises
62
+ ------
63
+ TypeError
64
+ If df is not a pandas DataFrame.
65
+ """
66
+ self.info = {}
67
+ self.columns_bin_sizes = columns_bin_sizes
68
+
69
+ def _create_bins_and_labels(
70
+ self, min_val: Union[int, float], max_val: Union[int, float], bin_size: int
71
+ ) -> Tuple[np.ndarray, List[str]]:
72
+ """
73
+ Creates the bin edges and their labels.
74
+
75
+ Parameters
76
+ ----------
77
+ min_val : `int` or `float`
78
+ The minimum value for the range.
79
+ max_val : `int` or `float`
80
+ The maximum value for the range.
81
+ bin_size : `int`
82
+ The size of each bin.
83
+
84
+ Returns
85
+ -------
86
+ bins : `np.ndarray`
87
+ The bin edges.
88
+ labels : `list`
89
+ The labels for the bins.
90
+
91
+ Raises
92
+ ------
93
+ ValueError
94
+ If bin_size is not positive or if min_val >= max_val.
95
+ """
96
+ if bin_size <= 0:
97
+ raise ValueError("bin_size must be positive")
98
+ if min_val >= max_val:
99
+ raise ValueError("min_val must be less than max_val")
100
+
101
+ start = int(min_val)
102
+ end = int(max_val) + bin_size
103
+
104
+ bins = np.arange(start, end + 1, bin_size)
105
+
106
+ if bins[-1] <= max_val:
107
+ bins = np.append(bins, max_val + 1)
108
+
109
+ lower_bin_edge = -np.inf
110
+ upper_bin_edge = np.inf
111
+
112
+ labels = [f"{int(bins[i])}-{int(bins[i+1] - 1)}" for i in range(len(bins) - 1)]
113
+ end = int(bins[-1] - 1)
114
+ bins = bins.tolist()
115
+ bins.insert(0, lower_bin_edge)
116
+ bins.append(upper_bin_edge)
117
+ labels.insert(0, f"< {start}")
118
+ labels.append(f"> {end}")
119
+ return bins, labels
120
+
121
+ def _transform_column_to_ranges(
122
+ self, df: pd.DataFrame, column: str, bin_size: int, fit: bool = True
123
+ ) -> pd.Series:
124
+ """
125
+ Transforms a column in the DataFrame into range bins.
126
+
127
+ Parameters
128
+ ----------
129
+ df : `pd.DataFrame`
130
+ The original DataFrame to transform.
131
+ column : `str`
132
+ The name of the column to transform.
133
+ bin_size : `int`
134
+ The size of each bin.
135
+
136
+ Returns
137
+ -------
138
+ `pd.Series`
139
+ A Series with the range labels.
140
+
141
+ Raises
142
+ ------
143
+ KeyError
144
+ If column is not found in the DataFrame.
145
+ ValueError
146
+ If bin_size is not positive or if column contains non-numeric data.
147
+ """
148
+ if not isinstance(df, pd.DataFrame):
149
+ raise TypeError("df must be a pandas DataFrame")
150
+ df_ = df.copy() # Create a copy to avoid modifying the original
151
+ numeric_series = pd.to_numeric(df_[column], errors="coerce")
152
+ if fit:
153
+ self.df = df_.copy()
154
+ if column not in df_.columns:
155
+ raise KeyError(f"Column '{column}' not found in DataFrame")
156
+
157
+ if bin_size <= 0:
158
+ raise ValueError("bin_size must be positive")
159
+
160
+ if numeric_series.isna().all():
161
+ raise ValueError(f"Column '{column}' contains no valid numeric data")
162
+
163
+ min_val = numeric_series.min()
164
+ max_val = numeric_series.max()
165
+
166
+ if min_val == max_val:
167
+ return pd.Series(
168
+ [f"{int(min_val)}-{int(max_val)}"] * len(df_), name=f"{column}_range"
169
+ )
170
+ self.info[column] = {"min_value": min_val, "max_value": max_val, "range": bin_size}
171
+ else:
172
+ min_val = self.info[column]["min_value"]
173
+ max_val = self.info[column]["max_value"]
174
+ bin_size = self.info[column]["range"]
175
+
176
+ bins, labels = self._create_bins_and_labels(min_val, max_val, bin_size)
177
+ return pd.cut(numeric_series, bins=bins, labels=labels, right=False, include_lowest=True)
178
+
179
+ def transform(
180
+ self, df: pd.DataFrame, drop_original: bool = False, fit: bool = True
181
+ ) -> pd.DataFrame:
182
+ """
183
+ Creates a new DataFrame with range columns.
184
+
185
+ Parameters
186
+ ----------
187
+ df : `pd.DataFrame`
188
+ The original DataFrame to transform.
189
+ drop_original : `bool`, optional
190
+ If True, drops original columns from the result, by default False
191
+ fit : `bool`, default=True
192
+ Whether to compute bin edges based on the data (True) or use predefined binning (False).
193
+
194
+ Returns
195
+ -------
196
+ `pd.DataFrame`
197
+ A DataFrame with the transformed range columns.
198
+
199
+ Raises
200
+ ------
201
+ TypeError
202
+ If columns_bin_sizes is not a dictionary.
203
+ """
204
+ if not isinstance(self.columns_bin_sizes, dict):
205
+ raise TypeError("columns_bin_sizes must be a dictionary")
206
+
207
+ if not self.columns_bin_sizes:
208
+ return pd.DataFrame()
209
+
210
+ range_columns = {}
211
+ for column, bin_size in self.columns_bin_sizes.items():
212
+ range_columns[f"{column}_range"] = self._transform_column_to_ranges(
213
+ df, column, bin_size, fit
214
+ )
215
+
216
+ result_df = pd.DataFrame(range_columns)
217
+
218
+ if not drop_original:
219
+ original_cols = [col for col in df.columns if col not in self.columns_bin_sizes]
220
+ if original_cols:
221
+ result_df = pd.concat([df[original_cols], result_df], axis=1)
222
+
223
+ return result_df
224
+
225
+ def get_range_info(self, column: str) -> Dict[str, Union[int, float, List[str]]]:
226
+ """
227
+ Get information about the range transformation for a specific column.
228
+
229
+ Parameters
230
+ ----------
231
+ column : `str`
232
+ The name of the column to analyze.
233
+
234
+ Returns
235
+ -------
236
+ `dict`
237
+ Dictionary containing min_val, max_val, bin_size, and labels.
238
+ """
239
+ if column not in self.df.columns:
240
+ raise KeyError(f"Column '{column}' not found in DataFrame")
241
+
242
+ numeric_series = pd.to_numeric(self.df[column], errors="coerce")
243
+ min_val = numeric_series.min()
244
+ max_val = numeric_series.max()
245
+
246
+ return {
247
+ "min_value": min_val,
248
+ "max_value": max_val,
249
+ "range": max_val - min_val,
250
+ "column": column,
251
+ }
252
+
253
+
254
+ def remove_collinearity(df: pd.DataFrame, threshold: float = 0.9):
255
+ """
256
+ Removes highly collinear features from the DataFrame based on a correlation threshold.
257
+
258
+ This function calculates the correlation matrix of the DataFrame and removes columns
259
+ that are highly correlated with any other column in the DataFrame. It uses an absolute
260
+ correlation value greater than the specified threshold to identify which columns to drop.
261
+
262
+ Parameters
263
+ ----------
264
+ df : `pd.DataFrame`
265
+ The input DataFrame containing numerical data.
266
+ threshold : `float`
267
+ The correlation threshold above which features will be removed. Default is `0.9`.
268
+
269
+ Returns
270
+ -------
271
+ df_reduced : `pd.DataFrame`
272
+ A DataFrame with highly collinear features removed.
273
+ """
274
+ corr_matrix = df.corr().abs()
275
+ upper_triangle = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
276
+ to_drop = [
277
+ column for column in upper_triangle.columns if any(upper_triangle[column] > threshold)
278
+ ]
279
+ df_reduced = df.drop(columns=to_drop)
280
+
281
+ return df_reduced
282
+
283
+
284
+ def train_and_insights(
285
+ x_data: np.ndarray,
286
+ y_act: np.ndarray,
287
+ model: tf.keras.Model,
288
+ patience: int = 3,
289
+ reg: bool = False,
290
+ frac: float = 1.0,
291
+ **kwargs: Optional[Dict],
292
+ ) -> tf.keras.Model:
293
+ """
294
+ Train a Keras model and provide insights on the training and validation metrics.
295
+
296
+ Parameters
297
+ ----------
298
+ x_data : `np.ndarray`
299
+ Input data for training the model.
300
+ y_act : `np.ndarray`
301
+ Actual labels corresponding to x_data.
302
+ model : `tf.keras.Model`
303
+ The Keras model to train.
304
+ patience : `int`
305
+ The patience parameter for early stopping callback (default is 3).
306
+ reg : `bool`
307
+ Flag to determine if residual analysis should be performed (default is `False`).
308
+ frac : `float`
309
+ Fraction of data to use (default is 1.0).
310
+
311
+ Keyword Arguments
312
+ -----------------
313
+ Additional keyword arguments passed to the `model.fit` function, such as validation split and callbacks.
314
+
315
+ Returns
316
+ -------
317
+ `tf.keras.Model`
318
+ The trained model after fitting.
319
+ """
320
+ validation_split = kwargs.get("validation_split", 0.2)
321
+ callback = kwargs.get(
322
+ "callback", [tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=patience)]
323
+ )
324
+
325
+ for key in ["validation_split", "callback"]:
326
+ if key in kwargs:
327
+ del kwargs[key]
328
+
329
+ history = model.fit(
330
+ x_data,
331
+ y_act,
332
+ validation_split=validation_split,
333
+ verbose=False,
334
+ callbacks=callback,
335
+ **kwargs,
336
+ )
337
+
338
+ hist = pd.DataFrame(history.history)
339
+ hist["epoch"] = history.epoch
340
+
341
+ columns = hist.columns
342
+ train_err, train_metric = columns[0], columns[1]
343
+ val_err, val_metric = columns[2], columns[3]
344
+ train_err, val_err = hist[train_err].values, hist[val_err].values
345
+
346
+ with suppress_prints():
347
+ n = int(len(x_data) * frac)
348
+ y_pred = model.predict(x_data[:n])
349
+ y_act = y_act[:n]
350
+
351
+ if reg:
352
+ residual(y_act, y_pred)
353
+ residual_hist(y_act, y_pred)
354
+ act_pred(y_act, y_pred)
355
+
356
+ loss_curve(hist["epoch"].values, train_err, val_err)
357
+
358
+ return model
359
+
360
+
361
+ @tf.keras.utils.register_keras_serializable(package="Custom", name="LoRALayer")
362
+ class LoRALayer(tf.keras.layers.Layer):
363
+ def __init__(self, units, rank=4, **kwargs):
364
+ super(LoRALayer, self).__init__(**kwargs)
365
+ self.units = units
366
+ self.rank = rank
367
+
368
+ def build(self, input_shape):
369
+ input_dim = input_shape[-1]
370
+ print(f"Input shape: {input_shape}")
371
+
372
+ if self.rank > input_dim:
373
+ raise ValueError(
374
+ f"Rank ({self.rank}) cannot be greater than input dimension ({input_dim})."
375
+ )
376
+ if self.rank > self.units:
377
+ raise ValueError(
378
+ f"Rank ({self.rank}) cannot be greater than number of units ({self.units})."
379
+ )
380
+
381
+ self.A = self.add_weight(
382
+ shape=(input_dim, self.rank), initializer="random_normal", trainable=True, name="A"
383
+ )
384
+ self.B = self.add_weight(
385
+ shape=(self.rank, self.units), initializer="random_normal", trainable=True, name="B"
386
+ )
387
+ print(f"Dense weights shape: {input_dim}x{self.units}")
388
+ print(f"LoRA weights shape: A{self.A.shape}, B{self.B.shape}")
389
+
390
+ def call(self, inputs):
391
+ lora_output = tf.matmul(tf.matmul(inputs, self.A), self.B)
392
+ return lora_output
393
+
394
+
395
+ def apply_lora(model, rank=4):
396
+ inputs = tf.keras.Input(shape=model.input_shape[1:])
397
+ x = inputs
398
+
399
+ for layer in model.layers:
400
+ if isinstance(layer, tf.keras.layers.Dense):
401
+ print(f"Applying LoRA to layer {layer.name}")
402
+ x = LoRALayer(units=layer.units, rank=rank)(x)
403
+ else:
404
+ x = layer(x)
405
+ new_model = tf.keras.Model(inputs=inputs, outputs=x)
406
+ return new_model
407
+
408
+
409
+ def graph_metrics(adj_matrix: np.ndarray, eigenvector_threshold: float = 1e-6) -> pd.DataFrame:
410
+ """
411
+ Calculate various graph metrics based on the given adjacency matrix and return them in a single DataFrame.
412
+
413
+ Parameters
414
+ ----------
415
+ adj_matrix : `np.ndarray`
416
+ The adjacency matrix representing the graph, where each element denotes the presence/weight of an edge between nodes.
417
+ eigenvector_threshold : `float`
418
+ A threshold for the eigenvector centrality calculation, used to determine the cutoff for small eigenvalues. Default is `1e-6`.
419
+
420
+ Returns
421
+ -------
422
+ metrics_df : pd.DataFrame
423
+ A DataFrame containing the following graph metrics as columns.
424
+ - `Degree`: The degree of each node, representing the number of edges connected to each node.
425
+ - `DegreeCentrality`: Degree centrality values for each node, indicating the number of direct connections each node has.
426
+ - `ClusteringCoefficient`: Clustering coefficient values for each node, representing the degree to which nodes cluster together.
427
+ - `EigenvectorCentrality`: Eigenvector centrality values, indicating the influence of a node in the graph based on the eigenvectors of the adjacency matrix.
428
+ - `BetweennessCentrality`: Betweenness centrality values, representing the extent to which a node lies on the shortest paths between other nodes.
429
+ - `ClosenessCentrality`: Closeness centrality values, indicating the inverse of the average shortest path distance from a node to all other nodes in the graph.
430
+ - `Assortativity`: The assortativity coefficient of the graph, measuring the tendency of nodes to connect to similar nodes.
431
+
432
+ Notes
433
+ -----
434
+ The returned DataFrame will have one row for each node and one column for each of the computed metrics.
435
+ """
436
+ adj_matrix = adj_matrix.astype(int)
437
+ G = nx.from_numpy_array(adj_matrix)
438
+ degree_centrality = nx.degree_centrality(G)
439
+ clustering_coeff = nx.clustering(G)
440
+ try:
441
+ eigenvector_centrality = nx.eigenvector_centrality(G, max_iter=500)
442
+ except nx.PowerIterationFailedConvergence:
443
+ print("Power iteration failed to converge. Returning NaN for eigenvector centrality.")
444
+ eigenvector_centrality = {node: float("nan") for node in G.nodes()}
445
+
446
+ for node, centrality in eigenvector_centrality.items():
447
+ if centrality < eigenvector_threshold:
448
+ eigenvector_centrality[node] = 0.0
449
+ degree = dict(G.degree())
450
+ betweenness_centrality = nx.betweenness_centrality(G)
451
+ closeness_centrality = nx.closeness_centrality(G)
452
+ assortativity = nx.degree_assortativity_coefficient(G)
453
+ metrics_df = pd.DataFrame(
454
+ {
455
+ "Degree": degree,
456
+ "DegreeCentrality": degree_centrality,
457
+ "ClusteringCoefficient": clustering_coeff,
458
+ "EigenvectorCentrality": eigenvector_centrality,
459
+ "BetweennessCentrality": betweenness_centrality,
460
+ "ClosenessCentrality": closeness_centrality,
461
+ }
462
+ )
463
+ metrics_df["Assortativity"] = assortativity
464
+
465
+ return metrics_df
466
+
467
+
468
+ def print_trajectory_info(state, selected_option, action, reward, next_state, terminate, done):
469
+ print("=" * 50)
470
+ print("TRAJECTORY INFO".center(50, "="))
471
+ print("=" * 50)
472
+ print(f"State: {state}")
473
+ print(f"Selected Option: {selected_option}")
474
+ print(f"Action: {action}")
475
+ print(f"Reward: {reward}")
476
+ print(f"Next State: {next_state}")
477
+ print(f"Terminate: {terminate}")
478
+ print(f"Done: {done}")
479
+ print("=" * 50)
480
+
481
+
482
+ def collect_experience(
483
+ env: Any,
484
+ model: torch.nn.Module,
485
+ gamma: float = 0.99,
486
+ lambda_parameter: float = 0.95,
487
+ penalty_for_done_state: float = -1.0,
488
+ tolerance: int = float("inf"),
489
+ verbose: bool = False,
490
+ ) -> tuple[List[tuple], List[float], List[float], List[float]]:
491
+ """Gathers experience samples from an environment using a reinforcement learning model.
492
+
493
+ Parameters
494
+ ----------
495
+ env : `Any`
496
+ The environment to collect experience from.
497
+ model : `torch.nn.Module`
498
+ The reinforcement learning model (e.g., a torch neural network).
499
+ gamma : float, optional
500
+ Discount factor for future rewards, default=0.99.
501
+ lambda_parameter : float, optional
502
+ TD error correction parameter, default=0.95.
503
+ penalty_for_done_state : float, optional
504
+ Penalty applied to the state when the environment reaches a terminal state, default=-1.0.
505
+
506
+ Returns
507
+ -------
508
+ trajectory : `list[tuple]`
509
+ The return trajectory (state, selected_option, action, reward, next_state, terminate, done).
510
+ returns : `list[float]`
511
+ The list of cumulative returns.
512
+ advantages : `list[float]`
513
+ The list of advantage terms for each step.
514
+ old_probs : `list[float]`
515
+ The list of old policy probabilities for each step.
516
+ """
517
+ state = env.reset()
518
+ done = False
519
+ trajectory = []
520
+ old_probs = []
521
+ tolerance_count = 0
522
+
523
+ while not done and tolerance_count < tolerance:
524
+ state = state[0] if isinstance(state, tuple) else state
525
+ state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
526
+
527
+ option_probs, action_probs, termination_probs, selected_option, action = model(state_tensor)
528
+
529
+ action = torch.multinomial(action_probs, 1).item()
530
+ option = torch.multinomial(option_probs, 1).item()
531
+ old_probs.append(action_probs[0, action].item())
532
+
533
+ terminate = torch.bernoulli(termination_probs).item() > 0.5
534
+ signature = env.step.__code__
535
+ if signature.co_argcount > 2:
536
+ next_state, reward, done, truncated, info = env.step(action, option)
537
+ else:
538
+ next_state, reward, done, truncated, info = env.step(action)
539
+
540
+ if done:
541
+ reward = penalty_for_done_state
542
+ tolerance_count += 1
543
+ trajectory.append((state, selected_option, action, reward, next_state, terminate, done))
544
+ state = next_state
545
+ if verbose:
546
+ print_trajectory_info(
547
+ state, selected_option, action, reward, next_state, terminate, done
548
+ )
549
+
550
+ returns = []
551
+ advantages = []
552
+ G = 0
553
+ delta = 0
554
+
555
+ for t in reversed(range(len(trajectory))):
556
+ state, selected_option, action, reward, next_state, terminate, done = trajectory[t]
557
+
558
+ state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
559
+ _, action_probs, _, _, _ = model(state_tensor)
560
+
561
+ if t == len(trajectory) - 1:
562
+ G = reward
563
+ advantages.insert(0, G - action_probs[0, action].item())
564
+ else:
565
+ next_state_tensor = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
566
+ _, next_action_probs, _, _, _ = model(next_state_tensor)
567
+
568
+ delta = (
569
+ reward
570
+ + gamma * next_action_probs[0, action].item()
571
+ - action_probs[0, action].item()
572
+ )
573
+ G = reward + gamma * G
574
+ advantages.insert(
575
+ 0, delta + gamma * lambda_parameter * advantages[0] if advantages else delta
576
+ )
577
+
578
+ returns.insert(0, G)
579
+
580
+ return trajectory, returns, advantages, old_probs
581
+
582
+
583
+ def ppo_loss(
584
+ advantages: torch.Tensor,
585
+ old_action_probs: torch.Tensor,
586
+ action_probs: torch.Tensor,
587
+ epsilon: float = 0.2,
588
+ ):
589
+ """Computes the Proximal Policy Optimization (PPO) loss using the clipped objective.
590
+
591
+ Parameters
592
+ ----------
593
+ advantages : `torch.Tensor`
594
+ The advantages (delta) for each action taken, calculated as the difference between returns and value predictions.
595
+ old_action_probs : `torch.Tensor`
596
+ The action probabilities from the previous policy (before the current update).
597
+ action_probs : `torch.Tensor`
598
+ The action probabilities from the current policy (after the update).
599
+ epsilon : `float`, optional, default=0.2
600
+ The clipping parameter that limits how much the policy can change between updates.
601
+
602
+ Returns
603
+ -------
604
+ loss : `torch.Tensor`
605
+ The PPO loss, averaged across the batch of samples. The loss is computed using the clipped objective to penalize large policy updates.
606
+ """
607
+
608
+ if advantages.dim() == 1:
609
+ advantages = advantages.unsqueeze(-1)
610
+
611
+ log_ratio = torch.log(action_probs + 1e-8) - torch.log(old_action_probs + 1e-8)
612
+ ratio = torch.exp(log_ratio) # π(a|s) / π_old(a|s)
613
+
614
+ clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
615
+
616
+ loss = -torch.min(ratio * advantages, clipped_ratio * advantages)
617
+
618
+ return loss.mean()
619
+
620
+
621
+ def train_option_critic(
622
+ model: torch.nn.Module,
623
+ optimizer: torch.optim.Optimizer,
624
+ env: Any,
625
+ num_epochs: int = 1_000,
626
+ batch_size: int = 32,
627
+ device: str = "cpu",
628
+ beta: float = 1e-2,
629
+ epsilon: float = 0.2,
630
+ patience: int = 15,
631
+ verbose: bool = False,
632
+ **kwargs,
633
+ ) -> tuple[torch.nn.Module, float]:
634
+ """Trains an option critic model with the provided environment and hyperparameters.
635
+
636
+ Parameters
637
+ ----------
638
+ model : `nn.Module`
639
+ The neural network model to train.
640
+ optimizer : `torch.optim.Optimizer`
641
+ The optimizer for model updates.
642
+ env : `Any`
643
+ The environment for training.
644
+ num_epochs : `int`
645
+ Number of training epochs.
646
+ batch_size : `int`
647
+ Batch size per training step.
648
+ device : `str`
649
+ Target device (e.g., "cpu" or "cuda").
650
+ beta : `float`
651
+ Critic learning rate hyperparameter.
652
+ epsilon : `float`, optional, default=0.2
653
+ The clipping parameter that limits how much the policy can change between updates.
654
+ patience : `int`
655
+ Early stopping patience in epochs.
656
+
657
+ Returns
658
+ -------
659
+ model : `nn.Module`
660
+ Trained model.
661
+ avg_epoch_loss : `float`
662
+ Average loss per epoch over training.
663
+ """
664
+ losses = []
665
+ best_loss_so_far = float("inf")
666
+ best_advantage_so_far = 0.0
667
+ patience_counter = 0
668
+ patience_counter_advantage = 0
669
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
670
+ advantages_per_epoch = []
671
+
672
+ for epoch in range(num_epochs):
673
+ trajectory, returns, advantages, old_probs = collect_experience(env, model, **kwargs)
674
+ avg_advantage = sum(advantages) / len(advantages)
675
+ advantages_per_epoch.append(avg_advantage)
676
+
677
+ states = torch.tensor(np.array([t[0] for t in trajectory]), dtype=torch.float32).to(device)
678
+ actions = torch.tensor([t[2] for t in trajectory], dtype=torch.long).to(device)
679
+ returns_tensor = torch.tensor(returns, dtype=torch.float32).to(device)
680
+ advantages_tensor = torch.tensor(advantages, dtype=torch.float32).to(device)
681
+ old_probs_tensor = torch.tensor(old_probs, dtype=torch.float32).view(-1, 1).to(device)
682
+
683
+ dataset = TensorDataset(
684
+ states, actions, returns_tensor, advantages_tensor, old_probs_tensor
685
+ )
686
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
687
+
688
+ epoch_loss = 0
689
+ num_batches = 0
690
+
691
+ for (
692
+ batch_states,
693
+ batch_actions,
694
+ batch_returns,
695
+ batch_advantages,
696
+ batch_old_probs,
697
+ ) in dataloader:
698
+ optimizer.zero_grad()
699
+
700
+ option_probs, action_probs, termination_probs, selected_option, action = model(
701
+ batch_states
702
+ )
703
+
704
+ batch_current_probs = action_probs.gather(1, batch_actions.unsqueeze(1))
705
+ ppo_loss_value = ppo_loss(
706
+ batch_advantages, batch_old_probs, batch_current_probs, epsilon=epsilon
707
+ )
708
+
709
+ entropy = -torch.sum(action_probs * torch.log(action_probs + 1e-8), dim=-1)
710
+
711
+ loss = ppo_loss_value + beta * entropy.mean()
712
+ avg_advantages = batch_advantages.mean().item()
713
+ loss.backward()
714
+ optimizer.step()
715
+
716
+ epoch_loss += loss.item()
717
+ num_batches += 1
718
+
719
+ if avg_advantages > best_advantage_so_far:
720
+ best_advantage_so_far = avg_advantages
721
+ patience_counter_advantage = 0
722
+ else:
723
+ patience_counter_advantage += 1
724
+ if patience_counter_advantage >= patience:
725
+ if verbose:
726
+ print(
727
+ f"Early stopping at epoch {epoch} after {patience} epochs without advantage improvement."
728
+ )
729
+ break
730
+
731
+ epoch_loss += loss.item()
732
+ num_batches += 1
733
+
734
+ if num_batches > 0:
735
+ avg_epoch_loss = epoch_loss / num_batches
736
+ losses.append(avg_epoch_loss)
737
+
738
+ if avg_epoch_loss < best_loss_so_far:
739
+ best_loss_so_far = avg_epoch_loss
740
+ patience_counter = 0
741
+ else:
742
+ patience_counter += 1
743
+
744
+ if patience_counter >= patience:
745
+ if verbose:
746
+ print(
747
+ f"Early stopping at epoch {epoch} after {patience} epochs without improvement."
748
+ )
749
+ break
750
+
751
+ if verbose:
752
+ if epoch % (num_epochs // 10) == 0:
753
+ print(f"Epoch {epoch}/{num_epochs} - Avg Loss: {avg_epoch_loss:.4f}")
754
+
755
+ return model, avg_epoch_loss, advantages_per_epoch
756
+
757
+
758
+ def train_model_with_episodes(
759
+ model: torch.nn.Module,
760
+ optimizer: torch.optim.Optimizer,
761
+ env: Any,
762
+ num_episodes: int,
763
+ episode_patience: int = 5,
764
+ **kwargs,
765
+ ):
766
+ """Trains a model via reinforcement learning episodes.
767
+
768
+ Parameters
769
+ ----------
770
+ model : `torch.nn.Module`
771
+ The model to be trained.
772
+ optimizer : `torch.optim.Optimizer`
773
+ The optimizer for the model.
774
+ env : `Any`
775
+ The environment used for the episodes.
776
+ num_episodes : `int`
777
+ The number of episodes to train.
778
+
779
+ Keyword Arguments
780
+ -----------------
781
+ Additional keyword arguments passed to the `train_option_critic` function.
782
+
783
+ num_epochs : `int`
784
+ Number of training epochs.
785
+ batch_size : `int`
786
+ Batch size per training step.
787
+ gamma : `float`
788
+ Discount factor for future rewards.
789
+ device : `str`
790
+ Target device (e.g., "cpu" or "cuda").
791
+ beta : `float`
792
+ Critic learning rate hyperparameter.
793
+ patience : `int`
794
+ Early stopping patience in epochs.
795
+
796
+ Returns
797
+ -------
798
+ model : `torch.nn.Module`
799
+ The trained model.
800
+ best_loss_so_far : `float`
801
+ The best loss value observed during training.
802
+ """
803
+ previous_weights = model.state_dict()
804
+ best_loss_so_far = float("inf")
805
+ loss_window = []
806
+ average_loss = 0.0
807
+ no_improvement_count = 0
808
+
809
+ print(f"{'Episode':<12} {'Loss':<8} {'Best Loss':<17} {'Status':<15} {'Avg Loss':<4}")
810
+ print("=" * 70)
811
+
812
+ NEW_BEST_COLOR = "\033[92m"
813
+ REVERT_COLOR = "\033[91m"
814
+ RESET_COLOR = "\033[0m"
815
+ advantages_per_episode = []
816
+
817
+ for episode in range(num_episodes):
818
+ model, loss, advantages = train_option_critic(model, optimizer, env, **kwargs)
819
+ advantages_per_episode.extend(advantages)
820
+
821
+ loss_window.append(loss)
822
+ average_loss = sum(loss_window) / len(loss_window)
823
+
824
+ if loss < best_loss_so_far:
825
+ best_loss_so_far = loss
826
+ previous_weights = model.state_dict()
827
+ no_improvement_count = 0
828
+ status = f"{NEW_BEST_COLOR}Updated{RESET_COLOR}"
829
+ else:
830
+ model.load_state_dict(previous_weights)
831
+ no_improvement_count += 1
832
+ status = f"{REVERT_COLOR}No Improvement{RESET_COLOR}"
833
+ print(
834
+ f"{episode + 1:<8} {loss:<12.4f} {best_loss_so_far:<15.4f} {status:<25} {average_loss:<12.4f}"
835
+ )
836
+ print("=" * 70)
837
+
838
+ if no_improvement_count >= episode_patience:
839
+ print(f"\nNo improvement for {episode_patience} episodes. Stopping early.")
840
+ break
841
+
842
+ print(f"\nTraining complete. Final best loss: {best_loss_so_far:.4f}")
843
+ sns.set_theme(style="whitegrid")
844
+ plt.figure(figsize=(5, 3))
845
+ plt.plot(
846
+ range(len(advantages_per_episode)),
847
+ advantages_per_episode,
848
+ marker=None,
849
+ markersize=6,
850
+ color=sns.color_palette("deep")[0],
851
+ linestyle="-",
852
+ linewidth=2,
853
+ )
854
+ plt.xscale("log")
855
+
856
+ plt.xlabel("Epoch", fontsize=12)
857
+ plt.ylabel("Average Advantages", fontsize=12)
858
+ plt.grid(True, which="both", axis="both", linestyle="--", linewidth=0.5, alpha=0.6)
859
+ plt.tight_layout()
860
+ plt.show()
861
+
862
+ return model, best_loss_so_far
863
+
864
+
865
+ if __name__ == "__main__":
866
+ pass