pyotc 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pyotc/__init__.py +5 -0
  2. pyotc/examples/__init__.py +0 -0
  3. pyotc/examples/edge_awareness.py +86 -0
  4. pyotc/examples/lollipops.py +54 -0
  5. pyotc/examples/stochastic_block_model.py +57 -0
  6. pyotc/examples/wheel.py +127 -0
  7. pyotc/otc.py +5 -0
  8. pyotc/otc_backend/__init__.py +0 -0
  9. pyotc/otc_backend/graph/__init__.py +3 -0
  10. pyotc/otc_backend/graph/utils.py +109 -0
  11. pyotc/otc_backend/optimal_transport/__init__.py +0 -0
  12. pyotc/otc_backend/optimal_transport/logsinkhorn.py +78 -0
  13. pyotc/otc_backend/optimal_transport/native.py +49 -0
  14. pyotc/otc_backend/optimal_transport/native_refactor.py +51 -0
  15. pyotc/otc_backend/optimal_transport/pot.py +38 -0
  16. pyotc/otc_backend/policy_iteration/__init__.py +0 -0
  17. pyotc/otc_backend/policy_iteration/dense/__init__.py +0 -0
  18. pyotc/otc_backend/policy_iteration/dense/approx_tce.py +42 -0
  19. pyotc/otc_backend/policy_iteration/dense/entropic.py +161 -0
  20. pyotc/otc_backend/policy_iteration/dense/entropic_tci.py +49 -0
  21. pyotc/otc_backend/policy_iteration/dense/exact.py +127 -0
  22. pyotc/otc_backend/policy_iteration/dense/exact_tce.py +56 -0
  23. pyotc/otc_backend/policy_iteration/dense/exact_tci_lp.py +65 -0
  24. pyotc/otc_backend/policy_iteration/dense/exact_tci_pot.py +90 -0
  25. pyotc/otc_backend/policy_iteration/sparse/__init__.py +0 -0
  26. pyotc/otc_backend/policy_iteration/sparse/exact.py +89 -0
  27. pyotc/otc_backend/policy_iteration/sparse/exact_tce.py +78 -0
  28. pyotc/otc_backend/policy_iteration/sparse/exact_tci.py +88 -0
  29. pyotc/otc_backend/policy_iteration/utils.py +112 -0
  30. pyotc-0.2.2.dist-info/METADATA +38 -0
  31. pyotc-0.2.2.dist-info/RECORD +34 -0
  32. pyotc-0.2.2.dist-info/WHEEL +4 -0
  33. pyotc-0.2.2.dist-info/licenses/AUTHORS.rst +12 -0
  34. pyotc-0.2.2.dist-info/licenses/LICENSE +22 -0
pyotc/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ """Top-level package for pyotc."""
2
+
3
+ __author__ = """Jay Hineman"""
4
+ __email__ = "jay.hineman@gmail.com"
5
+ __version__ = "0.1.0"
File without changes
@@ -0,0 +1,86 @@
1
+ """
2
+ networkx graphs for edge awareness example and corresponding costs from Table 1
3
+
4
+ [Alignment and Comparison of Directed Networks via Transition Couplings of Random Walks](https://arxiv.org/abs/2106.07106)
5
+
6
+ Figure 4 graphs
7
+
8
+ * G_1 is the regular octogon
9
+ * G_2 is the regular octogon removing 1 edge
10
+ * G_3 is uniform edge lengths of the octogon removing 1 edge
11
+ """
12
+
13
+ import numpy as np
14
+ import networkx as nx
15
+
16
+ # Define graphs G1, G2, G3
17
+ edge_awareness_1 = {
18
+ "nodes": [{"id": i} for i in range(1, 9)],
19
+ "edges": [
20
+ {"source": 1, "target": 2},
21
+ {"source": 2, "target": 3},
22
+ {"source": 3, "target": 4},
23
+ {"source": 4, "target": 5},
24
+ {"source": 5, "target": 6},
25
+ {"source": 6, "target": 7},
26
+ {"source": 7, "target": 8},
27
+ {"source": 8, "target": 1},
28
+ ],
29
+ "name": "edge awareness graph 1",
30
+ }
31
+
32
+ edge_awareness_2_3 = {
33
+ "nodes": [{"id": i} for i in range(1, 9)],
34
+ "edges": [
35
+ {"source": 1, "target": 2},
36
+ {"source": 2, "target": 3},
37
+ {"source": 3, "target": 4},
38
+ {"source": 4, "target": 5},
39
+ {"source": 5, "target": 6},
40
+ {"source": 6, "target": 7},
41
+ {"source": 7, "target": 8},
42
+ ],
43
+ "name": "edge awareness graph 2, 3",
44
+ }
45
+
46
+ graph_1 = nx.node_link_graph(data=edge_awareness_1, edges="edges")
47
+ graph_2 = nx.node_link_graph(data=edge_awareness_2_3, edges="edges")
48
+ graph_3 = nx.node_link_graph(data=edge_awareness_2_3, edges="edges")
49
+
50
+ # Define the coordinates of G_1, G_2, G_3
51
+ # All vertices are located on the unit circle in R^2
52
+ # d1: coordinate of G_1 vertices (regular octagon)
53
+ # d2: coordinate of G_2 vertices (regular octagon)
54
+ # d3: coordinate of G_3 vertices (the vertices are uniformly distributed in the left semicircle)
55
+
56
+ d1 = np.zeros((8, 2))
57
+ for i in range(8):
58
+ d1[i, 0] = np.cos(np.pi / 8 + np.pi / 4 * i)
59
+ d1[i, 1] = np.sin(np.pi / 8 + np.pi / 4 * i)
60
+
61
+ d2 = d1.copy()
62
+
63
+ d3 = np.zeros((8, 2))
64
+ for i in range(8):
65
+ d3[i, 0] = np.cos(np.pi / 2 + np.pi / 7 * i)
66
+ d3[i, 1] = np.sin(np.pi / 2 + np.pi / 7 * i)
67
+
68
+ # Get cost matrices
69
+ # Define a cost function equal to the squared Euclidean distance between vertex positions
70
+ # c21: cost function between G_2 and G_1
71
+ # c23: cost function between G_2 and G_3
72
+
73
+
74
+ def euclidean_cost(v1, v2):
75
+ n1 = v1.shape[0]
76
+ n2 = v2.shape[0]
77
+ c = np.zeros((n1, n2))
78
+ for i in range(n1):
79
+ for j in range(n2):
80
+ c[i, j] = np.sum((v1[i, :] - v2[j, :]) ** 2)
81
+
82
+ return c
83
+
84
+
85
+ c21 = euclidean_cost(d2, d1)
86
+ c23 = euclidean_cost(d2, d3)
@@ -0,0 +1,54 @@
1
+ """
2
+ networkx graphs for lollipop examples given in
3
+
4
+ [Alignment and Comparison of Directed Networks via Transition Couplings of Random Walks](https://arxiv.org/abs/2106.07106)
5
+
6
+ Figure 5 graphs
7
+
8
+ * lollipop_1 is the left graph
9
+ * lollipop_2 is the right graph
10
+ """
11
+
12
+ import networkx as nx
13
+
14
+ # Define graphs
15
+ left_lollipop_graph = {
16
+ "nodes": [{"id": i} for i in range(1, 13)],
17
+ "edges": [
18
+ {"source": 1, "target": 2},
19
+ {"source": 2, "target": 3},
20
+ {"source": 3, "target": 4},
21
+ {"source": 4, "target": 5},
22
+ {"source": 5, "target": 6},
23
+ {"source": 6, "target": 7},
24
+ {"source": 7, "target": 8},
25
+ {"source": 8, "target": 9},
26
+ {"source": 9, "target": 10},
27
+ {"source": 10, "target": 11},
28
+ {"source": 11, "target": 12},
29
+ {"source": 12, "target": 4},
30
+ ],
31
+ "name": "left lollipop graph",
32
+ }
33
+
34
+ right_lollipop_graph = {
35
+ "nodes": [{"id": i} for i in range(1, 13)],
36
+ "edges": [
37
+ {"source": 7, "target": 9},
38
+ {"source": 9, "target": 4},
39
+ {"source": 4, "target": 6},
40
+ {"source": 6, "target": 1},
41
+ {"source": 1, "target": 2},
42
+ {"source": 2, "target": 11},
43
+ {"source": 11, "target": 8},
44
+ {"source": 8, "target": 10},
45
+ {"source": 10, "target": 5},
46
+ {"source": 5, "target": 3},
47
+ {"source": 3, "target": 12},
48
+ {"source": 12, "target": 6},
49
+ ],
50
+ "name": "right lollipop graph",
51
+ }
52
+
53
+ lollipop_1 = nx.node_link_graph(data=left_lollipop_graph, edges="edges")
54
+ lollipop_2 = nx.node_link_graph(data=right_lollipop_graph, edges="edges")
@@ -0,0 +1,57 @@
1
+ import numpy as np
2
+
3
+
4
+ def stochastic_block_model(sizes: tuple, probs: np.ndarray) -> np.ndarray:
5
+ """Generate the adjacency for a stochastic block model SBM from a tuple (length n)
6
+ of sizes an (nxn) matrix of probabilities.
7
+
8
+ Args:
9
+ sizes (tuple): tuple of node sizes with length of number of blocks
10
+ probs (np.ndarray): nxn symmetric matrix
11
+
12
+ Raises:
13
+ ValueError: If probs is not a square numpy array
14
+ ValueError: If probs is not symmetric
15
+ ValueError: If sizes and probs dimensions do not match
16
+
17
+ Returns:
18
+ np.ndarray: adjancency matrix for SBM
19
+ """
20
+ # Check input type
21
+ if not isinstance(probs, np.ndarray) or probs.shape[0] != probs.shape[1]:
22
+ raise ValueError("'probs' must be a square numpy array.")
23
+ elif not np.allclose(probs, probs.T):
24
+ raise ValueError("'probs' must be a symmetric matrix.")
25
+ elif len(sizes) != probs.shape[0]:
26
+ raise ValueError("'sizes' and 'probs' dimensions do not match.")
27
+
28
+ n = sum(sizes) # Total number of nodes
29
+ n_b = len(sizes) # Total number of blocks
30
+ A = np.zeros((n, n))
31
+
32
+ # Column index of each block's start
33
+ cumsum = 0
34
+ start = [0]
35
+ for size in sizes:
36
+ cumsum += size
37
+ start.append(cumsum)
38
+
39
+ # Generating Adjacency Matrix (upper)
40
+ # Generate diagonal blocks
41
+ for i in range(n_b):
42
+ p = probs[i, i]
43
+ for j in range(start[i], start[i + 1]):
44
+ for k in range(j + 1, start[i + 1]):
45
+ A[j, k] = np.random.choice([0, 1], p=[1 - p, p])
46
+
47
+ # Generate Nondiagonal blocks
48
+ for i in range(n_b - 1):
49
+ for j in range(i + 1, n_b):
50
+ A[start[i] : start[i + 1], start[j] : start[j + 1]] = np.random.choice(
51
+ [0, 1], size=(sizes[i], sizes[j]), p=[1 - probs[i, j], probs[i, j]]
52
+ )
53
+
54
+ # Fill lower triangular matrix
55
+ A = A + A.T
56
+
57
+ return A
@@ -0,0 +1,127 @@
1
+ """
2
+ networkx graphs for wheel graph examples given in Version 1 of
3
+
4
+ [Alignment and Comparison of Directed Networks via Transition Couplings of Random Walks](https://arxiv.org/pdf/2106.07106v1.pdf)
5
+
6
+ Figure 2 graphs
7
+
8
+ * G_1 is a wheel graph of order 16
9
+ * G_2 is a wheel graph removing 1 spoke edge
10
+ * G_3 is a wheel graph removing 1 wheel edge
11
+ """
12
+
13
+ import networkx as nx
14
+
15
+ # Define graphs G1, G2, G3
16
+ wheel_graph_1 = {
17
+ "nodes": [{"id": i} for i in range(1, 17)],
18
+ "edges": [
19
+ {"source": 1, "target": 2},
20
+ {"source": 1, "target": 3},
21
+ {"source": 1, "target": 4},
22
+ {"source": 1, "target": 5},
23
+ {"source": 1, "target": 6},
24
+ {"source": 1, "target": 7},
25
+ {"source": 1, "target": 8},
26
+ {"source": 1, "target": 9},
27
+ {"source": 1, "target": 10},
28
+ {"source": 1, "target": 11},
29
+ {"source": 1, "target": 12},
30
+ {"source": 1, "target": 13},
31
+ {"source": 1, "target": 14},
32
+ {"source": 1, "target": 15},
33
+ {"source": 1, "target": 16},
34
+ {"source": 2, "target": 3},
35
+ {"source": 3, "target": 4},
36
+ {"source": 4, "target": 5},
37
+ {"source": 5, "target": 6},
38
+ {"source": 6, "target": 7},
39
+ {"source": 7, "target": 8},
40
+ {"source": 8, "target": 9},
41
+ {"source": 9, "target": 10},
42
+ {"source": 10, "target": 11},
43
+ {"source": 11, "target": 12},
44
+ {"source": 12, "target": 13},
45
+ {"source": 13, "target": 14},
46
+ {"source": 14, "target": 15},
47
+ {"source": 15, "target": 16},
48
+ {"source": 16, "target": 2},
49
+ ],
50
+ "name": "wheel graph 1",
51
+ }
52
+
53
+ wheel_graph_2 = {
54
+ "nodes": [{"id": i} for i in range(1, 17)],
55
+ "edges": [
56
+ {"source": 1, "target": 3},
57
+ {"source": 1, "target": 4},
58
+ {"source": 1, "target": 5},
59
+ {"source": 1, "target": 6},
60
+ {"source": 1, "target": 7},
61
+ {"source": 1, "target": 8},
62
+ {"source": 1, "target": 9},
63
+ {"source": 1, "target": 10},
64
+ {"source": 1, "target": 11},
65
+ {"source": 1, "target": 12},
66
+ {"source": 1, "target": 13},
67
+ {"source": 1, "target": 14},
68
+ {"source": 1, "target": 15},
69
+ {"source": 1, "target": 16},
70
+ {"source": 2, "target": 3},
71
+ {"source": 3, "target": 4},
72
+ {"source": 4, "target": 5},
73
+ {"source": 5, "target": 6},
74
+ {"source": 6, "target": 7},
75
+ {"source": 7, "target": 8},
76
+ {"source": 8, "target": 9},
77
+ {"source": 9, "target": 10},
78
+ {"source": 10, "target": 11},
79
+ {"source": 11, "target": 12},
80
+ {"source": 12, "target": 13},
81
+ {"source": 13, "target": 14},
82
+ {"source": 14, "target": 15},
83
+ {"source": 15, "target": 16},
84
+ {"source": 16, "target": 2},
85
+ ],
86
+ "name": "wheel graph 2",
87
+ }
88
+
89
+ wheel_graph_3 = {
90
+ "nodes": [{"id": i} for i in range(1, 17)],
91
+ "edges": [
92
+ {"source": 1, "target": 2},
93
+ {"source": 1, "target": 3},
94
+ {"source": 1, "target": 4},
95
+ {"source": 1, "target": 5},
96
+ {"source": 1, "target": 6},
97
+ {"source": 1, "target": 7},
98
+ {"source": 1, "target": 8},
99
+ {"source": 1, "target": 9},
100
+ {"source": 1, "target": 10},
101
+ {"source": 1, "target": 11},
102
+ {"source": 1, "target": 12},
103
+ {"source": 1, "target": 13},
104
+ {"source": 1, "target": 14},
105
+ {"source": 1, "target": 15},
106
+ {"source": 1, "target": 16},
107
+ {"source": 2, "target": 3},
108
+ {"source": 3, "target": 4},
109
+ {"source": 4, "target": 5},
110
+ {"source": 5, "target": 6},
111
+ {"source": 6, "target": 7},
112
+ {"source": 7, "target": 8},
113
+ {"source": 8, "target": 9},
114
+ {"source": 9, "target": 10},
115
+ {"source": 10, "target": 11},
116
+ {"source": 11, "target": 12},
117
+ {"source": 12, "target": 13},
118
+ {"source": 13, "target": 14},
119
+ {"source": 14, "target": 15},
120
+ {"source": 15, "target": 16},
121
+ ],
122
+ "name": "wheel graph 3",
123
+ }
124
+
125
+ wheel_1 = nx.node_link_graph(data=wheel_graph_1, edges="edges")
126
+ wheel_2 = nx.node_link_graph(data=wheel_graph_2, edges="edges")
127
+ wheel_3 = nx.node_link_graph(data=wheel_graph_3, edges="edges")
pyotc/otc.py ADDED
@@ -0,0 +1,5 @@
1
+ """Main entry point for otc funcitonality"""
2
+
3
+
4
+ def exact_OTC():
5
+ raise NotImplementedError
File without changes
@@ -0,0 +1,3 @@
1
+ """
2
+ Routines producing stationary processes on graphs
3
+ """
@@ -0,0 +1,109 @@
1
+ import numpy as np
2
+
3
+
4
+ def weight(x):
5
+ """
6
+ Normalizes a vector into a probability distribution.
7
+
8
+ Args:
9
+ x (np.ndarray): Input vector.
10
+
11
+ Returns:
12
+ np.ndarray: Normalized vector such that the sum is 1.
13
+ """
14
+ return x / np.sum(x)
15
+
16
+
17
+ def adj_to_trans(A):
18
+ """
19
+ Converts an adjacency matrix into a row-stochastic transition matrix.
20
+
21
+ Args:
22
+ A (np.ndarray): Adjacency matrix of shape (n, n).
23
+
24
+ Returns:
25
+ np.ndarray: Transition matrix of shape (n, n), where each row sums to 1.
26
+ """
27
+ nrow = A.shape[0]
28
+ T = np.copy(A).astype(float)
29
+ for i in range(nrow):
30
+ row = A[i, :]
31
+ k = np.where(row != 0)[0]
32
+ vals = weight(row[k])
33
+ for idx in range(len(k)):
34
+ T[i, k[idx]] = vals[idx]
35
+ row_sums = T.sum(axis=1)
36
+ return T / row_sums[:, np.newaxis]
37
+
38
+
39
+ def get_degree_cost(A1, A2):
40
+ """
41
+ Computes a cost matrix based on squared degree differences between nodes.
42
+
43
+ Args:
44
+ A1 (np.ndarray): First adjacency matrix of shape (n1, n1).
45
+ A2 (np.ndarray): Second adjacency matrix of shape (n2, n2).
46
+
47
+ Returns:
48
+ cost_mat (np.ndarray): Cost matrix of shape (n1, n2) with squared degree differences.
49
+ """
50
+ n1 = A1.shape[0]
51
+ n2 = A2.shape[0]
52
+ degrees1 = np.sum(A1, axis=1)
53
+ degrees2 = np.sum(A2, axis=1)
54
+ cost_mat = np.zeros((n1, n2))
55
+ for i in range(n1):
56
+ for j in range(n2):
57
+ cost_mat[i, j] = (degrees1[i] - degrees2[j]) ** 2
58
+ return cost_mat
59
+
60
+
61
+ def get_01_cost(V1, V2):
62
+ """
63
+ Computes a binary cost matrix between node features of two graphs based on inequality.
64
+
65
+ Given two vectors representing features of nodes from two graphs, this function
66
+ returns a binary cost matrix where each entry is 1 if the corresponding features differ,
67
+ and 0 otherwise.
68
+
69
+ Args:
70
+ V1 (np.ndarray): Feature vector for nodes in graph 1, of shape (n1,).
71
+ V2 (np.ndarray): Feature vector for nodes in graph 2, of shape (n2,).
72
+
73
+ Returns:
74
+ np.ndarray: Binary cost matrix of shape (n1, n2), where entry (i, j) is 1
75
+ if V1[i] != V2[j], else 0.
76
+ """
77
+
78
+ n1 = len(V1)
79
+ n2 = len(V2)
80
+ cost_mat = np.zeros((n1, n2))
81
+ for i in range(n1):
82
+ for j in range(n2):
83
+ cost_mat[i, j] = V1[i] != V2[j]
84
+ return cost_mat
85
+
86
+
87
+ def get_sq_cost(V1, V2):
88
+ """
89
+ Computes a cost matrix based on squared differences between node features of two graphs.
90
+
91
+ Given two vectors representing node features from two graphs, this function computes
92
+ a cost matrix where each entry (i, j) is the squared difference between the i-th feature
93
+ in graph 1 and the j-th feature in graph 2.
94
+
95
+ Args:
96
+ V1 (np.ndarray): Feature vector for nodes in graph 1, of shape (n1,).
97
+ V2 (np.ndarray): Feature vector for nodes in graph 2, of shape (n2,).
98
+
99
+ Returns:
100
+ np.ndarray: Cost matrix of shape (n1, n2), where entry (i, j) = (V1[i] - V2[j]) ** 2.
101
+ """
102
+
103
+ n1 = len(V1)
104
+ n2 = len(V2)
105
+ cost_mat = np.zeros((n1, n2))
106
+ for i in range(n1):
107
+ for j in range(n2):
108
+ cost_mat[i, j] = (V1[i] - V2[j]) ** 2
109
+ return cost_mat
File without changes
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+
3
+
4
+ def round_transpoly(X, r, c):
5
+ A = X.copy()
6
+ # A = copy.deepcopy(X)
7
+ n1, n2 = A.shape
8
+
9
+ r_A = np.sum(A, axis=1)
10
+ for i in range(n1):
11
+ scaling = min(1, r[i] / r_A[i])
12
+ A[i, :] *= scaling
13
+
14
+ c_A = np.sum(A, axis=0)
15
+ for j in range(n2):
16
+ scaling = min(1, c[j] / c_A[j])
17
+ A[:, j] *= scaling
18
+
19
+ r_A = np.sum(A, axis=1)
20
+ c_A = np.sum(A, axis=0)
21
+ err_r = r_A - r
22
+ err_c = c_A - c
23
+
24
+ if not np.all(err_r == 0) and not np.all(err_c == 0):
25
+ A += np.outer(err_r, err_c) / np.sum(np.abs(err_r))
26
+
27
+ return A
28
+
29
+
30
+ def logsumexp(X, axis=None):
31
+ """
32
+ Numerically stable log-sum-exp operation.
33
+
34
+ Args:
35
+ X (np.ndarray): Input array.
36
+ axis (int or tuple of ints, optional): Axis or axes over which to operate.
37
+
38
+ Returns:
39
+ np.ndarray: The result of log(sum(exp(X))) along the specified axis.
40
+ """
41
+
42
+ y = np.max(
43
+ X, axis=axis, keepdims=True
44
+ ) # use 'keepdims' to make matrix operation X-y work
45
+ s = y + np.log(np.sum(np.exp(X - y), axis=axis, keepdims=True))
46
+
47
+ return np.squeeze(s, axis=axis)
48
+
49
+
50
+ def logsinkhorn(A, r, c, T):
51
+ """
52
+ Implementation of classical Sinkhorn algorithm for matrix scaling.
53
+ Each iteration simply alternately updates (projects) all rows or
54
+ all columns to have correct marginals.
55
+
56
+ Args:
57
+ A (np.ndarray): Negative scaled cost matrix of shape (dx, dy), e.g., -xi * cost.
58
+ r (np.ndarray): desired row sums (marginals) (shape: dx,). Should sum to 1.
59
+ c (np.ndarray): desired column sums (marginals) (shape: dy,). Should sum to 1.
60
+ T (int): Number of full Sinkhorn iterations.
61
+
62
+ Returns:
63
+ np.ndarray: Final scaled matrix of shape (dx, dy).
64
+ """
65
+
66
+ dx, dy = A.shape
67
+ f = np.zeros(dx)
68
+ g = np.zeros(dy)
69
+
70
+ for t in range(T):
71
+ if t % 2 == 0:
72
+ f = np.log(r) - logsumexp(A + g, axis=1)
73
+ else:
74
+ g = np.log(c) - logsumexp(A + f[:, np.newaxis], axis=0)
75
+
76
+ P = round_transpoly(np.exp(f[:, np.newaxis] + A + g), r, c)
77
+
78
+ return P
@@ -0,0 +1,49 @@
1
+ """
2
+ Native linear programming implementation for solving optimal transport (OT) problems.
3
+ """
4
+
5
+ import numpy as np
6
+ from scipy.optimize import linprog
7
+
8
+
9
+ def computeot_lp(C, r, c):
10
+ """
11
+ Solves the optimal transport problem using linear programming (LP) with SciPy.
12
+
13
+ Given a cost matrix `C` and distributions `r` and `c`, this function computes the
14
+ optimal transport plan that minimizes the total transport cost.
15
+
16
+ Args:
17
+ C (np.ndarray): Cost matrix of shape (nx, ny), where C[i, j] represents the cost of transporting
18
+ mass from source i to target j.
19
+ r (np.ndarray): Source distribution (shape: nx,). Should sum to 1.
20
+ c (np.ndarray): Target distribution (shape: ny,). Should sum to 1.
21
+
22
+ Returns:
23
+ Tuple[np.ndarray, float]:
24
+ - lp_sol (np.ndarray): Optimal transport plan of shape (nx, ny).
25
+ - lp_val (float): Total transport cost under the optimal plan.
26
+ """
27
+ nx = r.size
28
+ ny = c.size
29
+
30
+ # setup LP
31
+ Aeq = np.zeros((nx + ny, nx * ny))
32
+ beq = np.concatenate((r.flatten(), c.flatten()))
33
+ beq = beq.reshape(-1, 1)
34
+ for row in range(nx):
35
+ for t in range(ny):
36
+ Aeq[row, (row * ny) + t] = 1
37
+ for row in range(nx, nx + ny):
38
+ for t in range(nx):
39
+ Aeq[row, t * ny + (row - nx)] = 1
40
+ cost = C.reshape(-1, 1)
41
+
42
+ # Bound
43
+ bound = [[0, None]] * (nx * ny)
44
+
45
+ # Solve OT LP using linprog
46
+ res = linprog(cost, A_eq=Aeq, b_eq=beq, bounds=bound, method="highs")
47
+ lp_sol = res.x
48
+ lp_val = res.fun
49
+ return lp_sol, lp_val
@@ -0,0 +1,51 @@
1
+ """Yuning's other other native implementation of lp ot"""
2
+
3
+ import numpy as np
4
+ from scipy.optimize import linprog
5
+ from typing import Any
6
+
7
+
8
+ def setup_rows(Aeq: np.ndarray, nx: int, ny: int) -> None:
9
+ for row in range(nx):
10
+ for t in range(ny):
11
+ Aeq[row, (row * ny) + t] = 1
12
+ return None
13
+
14
+
15
+ def setup_columns(Aeq: np.ndarray, nx: int, ny: int) -> None:
16
+ for row in range(nx):
17
+ for t in range(ny):
18
+ Aeq[row, (row * ny) + t] = 1
19
+ return None
20
+
21
+
22
+ def computeot_lp(C: np.ndarray, r: np.ndarray, c: np.ndarray) -> tuple[Any, Any]:
23
+ """Compute optimal transport mapping via LP.
24
+
25
+ Args:
26
+ C (np.ndarray): cost
27
+ r (np.ndarray): _description_
28
+ c (np.ndarray): _description_
29
+
30
+ Returns:
31
+ tuple[Any, Any]: _description_
32
+ """
33
+ nx = r.size
34
+ ny = c.size
35
+
36
+ # setup LP
37
+ Aeq = np.zeros((nx + ny, nx * ny))
38
+ beq = np.concatenate((r.flatten(), c.flatten()))
39
+ beq = beq.reshape(-1, 1)
40
+ setup_rows(Aeq, nx, ny)
41
+ setup_columns(Aeq, nx, ny)
42
+ cost = C.reshape(-1, 1)
43
+
44
+ # Bound
45
+ bound = [[0, None]] * (nx * ny)
46
+
47
+ # Solve OT LP using linprog
48
+ res = linprog(cost, A_eq=Aeq, b_eq=beq, bounds=bound, method="highs")
49
+ lp_sol = res.x
50
+ lp_val = res.fun
51
+ return lp_sol, lp_val
@@ -0,0 +1,38 @@
1
+ """
2
+ A wrapper for the Python Optimal Transport (POT) library.
3
+
4
+ This module provides a simplified interface for computing optimal transport plans
5
+ and their associated costs using the POT library.
6
+ """
7
+
8
+ import numpy as np
9
+ import ot
10
+
11
+
12
+ def computeot_pot(C, r, c):
13
+ """
14
+ Computes the optimal transport plan and its total cost using the POT library.
15
+
16
+ Given a cost matrix `C` and distributions `r` and `c`, this function computes the
17
+ optimal transport plan that minimizes the total transport cost.
18
+
19
+ Args:
20
+ C (np.ndarray): Cost matrix of shape (n, m), where C[i, j] represents the cost of transporting
21
+ mass from source i to target j.
22
+ r (np.ndarray): Source distribution (shape: n,). Should sum to 1.
23
+ c (np.ndarray): Target distribution (shape: m,). Should sum to 1.
24
+
25
+ Returns:
26
+ Tuple[np.ndarray, float]:
27
+ - lp_sol (np.ndarray): Optimal transport plan of shape (n, m).
28
+ - lp_val (float): Total transport cost under the optimal plan.
29
+ """
30
+ # Ensure r and c are numpy arrays
31
+ r = np.array(r).flatten()
32
+ c = np.array(c).flatten()
33
+
34
+ # Compute the optimal transport plan and the cost using the ot.emd function
35
+ lp_sol = ot.emd(r, c, C)
36
+ lp_val = np.sum(lp_sol * C)
37
+
38
+ return lp_sol, lp_val
File without changes
File without changes