pyotc 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pyotc/__init__.py +5 -0
  2. pyotc/examples/__init__.py +0 -0
  3. pyotc/examples/edge_awareness.py +86 -0
  4. pyotc/examples/lollipops.py +54 -0
  5. pyotc/examples/stochastic_block_model.py +57 -0
  6. pyotc/examples/wheel.py +127 -0
  7. pyotc/otc.py +5 -0
  8. pyotc/otc_backend/__init__.py +0 -0
  9. pyotc/otc_backend/graph/__init__.py +3 -0
  10. pyotc/otc_backend/graph/utils.py +109 -0
  11. pyotc/otc_backend/optimal_transport/__init__.py +0 -0
  12. pyotc/otc_backend/optimal_transport/logsinkhorn.py +78 -0
  13. pyotc/otc_backend/optimal_transport/native.py +49 -0
  14. pyotc/otc_backend/optimal_transport/native_refactor.py +51 -0
  15. pyotc/otc_backend/optimal_transport/pot.py +38 -0
  16. pyotc/otc_backend/policy_iteration/__init__.py +0 -0
  17. pyotc/otc_backend/policy_iteration/dense/__init__.py +0 -0
  18. pyotc/otc_backend/policy_iteration/dense/approx_tce.py +42 -0
  19. pyotc/otc_backend/policy_iteration/dense/entropic.py +161 -0
  20. pyotc/otc_backend/policy_iteration/dense/entropic_tci.py +49 -0
  21. pyotc/otc_backend/policy_iteration/dense/exact.py +127 -0
  22. pyotc/otc_backend/policy_iteration/dense/exact_tce.py +56 -0
  23. pyotc/otc_backend/policy_iteration/dense/exact_tci_lp.py +65 -0
  24. pyotc/otc_backend/policy_iteration/dense/exact_tci_pot.py +90 -0
  25. pyotc/otc_backend/policy_iteration/sparse/__init__.py +0 -0
  26. pyotc/otc_backend/policy_iteration/sparse/exact.py +89 -0
  27. pyotc/otc_backend/policy_iteration/sparse/exact_tce.py +78 -0
  28. pyotc/otc_backend/policy_iteration/sparse/exact_tci.py +88 -0
  29. pyotc/otc_backend/policy_iteration/utils.py +112 -0
  30. pyotc-0.2.2.dist-info/METADATA +38 -0
  31. pyotc-0.2.2.dist-info/RECORD +34 -0
  32. pyotc-0.2.2.dist-info/WHEEL +4 -0
  33. pyotc-0.2.2.dist-info/licenses/AUTHORS.rst +12 -0
  34. pyotc-0.2.2.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,42 @@
1
+ import numpy as np
2
+
3
+
4
+ def approx_tce(P, c, L, T):
5
+ """
6
+ Approximates the Transition Coupling Evaluation (TCE) vectors g and h
7
+ using a truncation-based approximation of the exact TCE method.
8
+
9
+ Args:
10
+ P (np.ndarray): Transition matrix of shape (dx*dy, dx*dy).
11
+ c (np.ndarray): Cost vector of shape (dx*dy,) or (dx*dy, 1).
12
+ L (int): Maximum number of iterations for computing the cost vector g.
13
+ T (int): Maximum number of iterations for computing the bias vector h.
14
+
15
+ Returns:
16
+ g (np.ndarray): Approximated average cost (gain) vector of shape (dx*dy,).
17
+ h (np.ndarray): Approximated bias vector of shape (dx*dy,).
18
+ """
19
+
20
+ d = P.shape[0]
21
+ c = np.reshape(c, (d, -1))
22
+ c_max = np.max(c)
23
+
24
+ g_old = c
25
+ g = P @ g_old
26
+ l = 1
27
+ tol = 1e-12
28
+ while l <= L and np.max(np.abs(g - g_old)) > tol * c_max:
29
+ g_old = g
30
+ g = P @ g_old
31
+ l += 1
32
+
33
+ g = np.mean(g) * np.ones((d, 1))
34
+ diff = c - g
35
+ h = diff.copy()
36
+ t = 1
37
+ while t <= T and np.max(np.abs(P @ diff)) > tol * c_max:
38
+ h += P @ diff
39
+ diff = P @ diff
40
+ t += 1
41
+
42
+ return g, h
@@ -0,0 +1,161 @@
1
+ """
2
+ Entropic Optimal Transition Coupling (OTC) solvers.
3
+
4
+ Implements variants of the OTC algorithm using entropic regularization.
5
+ Includes both a custom Sinkhorn implementation and one based on the POT library.
6
+
7
+ References:
8
+ - Section 5, "Optimal Transport for Stationary Markov Chains via Policy Iteration"
9
+ (https://www.jmlr.org/papers/volume23/21-0519/21-0519.pdf)
10
+
11
+ Methods:
12
+ - logsinkhorn: A self-implemented log-scaled Sinkhorn solver.
13
+ - ot_sinkhorn: Sinkhorn solver from POT library.
14
+ (reference: https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.sinkhorn)
15
+ - ot_logsinkhorn: Sinkhorn solver from POT library in log scale.
16
+ (reference: https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.sinkhorn_log)
17
+ - ot_greenkhorn: Sinkhorn solver of greedy version from POT library.
18
+ (reference: https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.greenkhorn)
19
+ """
20
+
21
+ import time
22
+ import numpy as np
23
+ import ot
24
+ from ..utils import get_best_stat_dist
25
+ from .approx_tce import approx_tce
26
+ from .entropic_tci import entropic_tci
27
+ from pyotc.otc_backend.optimal_transport.logsinkhorn import logsinkhorn
28
+
29
+
30
+ def entropic_otc(
31
+ Px,
32
+ Py,
33
+ c,
34
+ L=100,
35
+ T=100,
36
+ xi=0.1,
37
+ method="logsinkhorn",
38
+ sink_iter=100,
39
+ reg_num=None,
40
+ get_sd=False,
41
+ silent=True,
42
+ ):
43
+ """
44
+ Solves the Entropic Optimal Transition Coupling (OTC) problem between two Markov chains
45
+ using approximate policy iteration and entropic regularization.
46
+
47
+ This method alternates between approximate coupling evaluation
48
+ and entropic coupling improvement (via Sinkhorn iterations), until convergence.
49
+
50
+ Args:
51
+ Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx).
52
+ Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy).
53
+ c (np.ndarray): Cost function of shape (dx, dy).
54
+ L (int): Number of iterations for computing the cost vector g in approx_tce.
55
+ T (int): Number of iterations for computing the bias vector h in approx_tce.
56
+ xi (float): Scaling factor for entropic cost adjustment in entropic_tci.
57
+ method (str): Method for the Sinkhorn algorithm. Must choose from ['logsinkhorn', 'ot_sinkhorn', 'ot_logsinkhorn', 'ot_greenkhorn']. Default is 'logsinkhorn'. See 'Methods' above for details.
58
+ sink_iter (int): Number of iterations for 'logsinkhorn' method. Maximum number of Sinkhorn iterations for other methods from POT library. Used in the entropic TCI step.
59
+ reg_num (float): Entropic regularization term, used only for methods from POT package.
60
+ get_sd (bool): If True, compute best stationary distribution using linear programming.
61
+ silent (bool): If False, print convergence info during iterations and running time
62
+
63
+ Returns:
64
+ exp_cost (float): Expected transport cost under the optimal transition coupling.
65
+ P (np.ndarray): Optimal transition coupling matrix of shape (dx*dy, dx*dy).
66
+ stat_dist (Optional[np.ndarray]): Stationary distribution of the optimal transition coupling of shape (dx, dy),
67
+ or None if get_sd is False.
68
+ """
69
+ if not silent:
70
+ start_time = time.time()
71
+ print(f"Starting entropic otc with {method} method...")
72
+
73
+ dx, dy = Px.shape[0], Py.shape[0]
74
+ max_c = np.max(c)
75
+ tol = 1e-5 * max_c
76
+
77
+ g_old = max_c * np.ones(dx * dy)
78
+ g = g_old - 10 * tol
79
+ P = np.kron(Px, Py)
80
+
81
+ if method == "logsinkhorn":
82
+
83
+ def solver_fn(A, a, b):
84
+ return logsinkhorn(A, a, b, sink_iter)
85
+ elif method == "ot_sinkhorn":
86
+ if reg_num is None:
87
+ raise ValueError("reg_num must be specified for 'ot_sinkhorn'")
88
+
89
+ def solver_fn(A, a, b):
90
+ return ot.sinkhorn(a, b, A, reg=reg_num, numItermax=sink_iter)
91
+ elif method == "ot_logsinkhorn":
92
+ if reg_num is None:
93
+ raise ValueError("reg_num must be specified for 'ot_logsinkhorn'")
94
+
95
+ def solver_fn(A, a, b):
96
+ return ot.bregman.sinkhorn_log(a, b, A, reg=reg_num, numItermax=sink_iter)
97
+ elif method == "ot_greenkhorn":
98
+ if reg_num is None:
99
+ raise ValueError("reg_num must be specified for 'ot_greenkhorn'")
100
+
101
+ def solver_fn(A, a, b):
102
+ return ot.bregman.greenkhorn(a, b, A, reg=reg_num, numItermax=sink_iter)
103
+ else:
104
+ raise ValueError(f"Unknown method: {method}")
105
+
106
+ iter_ctr = 0
107
+ while g_old[0] - g[0] > tol:
108
+ iter_ctr += 1
109
+ P_old = P
110
+ g_old = g
111
+ if not silent:
112
+ print("Iteration:", iter_ctr)
113
+ start_iter = time.time()
114
+
115
+ # Approximate transition coupling evaluation
116
+ if not silent:
117
+ print("Computing entropic TCE...")
118
+ g, h = approx_tce(P, c, L, T)
119
+
120
+ # Entropic transition coupling improvement (passing solver function to entropic_tci)
121
+ if not silent:
122
+ print("Computing entropic TCE...")
123
+ P = entropic_tci(h=h, P0=P_old, Px=Px, Py=Py, xi=xi, solver_fn=solver_fn)
124
+
125
+ if not silent:
126
+ iter_time = time.time() - start_iter
127
+ elapsed = time.time() - start_time
128
+ g0 = float(np.ravel(g)[0])
129
+ g0_old = float(np.ravel(g_old)[0])
130
+ diff = g0_old - g0
131
+ ratio = diff / g0 if g0 != 0 else float("inf")
132
+ print(
133
+ f"[Iter {iter_ctr} taking {iter_time:.2f}s] Δg={diff:.3e}, g[0]={g0:.6f}, Δg/g[0]={ratio:.3e}, total elapsed={elapsed:.2f}s"
134
+ )
135
+
136
+ # In case of numerical instability, make non-negative and normalize.
137
+ P = np.maximum(P, 0)
138
+ row_sums = np.sum(P, axis=1, keepdims=True)
139
+ P = P / np.where(row_sums > 0, row_sums, 1)
140
+
141
+ if get_sd:
142
+ if not silent:
143
+ print(
144
+ f"Convergence reached in {iter_ctr} iterations. Computing stationary distribution..."
145
+ )
146
+ stat_dist, exp_cost = get_best_stat_dist(P, c)
147
+ stat_dist = np.reshape(stat_dist, (dx, dy))
148
+ else:
149
+ if not silent:
150
+ print(
151
+ f"Convergence reached in {iter_ctr} iterations. No stationary distribution computation requested."
152
+ )
153
+ stat_dist = None
154
+ exp_cost = g[0].item()
155
+
156
+ if not silent:
157
+ print(
158
+ f"[entropic_otc] Finished. Total time elapsed: {time.time() - start_time:.3f} seconds."
159
+ )
160
+
161
+ return exp_cost, P, stat_dist
@@ -0,0 +1,49 @@
1
+ import numpy as np
2
+ import ot
3
+
4
+
5
+ def entropic_tci(h, P0, Px, Py, xi, solver_fn):
6
+ """
7
+ Performs entropic Transition Coupling Improvement (TCI) using log-domain Sinkhorn algorithm.
8
+
9
+ For each (i, j) state pair from the product space of two Markov chains, this function solves
10
+ a local entropic optimal transport problem based on the bias vector h.
11
+
12
+ Args:
13
+ h (np.ndarray): Bias vector of shape (dx*dy,).
14
+ P0 (np.ndarray): Previous transition coupling matrix of shape (dx*dy, dx*dy).
15
+ Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx).
16
+ Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy).
17
+ xi (float): Scaling factor for entropic cost adjustment.
18
+ solver_fn (callable): A function solves the optimization and provides a transport plan. Specified in 'entropic_otc'.
19
+
20
+ Returns:
21
+ np.ndarray: Updated transition coupling matrix of shape (dx*dy, dx*dy).
22
+ """
23
+
24
+ dx, dy = Px.shape[0], Py.shape[0]
25
+ P = P0.copy()
26
+ h_mat = np.reshape(h, (dx, dy))
27
+ K = -xi * h_mat
28
+
29
+ for i in range(dx):
30
+ for j in range(dy):
31
+ dist_x = Px[i, :]
32
+ dist_y = Py[j, :]
33
+ x_idxs = np.where(dist_x > 0)[0]
34
+ y_idxs = np.where(dist_y > 0)[0]
35
+
36
+ if len(x_idxs) == 1 or len(y_idxs) == 1:
37
+ P[dy * i + j, :] = P0[dy * i + j, :]
38
+ else:
39
+ A_matrix = K[np.ix_(x_idxs, y_idxs)]
40
+ sub_dist_x = dist_x[x_idxs]
41
+ sub_dist_y = dist_y[y_idxs]
42
+
43
+ sol = solver_fn(A_matrix, sub_dist_x, sub_dist_y)
44
+
45
+ sol_full = np.zeros((dx, dy))
46
+ sol_full[np.ix_(x_idxs, y_idxs)] = sol
47
+ P[dy * i + j, :] = sol_full.flatten()
48
+
49
+ return P
@@ -0,0 +1,127 @@
1
+ import numpy as np
2
+ import time
3
+
4
+ from .exact_tce import exact_tce
5
+ from .exact_tci_lp import exact_tci as exact_tci_lp
6
+ from .exact_tci_pot import exact_tci as exact_tci_pot
7
+ from ..utils import get_stat_dist
8
+
9
+
10
+ def exact_otc_lp(Px, Py, c, stat_dist="best"):
11
+ start = time.time()
12
+ print("Starting exact_otc_dense...")
13
+
14
+ dx = Px.shape[0]
15
+ dy = Py.shape[0]
16
+ P_old = np.ones((dx * dy, dx * dy))
17
+ P = np.kron(Px, Py)
18
+
19
+ while np.max(np.abs(P - P_old)) > 1e-10:
20
+ P_old = np.copy(P)
21
+
22
+ print("Computing exact TCE...")
23
+ g, h = exact_tce(P, c)
24
+
25
+ print("Computing exact TCI...")
26
+ P = exact_tci_lp(g, h, P_old, Px, Py)
27
+
28
+ # Check for convergence.
29
+ if np.all(P == P_old):
30
+ if stat_dist is None:
31
+ print(
32
+ "Convergence reached. No stationary distribution computation requested."
33
+ )
34
+ exp_cost = g[0].item()
35
+ end = time.time()
36
+ print(
37
+ f"[exact_otc] Finished. Total time elapsed: {end - start:.3f} seconds."
38
+ )
39
+ return float(exp_cost), P, None
40
+ else:
41
+ print("Convergence reached. Computing stationary distribution...")
42
+ stat_dist = get_stat_dist(P, method=stat_dist, c=c)
43
+ stat_dist = np.reshape(stat_dist, (dx, dy))
44
+ exp_cost = g[0].item()
45
+ end = time.time()
46
+ print(
47
+ f"[exact_otc] Finished. Total time elapsed: {end - start:.3f} seconds."
48
+ )
49
+ return float(exp_cost), P, stat_dist
50
+
51
+ return None, None, None
52
+
53
+
54
+ def exact_otc(Px, Py, c, stat_dist="best"):
55
+ """
56
+ Computes the optimal transport coupling (OTC) between two stationary Markov chains represented by transition matrices Px and Py,
57
+ as described in Algorithm 1 of the paper: "Optimal Transport for Stationary Markov Chains via Policy Iteration"
58
+ (https://www.jmlr.org/papers/volume23/21-0519/21-0519.pdf).
59
+
60
+ The algorithm iteratively updates the transition coupling matrix until convergence by alternating
61
+ between Transition Coupling Evaluation (TCE) and Transition Coupling Improvement (TCI) steps.
62
+
63
+ For a detailed discussion of the connection between the OTC problem and Markov Decision Processes (MDPs), see Section 4 of the paper.
64
+ Additional background on policy iteration methods for solving average-cost MDP problems can be found in Chapters 8 and 9 of
65
+ "Markov Decision Processes: Discrete Stochastic Dynamic Programming" by Martin L. Puterman.
66
+
67
+ Args:
68
+ Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx).
69
+ Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy).
70
+ c (np.ndarray): Cost function of shape (dx, dy).
71
+ stat_dist (str, optional): Method to compute the stationary distribution.
72
+ Options include 'best', 'eigen', 'iterative' and None. Defaults to 'best'.
73
+
74
+ Returns:
75
+ exp_cost (float): Expected transport cost under the optimal transition coupling.
76
+ R (np.ndarray): Optimal transition coupling matrix of shape (dx*dy, dx*dy).
77
+ stat_dist (np.ndarray): Stationary distribution of the optimal transition coupling of shape (dx, dy).
78
+
79
+ Returns (None, None, None) if the algorithm fails to converge.
80
+ """
81
+
82
+ start = time.time()
83
+ print("Starting exact_otc_dense...")
84
+
85
+ dx, dy = Px.shape[0], Py.shape[0]
86
+ R_old = np.ones((dx * dy, dx * dy))
87
+ R = np.kron(Px, Py)
88
+ iter = 0
89
+
90
+ while np.max(np.abs(R - R_old)) > 1e-10:
91
+ print("Iteration:", iter)
92
+ R_old = np.copy(R)
93
+
94
+ print("Computing exact TCE...")
95
+ g, h = exact_tce(R, c)
96
+
97
+ print("Computing exact TCI...")
98
+ R = exact_tci_pot(g, h, R_old, Px, Py)
99
+
100
+ # Check if the transition coupling matrix has converged
101
+ if np.all(R == R_old):
102
+ if stat_dist is None:
103
+ print(
104
+ f"Convergence reached in {iter + 1} iterations. No stationary distribution computation requested."
105
+ )
106
+ exp_cost = g[0].item()
107
+ end = time.time()
108
+ print(
109
+ f"[exact_otc] Finished. Total time elapsed: {end - start:.3f} seconds."
110
+ )
111
+ return float(exp_cost), R, None
112
+ else:
113
+ print(
114
+ f"Convergence reached in {iter + 1} iterations. Computing stationary distribution..."
115
+ )
116
+ stat_dist = get_stat_dist(R, method=stat_dist, c=c)
117
+ stat_dist = np.reshape(stat_dist, (dx, dy))
118
+ exp_cost = g[0].item()
119
+ end = time.time()
120
+ print(
121
+ f"[exact_otc] Finished. Total time elapsed: {end - start:.3f} seconds."
122
+ )
123
+ return float(exp_cost), R, stat_dist
124
+
125
+ iter += 1
126
+
127
+ return None, None, None
@@ -0,0 +1,56 @@
1
+ """
2
+ Original Transition Coupling Evaluation (TCE) methods from:
3
+ https://www.jmlr.org/papers/volume23/21-0519/21-0519.pdf
4
+ """
5
+
6
+ import numpy as np
7
+ from numpy.linalg import pinv
8
+
9
+
10
+ def exact_tce(R, c):
11
+ """
12
+ Computes the exact Transition Coupling Evaluation (TCE) vectors g and h
13
+ using the linear system described in Algorithm 1a of the paper
14
+ "Optimal Transport for Stationary Markov Chains via Policy Iteration"
15
+ (https://www.jmlr.org/papers/volume23/21-0519/21-0519.pdf).
16
+
17
+ The method solves a block linear system involving the transition matrix R and cost vector c.
18
+ If the system is not full rank, a pseudo-inverse (pinv) is used as fallback.
19
+
20
+ Args:
21
+ R (np.ndarray): Transition matrix of shape (dx*dy, dx*dy).
22
+ c (np.ndarray): Cost vector of shape (dx*dy, dx*dy).
23
+
24
+ Returns:
25
+ g (np.ndarray): Average cost (gain) vector of shape (dx*dy,).
26
+ h (np.ndarray): Total extra cost (bias) vector of shape (dx*dy,).
27
+
28
+ Notes:
29
+ - If the matrix A is singular or ill-conditioned, the solution uses `np.linalg.pinv`,
30
+ which may lead to numerical instability.
31
+ - Make sure Pz is a proper stochastic matrix (rows sum to 1).
32
+ """
33
+ d = R.shape[0]
34
+ c = np.reshape(c, (d, -1))
35
+
36
+ # Construct the block matrix A and right-hand side vector b
37
+ A = np.block(
38
+ [
39
+ [np.eye(d) - R, np.zeros((d, d)), np.zeros((d, d))],
40
+ [np.eye(d), np.eye(d) - R, np.zeros((d, d))],
41
+ [np.zeros((d, d)), np.eye(d), np.eye(d) - R],
42
+ ]
43
+ )
44
+ b = np.concatenate([np.zeros((d, 1)), c, np.zeros((d, 1))])
45
+
46
+ # Solve the linear system Ax = b
47
+ try:
48
+ sol = np.linalg.solve(A, b)
49
+ except:
50
+ sol = np.matmul(pinv(A), b)
51
+
52
+ # Extract g and h from the solution
53
+ g = sol[0:d].flatten()
54
+ h = sol[d : 2 * d].flatten()
55
+
56
+ return g, h
@@ -0,0 +1,65 @@
1
+ """
2
+ Original Transition Coupling Improvements (TCI) methods from:
3
+ https://jmlr.csail.mit.edu/papers/volume23/21-0519/21-0519.pdf
4
+
5
+ Use scipy.linprog (LP solver) library to solve optimal transport problem.
6
+ """
7
+
8
+ import numpy as np
9
+ import copy
10
+ from pyotc.otc_backend.optimal_transport.native import computeot_lp
11
+
12
+
13
+ def check_constant(f, Px, threshold=1e-3):
14
+ dx = Px.shape[0]
15
+ g_const = True
16
+ for i in range(dx):
17
+ for j in range(i + 1, dx):
18
+ if abs(f[i] - f[j]) > threshold:
19
+ g_const = False
20
+ break
21
+ if not g_const:
22
+ break
23
+ return g_const
24
+
25
+
26
+ def setup_ot(f, Px, Py, Pz):
27
+ dx = Px.shape[0]
28
+ dy = Py.shape[0]
29
+ f_mat = np.reshape(f, (dx, dy))
30
+ for x_row in range(dx):
31
+ for y_row in range(dy):
32
+ dist_x = Px[x_row, :]
33
+ dist_y = Py[y_row, :]
34
+ # Check if either distribution is degenerate.
35
+ if any(dist_x == 1) or any(dist_y == 1):
36
+ sol = np.outer(dist_x, dist_y)
37
+ # If not degenerate, proceed with OT.
38
+ else:
39
+ sol, val = computeot_lp(f_mat, dist_x, dist_y)
40
+ idx = dy * (x_row) + y_row
41
+ Pz[idx, :] = np.reshape(sol, (-1, dx * dy))
42
+ return Pz
43
+
44
+
45
+ def exact_tci(g, h, P0, Px, Py):
46
+ # Check if g is constant.
47
+ dx = Px.shape[0]
48
+ dy = Py.shape[0]
49
+ Pz = np.zeros((dx * dy, dx * dy))
50
+ g_const = check_constant(f=g, Px=Px)
51
+
52
+ # If g is not constant, improve transition coupling against g.
53
+ if not g_const:
54
+ Pz = setup_ot(f=g, Px=Px, Py=Py, Pz=Pz)
55
+ if np.max(np.abs(np.matmul(P0, g) - np.matmul(Pz, g))) <= 1e-7:
56
+ Pz = copy.deepcopy(P0)
57
+ else:
58
+ return Pz
59
+
60
+ # Try to improve with respect to h.
61
+ Pz = setup_ot(f=h, Px=Px, Py=Py, Pz=Pz)
62
+ if np.max(np.abs(np.matmul(P0, h) - np.matmul(Pz, h))) <= 1e-4:
63
+ Pz = copy.deepcopy(P0)
64
+
65
+ return Pz
@@ -0,0 +1,90 @@
1
+ """
2
+ Original Transition Coupling Improvements (TCI) method from:
3
+ https://www.jmlr.org/papers/volume23/21-0519/21-0519.pdf
4
+
5
+ Use the python optimal transport (POT) library to solve optimal transport problem.
6
+ """
7
+
8
+ import numpy as np
9
+ import copy
10
+ from pyotc.otc_backend.optimal_transport.pot import computeot_pot
11
+
12
+
13
+ def setup_ot(f, Px, Py, R):
14
+ """
15
+ This improvement step updates the transition coupling matrix R that minimizes the product Rf element-wise.
16
+ In more detail, we may select a transition coupling R such that for each state pair (x, y),
17
+ the corresponding row r = R((x, y), ·) minimizes rf over couplings r in Pi(Px(x, ·), Py(y, ·)).
18
+ This is done by solving the optimal transport problem for each state pair (x, y) in the source
19
+ and target Markov chains. The resulting transition coupling matrix R is updated accordingly.
20
+
21
+ This function uses the POT (Python Optimal Transport) library to solve the optimal transport problem
22
+ for each (x, y) state pair and updates the transition coupling matrix.
23
+
24
+ Args:
25
+ f (np.ndarray): Cost function reshaped as of shape (dx*dy,).
26
+ Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx).
27
+ Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy).
28
+ R (np.ndarray): Transition coupling matrix to update of shape (dx*dy, dx*dy).
29
+
30
+ Returns:
31
+ R (np.ndarray): Updated transition coupling matrix of shape (dx*dy, dx*dy).
32
+ """
33
+
34
+ dx, dy = Px.shape[0], Py.shape[0]
35
+ f_mat = np.reshape(f, (dx, dy))
36
+
37
+ for x_row in range(dx):
38
+ for y_row in range(dy):
39
+ dist_x = Px[x_row, :]
40
+ dist_y = Py[y_row, :]
41
+
42
+ # Check if either distribution is degenerate.
43
+ if any(dist_x == 1) or any(dist_y == 1):
44
+ sol = np.outer(dist_x, dist_y)
45
+ # If not degenerate, proceed with OT.
46
+ else:
47
+ sol, _ = computeot_pot(f_mat, dist_x, dist_y)
48
+ idx = dy * (x_row) + y_row
49
+ R[idx, :] = np.reshape(sol, (-1, dx * dy))
50
+
51
+ return R
52
+
53
+
54
+ def exact_tci(g, h, R0, Px, Py):
55
+ """
56
+ Performs the Transition Coupling Improvement (TCI) step in the OTC algorithm.
57
+
58
+ This function attempts to update the current coupling transition matrix R0
59
+ based on the evaluation vectors g and h obtained from the Transition Coupling Evaluation (TCE).
60
+
61
+ Args:
62
+ g (np.ndarray): Gain vector from TCE of shape (dx*dy,).
63
+ h (np.ndarray): Bias vector from TCE of shape (dx*dy,).
64
+ R0 (np.ndarray): Current transition coupling matrix of shape (dx*dy, dx*dy).
65
+ Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx).
66
+ Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy).
67
+
68
+ Returns:
69
+ R (np.ndarray): Improved transition coupling matrix of shape (dx*dy, dx*dy).
70
+ """
71
+
72
+ # Check if g is constant.
73
+ dx, dy = Px.shape[0], Py.shape[0]
74
+ R = np.zeros((dx * dy, dx * dy))
75
+ g_const = np.max(g) - np.min(g) <= 1e-3
76
+
77
+ # If g is not constant, improve transition coupling against g.
78
+ if not g_const:
79
+ R = setup_ot(g, Px, Py, R)
80
+ if np.max(np.abs(np.matmul(R0, g) - np.matmul(R, g))) <= 1e-7:
81
+ R = copy.deepcopy(R0)
82
+ else:
83
+ return R
84
+
85
+ # Try to improve with respect to h.
86
+ R = setup_ot(h, Px, Py, R)
87
+ if np.max(np.abs(np.matmul(R0, h) - np.matmul(R, h))) <= 1e-4:
88
+ R = copy.deepcopy(R0)
89
+
90
+ return R
File without changes
@@ -0,0 +1,89 @@
1
+ import numpy as np
2
+ import scipy.sparse as sp
3
+ import time
4
+
5
+ from .exact_tce import exact_tce
6
+ from .exact_tci import exact_tci
7
+ from ..utils import get_stat_dist
8
+
9
+
10
+ def exact_otc(Px, Py, c, stat_dist="best", max_iter=100):
11
+ """
12
+ Computes the optimal transport coupling (OTC) between two stationary Markov chains represented by transition matrices Px and Py,
13
+ as described in Algorithm 1 of the paper: "Optimal Transport for Stationary Markov Chains via Policy Iteration"
14
+ (https://www.jmlr.org/papers/volume23/21-0519/21-0519.pdf).
15
+
16
+ The algorithm iteratively updates the transition coupling matrix until convergence by alternating
17
+ between Transition Coupling Evaluation (TCE) and Transition Coupling Improvement (TCI) steps.
18
+
19
+ For a detailed discussion of the connection between the OTC problem and Markov Decision Processes (MDPs), see Section 4 of the paper.
20
+ Additional background on policy iteration methods for solving average-cost MDP problems can be found in Chapters 8 and 9 of
21
+ "Markov Decision Processes: Discrete Stochastic Dynamic Programming" by Martin L. Puterman.
22
+
23
+ Note:
24
+ In the TCE step (implemented in exact_tce), we solve a block linear system using functions from scipy.sparse.linalg.
25
+ However, when A in Ax = b is nearly singular, we have observed a few cases where both SciPy solvers (scipy.sparse.linalg.spsolve, scipy.sparse.linalg.lsmr)
26
+ can produce results that differ from NumPy's solver (np.linalg.solve). This leads to discrepancies with the dense implementation and non-convergence.
27
+ This is an issue with SciPy's sparse solvers and remains unresolved. The best approach in such cases is to use the dense implementation.
28
+
29
+ Args:
30
+ Px (np.ndarray): Transition matrix of the source Markov chain of shape (dx, dx).
31
+ Py (np.ndarray): Transition matrix of the target Markov chain of shape (dy, dy).
32
+ c (np.ndarray): Cost function of shape (dx, dy).
33
+ stat_dist (str, optional): Method to compute the stationary distribution.
34
+ Options include 'best', 'eigen', 'iterative' and None. Defaults to 'best'.
35
+ max_iter (int, optional): Maximum number of iterations for the convergence process. Defaults to 100.
36
+
37
+ Returns:
38
+ exp_cost (float): Expected transport cost under the optimal transition coupling.
39
+ R (scipy.sparse.csr_matrix): Optimal transition coupling matrix of shape (dx*dy, dx*dy).
40
+ stat_dist (np.ndarray): Stationary distribution of the optimal transition coupling of shape (dx, dy).
41
+
42
+ If convergence is not reached within max_iter iterations, returns (None, None, None).
43
+ """
44
+
45
+ start = time.time()
46
+ print("Starting exact_otc_sparse...")
47
+ dx, dy = Px.shape[0], Py.shape[0]
48
+
49
+ # Initial coupling matrix using Kronecker product
50
+ R = sp.kron(sp.csr_matrix(Px), sp.csr_matrix(Py), format="csr")
51
+
52
+ for iter in range(max_iter):
53
+ print("Iteration:", iter)
54
+ R_old = R.copy()
55
+
56
+ print("Computing exact TCE...")
57
+ g, h = exact_tce(R, c)
58
+
59
+ print("Computing exact TCI...")
60
+ R = exact_tci(g, h, R_old, Px, Py)
61
+
62
+ # Check if the transition coupling matrix has converged
63
+ if (R != R_old).nnz == 0:
64
+ if stat_dist is None:
65
+ print(
66
+ f"Convergence reached in {iter + 1} iterations. No stationary distribution computation requested."
67
+ )
68
+ exp_cost = g[0].item()
69
+ end = time.time()
70
+ print(
71
+ f"[exact_otc] Finished. Total time elapsed: {end - start:.3f} seconds."
72
+ )
73
+ return float(exp_cost), R, None
74
+ else:
75
+ print(
76
+ f"Convergence reached in {iter + 1} iterations. Computing stationary distribution..."
77
+ )
78
+ stat_dist = get_stat_dist(R, method=stat_dist, c=c)
79
+ stat_dist = np.reshape(stat_dist, (dx, dy))
80
+ exp_cost = g[0].item()
81
+ end = time.time()
82
+ print(
83
+ f"[exact_otc] Finished. Total time elapsed: {end - start:.3f} seconds."
84
+ )
85
+ return float(exp_cost), R, stat_dist
86
+
87
+ # Return None if convergence is not achieved
88
+ print(f"Convergence not achieved after {iter} iterations. Returning None.")
89
+ return None, None, None