allocator 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- allocator/__init__.py +154 -0
- allocator/api/__init__.py +32 -0
- allocator/api/cluster.py +126 -0
- allocator/api/distance.py +225 -0
- allocator/api/route.py +256 -0
- allocator/api/types.py +52 -0
- allocator/cli/__init__.py +1 -0
- allocator/cli/cluster_cmd.py +104 -0
- allocator/cli/main.py +170 -0
- allocator/cli/route_cmd.py +164 -0
- allocator/core/__init__.py +1 -0
- allocator/core/algorithms.py +200 -0
- allocator/core/routing.py +242 -0
- allocator/distances/__init__.py +17 -0
- allocator/distances/euclidean.py +80 -0
- allocator/distances/external_apis.py +165 -0
- allocator/distances/factory.py +66 -0
- allocator/distances/haversine.py +43 -0
- allocator/io/__init__.py +1 -0
- allocator/io/data_handler.py +174 -0
- allocator/py.typed +2 -0
- allocator/utils.py +37 -0
- allocator/viz/__init__.py +17 -0
- allocator/viz/plotting.py +206 -0
- allocator-1.0.0.dist-info/METADATA +132 -0
- allocator-1.0.0.dist-info/RECORD +28 -0
- allocator-1.0.0.dist-info/WHEEL +4 -0
- allocator-1.0.0.dist-info/entry_points.txt +3 -0
allocator/api/route.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Modern routing API for allocator package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from ..io.data_handler import DataHandler
|
|
11
|
+
from .types import RouteResult
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _solve_tsp_nearest_neighbor(distance_matrix: np.ndarray) -> tuple[float, list[int]]:
|
|
15
|
+
"""
|
|
16
|
+
Solve TSP using nearest neighbor heuristic.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
distance_matrix: Symmetric distance matrix
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Tuple of (total_distance, route)
|
|
23
|
+
"""
|
|
24
|
+
n = len(distance_matrix)
|
|
25
|
+
if n == 0:
|
|
26
|
+
return 0.0, []
|
|
27
|
+
if n == 1:
|
|
28
|
+
return 0.0, [0]
|
|
29
|
+
|
|
30
|
+
# Start from node 0
|
|
31
|
+
unvisited = set(range(1, n))
|
|
32
|
+
route = [0]
|
|
33
|
+
current = 0
|
|
34
|
+
total_distance = 0.0
|
|
35
|
+
|
|
36
|
+
while unvisited:
|
|
37
|
+
# Find nearest unvisited node
|
|
38
|
+
nearest = min(unvisited, key=lambda x: distance_matrix[current, x])
|
|
39
|
+
total_distance += distance_matrix[current, nearest]
|
|
40
|
+
route.append(nearest)
|
|
41
|
+
unvisited.remove(nearest)
|
|
42
|
+
current = nearest
|
|
43
|
+
|
|
44
|
+
# Return to start
|
|
45
|
+
total_distance += distance_matrix[current, 0]
|
|
46
|
+
route.append(0)
|
|
47
|
+
|
|
48
|
+
return total_distance, route
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def shortest_path(
|
|
52
|
+
data: str | pd.DataFrame | np.ndarray | list,
|
|
53
|
+
method: str = "christofides",
|
|
54
|
+
distance: str = "euclidean",
|
|
55
|
+
**kwargs,
|
|
56
|
+
) -> RouteResult:
|
|
57
|
+
"""
|
|
58
|
+
Find shortest path through geographic points (TSP).
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
data: Input data (file path, DataFrame, numpy array, or list)
|
|
62
|
+
method: TSP solving method ('christofides', 'ortools', 'osrm', 'google')
|
|
63
|
+
distance: Distance metric ('euclidean', 'haversine', 'osrm', 'google')
|
|
64
|
+
**kwargs: Additional method-specific arguments
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
RouteResult with optimal route and total distance
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
>>> result = shortest_path('points.csv', method='ortools')
|
|
71
|
+
>>> print(result.route) # Optimal visiting order
|
|
72
|
+
>>> print(result.total_distance) # Total route distance
|
|
73
|
+
"""
|
|
74
|
+
# Load and standardize data
|
|
75
|
+
df = DataHandler.load_data(data)
|
|
76
|
+
|
|
77
|
+
if method == "christofides":
|
|
78
|
+
return tsp_christofides(df, distance=distance, **kwargs)
|
|
79
|
+
elif method == "ortools":
|
|
80
|
+
return tsp_ortools(df, distance=distance, **kwargs)
|
|
81
|
+
elif method == "osrm":
|
|
82
|
+
return tsp_osrm(df, **kwargs)
|
|
83
|
+
elif method == "google":
|
|
84
|
+
return tsp_google(df, **kwargs)
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Unknown routing method: {method}")
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def tsp_christofides(
|
|
90
|
+
data: pd.DataFrame | np.ndarray, distance: str = "euclidean", **kwargs
|
|
91
|
+
) -> RouteResult:
|
|
92
|
+
"""
|
|
93
|
+
Solve TSP using Christofides algorithm (approximate).
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
data: Input data as DataFrame or numpy array
|
|
97
|
+
distance: Distance metric
|
|
98
|
+
**kwargs: Additional arguments
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
RouteResult with approximate optimal route
|
|
102
|
+
"""
|
|
103
|
+
from ..core.routing import solve_tsp_christofides
|
|
104
|
+
|
|
105
|
+
# Load and standardize data
|
|
106
|
+
df = DataHandler.load_data(data)
|
|
107
|
+
|
|
108
|
+
# Extract coordinates
|
|
109
|
+
if "longitude" in df.columns and "latitude" in df.columns:
|
|
110
|
+
points = df[["longitude", "latitude"]].values
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError("Data must contain 'longitude' and 'latitude' columns")
|
|
113
|
+
|
|
114
|
+
# Solve TSP
|
|
115
|
+
total_distance, route = solve_tsp_christofides(points, distance_method=distance, **kwargs)
|
|
116
|
+
|
|
117
|
+
# Create result DataFrame with route order
|
|
118
|
+
result_df = df.iloc[route].copy()
|
|
119
|
+
result_df["route_order"] = range(len(route))
|
|
120
|
+
|
|
121
|
+
return RouteResult(
|
|
122
|
+
route=route,
|
|
123
|
+
total_distance=total_distance,
|
|
124
|
+
data=result_df,
|
|
125
|
+
metadata={"method": "christofides", "distance": distance, "n_points": len(points)},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def tsp_ortools(
|
|
130
|
+
data: pd.DataFrame | np.ndarray, distance: str = "euclidean", **kwargs
|
|
131
|
+
) -> RouteResult:
|
|
132
|
+
"""
|
|
133
|
+
Solve TSP using Google OR-Tools (exact for small problems).
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
data: Input data as DataFrame or numpy array
|
|
137
|
+
distance: Distance metric
|
|
138
|
+
**kwargs: Additional arguments
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
RouteResult with optimal route
|
|
142
|
+
"""
|
|
143
|
+
from ..core.routing import solve_tsp_ortools
|
|
144
|
+
|
|
145
|
+
# Load and standardize data
|
|
146
|
+
df = DataHandler.load_data(data)
|
|
147
|
+
|
|
148
|
+
# Extract coordinates
|
|
149
|
+
if "longitude" in df.columns and "latitude" in df.columns:
|
|
150
|
+
points = df[["longitude", "latitude"]].values
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError("Data must contain 'longitude' and 'latitude' columns")
|
|
153
|
+
|
|
154
|
+
# Solve TSP
|
|
155
|
+
total_distance, route = solve_tsp_ortools(points, distance_method=distance, **kwargs)
|
|
156
|
+
|
|
157
|
+
# Create result DataFrame with route order
|
|
158
|
+
result_df = df.iloc[route].copy()
|
|
159
|
+
result_df["route_order"] = range(len(route))
|
|
160
|
+
|
|
161
|
+
return RouteResult(
|
|
162
|
+
route=route,
|
|
163
|
+
total_distance=total_distance,
|
|
164
|
+
data=result_df,
|
|
165
|
+
metadata={"method": "ortools", "distance": distance, "n_points": len(points)},
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def tsp_osrm(
|
|
170
|
+
data: pd.DataFrame | np.ndarray, osrm_base_url: str | None = None, **kwargs
|
|
171
|
+
) -> RouteResult:
|
|
172
|
+
"""
|
|
173
|
+
Solve TSP using OSRM distance matrix and nearest neighbor heuristic.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
data: Input data as DataFrame or numpy array
|
|
177
|
+
osrm_base_url: Custom OSRM server URL
|
|
178
|
+
**kwargs: Additional arguments
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
RouteResult with route using real road network
|
|
182
|
+
"""
|
|
183
|
+
points = DataHandler.load_data(data)
|
|
184
|
+
|
|
185
|
+
if len(points) == 0:
|
|
186
|
+
raise ValueError("Cannot solve TSP with empty data")
|
|
187
|
+
if len(points) == 1:
|
|
188
|
+
route = [0]
|
|
189
|
+
total_distance = 0.0
|
|
190
|
+
else:
|
|
191
|
+
# Use OSRM distance matrix to solve TSP with nearest neighbor heuristic
|
|
192
|
+
from ..distances import osrm_distance_matrix
|
|
193
|
+
|
|
194
|
+
distances = osrm_distance_matrix(
|
|
195
|
+
points[["longitude", "latitude"]].values, osrm_base_url=osrm_base_url
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
total_distance, route = _solve_tsp_nearest_neighbor(distances)
|
|
199
|
+
|
|
200
|
+
# Create result DataFrame
|
|
201
|
+
result_df = points.copy()
|
|
202
|
+
result_df["route_order"] = [route.index(i) if i in route else -1 for i in range(len(points))]
|
|
203
|
+
|
|
204
|
+
return RouteResult(
|
|
205
|
+
route=route,
|
|
206
|
+
total_distance=total_distance,
|
|
207
|
+
data=result_df,
|
|
208
|
+
metadata={"method": "osrm", "osrm_base_url": osrm_base_url, "n_points": len(points)},
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def tsp_google(data: pd.DataFrame | np.ndarray, api_key: str, **kwargs) -> RouteResult:
|
|
213
|
+
"""
|
|
214
|
+
Solve TSP using Google Maps distance matrix and nearest neighbor heuristic.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
data: Input data as DataFrame or numpy array
|
|
218
|
+
api_key: Google Maps API key
|
|
219
|
+
**kwargs: Additional arguments
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
RouteResult with route using Google's road network
|
|
223
|
+
"""
|
|
224
|
+
points = DataHandler.load_data(data)
|
|
225
|
+
|
|
226
|
+
if len(points) == 0:
|
|
227
|
+
raise ValueError("Cannot solve TSP with empty data")
|
|
228
|
+
if len(points) == 1:
|
|
229
|
+
route = [0]
|
|
230
|
+
total_distance = 0.0
|
|
231
|
+
else:
|
|
232
|
+
# Use Google Maps distance matrix to solve TSP with nearest neighbor heuristic
|
|
233
|
+
from ..distances import google_distance_matrix
|
|
234
|
+
|
|
235
|
+
distances = google_distance_matrix(
|
|
236
|
+
points[["longitude", "latitude"]].values,
|
|
237
|
+
api_key=api_key,
|
|
238
|
+
duration=False, # Get distance, not duration
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
total_distance, route = _solve_tsp_nearest_neighbor(distances)
|
|
242
|
+
|
|
243
|
+
# Create result DataFrame
|
|
244
|
+
result_df = points.copy()
|
|
245
|
+
result_df["route_order"] = [route.index(i) if i in route else -1 for i in range(len(points))]
|
|
246
|
+
|
|
247
|
+
return RouteResult(
|
|
248
|
+
route=route,
|
|
249
|
+
total_distance=total_distance,
|
|
250
|
+
data=result_df,
|
|
251
|
+
metadata={
|
|
252
|
+
"method": "google",
|
|
253
|
+
"api_key": "***" if api_key else None, # Don't log actual API key
|
|
254
|
+
"n_points": len(points),
|
|
255
|
+
},
|
|
256
|
+
)
|
allocator/api/types.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Type definitions and dataclasses for allocator API.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ClusterResult:
|
|
16
|
+
"""Result of clustering operation."""
|
|
17
|
+
|
|
18
|
+
labels: np.ndarray
|
|
19
|
+
centroids: np.ndarray
|
|
20
|
+
n_iter: int
|
|
21
|
+
inertia: float | None
|
|
22
|
+
data: pd.DataFrame
|
|
23
|
+
converged: bool
|
|
24
|
+
metadata: dict[str, Any]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SortResult:
|
|
29
|
+
"""Result of sort by distance operation."""
|
|
30
|
+
|
|
31
|
+
data: pd.DataFrame
|
|
32
|
+
distance_matrix: np.ndarray | None
|
|
33
|
+
metadata: dict[str, Any]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class RouteResult:
|
|
38
|
+
"""Result of shortest path operation."""
|
|
39
|
+
|
|
40
|
+
route: list[int]
|
|
41
|
+
total_distance: float
|
|
42
|
+
data: pd.DataFrame
|
|
43
|
+
metadata: dict[str, Any]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ComparisonResult:
|
|
48
|
+
"""Result of algorithm comparison."""
|
|
49
|
+
|
|
50
|
+
results: dict[str, ClusterResult]
|
|
51
|
+
statistics: pd.DataFrame
|
|
52
|
+
metadata: dict[str, Any]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Command-line interface."""
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Clustering CLI commands.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
|
|
8
|
+
console = Console()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@click.command()
|
|
12
|
+
@click.argument("input_file", type=click.Path(exists=True))
|
|
13
|
+
@click.option("--n-clusters", "-n", type=int, required=True, help="Number of clusters")
|
|
14
|
+
@click.option(
|
|
15
|
+
"--distance",
|
|
16
|
+
"-d",
|
|
17
|
+
default="euclidean",
|
|
18
|
+
type=click.Choice(["euclidean", "haversine", "osrm", "google"]),
|
|
19
|
+
help="Distance metric to use",
|
|
20
|
+
)
|
|
21
|
+
@click.option("--max-iter", type=int, default=300, help="Maximum number of iterations")
|
|
22
|
+
@click.option("--random-state", type=int, help="Random seed for reproducibility")
|
|
23
|
+
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
24
|
+
@click.option("--centroids", "-c", type=click.Path(), help="Centroids output file path")
|
|
25
|
+
@click.option("--plot", is_flag=True, help="Show clustering plot")
|
|
26
|
+
@click.option("--save-plot", type=click.Path(), help="Save plot to file")
|
|
27
|
+
@click.option(
|
|
28
|
+
"--format",
|
|
29
|
+
"output_format",
|
|
30
|
+
default="csv",
|
|
31
|
+
type=click.Choice(["csv", "json"]),
|
|
32
|
+
help="Output format",
|
|
33
|
+
)
|
|
34
|
+
@click.pass_context
|
|
35
|
+
def kmeans(
|
|
36
|
+
ctx,
|
|
37
|
+
input_file,
|
|
38
|
+
n_clusters,
|
|
39
|
+
distance,
|
|
40
|
+
max_iter,
|
|
41
|
+
random_state,
|
|
42
|
+
output,
|
|
43
|
+
centroids,
|
|
44
|
+
plot,
|
|
45
|
+
save_plot,
|
|
46
|
+
output_format,
|
|
47
|
+
):
|
|
48
|
+
"""K-means clustering of geographic data."""
|
|
49
|
+
from ..api import kmeans as kmeans_func
|
|
50
|
+
from ..io.data_handler import DataHandler
|
|
51
|
+
from ..viz.plotting import plot_clusters
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
# Run clustering
|
|
55
|
+
result = kmeans_func(
|
|
56
|
+
input_file,
|
|
57
|
+
n_clusters=n_clusters,
|
|
58
|
+
distance=distance,
|
|
59
|
+
max_iter=max_iter,
|
|
60
|
+
random_state=random_state,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if ctx.obj.get("verbose"):
|
|
64
|
+
console.print(f"[green]K-means converged: {result.converged}[/green]")
|
|
65
|
+
console.print(f"Iterations: {result.n_iter}")
|
|
66
|
+
if result.inertia:
|
|
67
|
+
console.print(f"Inertia: {result.inertia:.2f}")
|
|
68
|
+
|
|
69
|
+
# Save results
|
|
70
|
+
if output:
|
|
71
|
+
DataHandler.save_results(result, output, format=output_format)
|
|
72
|
+
console.print(f"[green]Results saved to {output}[/green]")
|
|
73
|
+
|
|
74
|
+
# Save centroids
|
|
75
|
+
if centroids:
|
|
76
|
+
import pandas as pd
|
|
77
|
+
|
|
78
|
+
centroids_df = pd.DataFrame(result.centroids, columns=["longitude", "latitude"])
|
|
79
|
+
if output_format == "csv":
|
|
80
|
+
centroids_df.to_csv(centroids, index=False)
|
|
81
|
+
else:
|
|
82
|
+
centroids_df.to_json(centroids, orient="records", indent=2)
|
|
83
|
+
console.print(f"[green]Centroids saved to {centroids}[/green]")
|
|
84
|
+
|
|
85
|
+
# Plotting
|
|
86
|
+
if plot or save_plot:
|
|
87
|
+
plot_clusters(
|
|
88
|
+
result.data,
|
|
89
|
+
result.labels,
|
|
90
|
+
result.centroids,
|
|
91
|
+
title=f"K-means Clustering (n={n_clusters})",
|
|
92
|
+
save_path=save_plot,
|
|
93
|
+
show=plot,
|
|
94
|
+
)
|
|
95
|
+
if save_plot:
|
|
96
|
+
console.print(f"[green]Plot saved to {save_plot}[/green]")
|
|
97
|
+
|
|
98
|
+
if not output:
|
|
99
|
+
console.print("\nFirst 10 results:")
|
|
100
|
+
console.print(result.data.head(10).to_string())
|
|
101
|
+
|
|
102
|
+
except Exception as e:
|
|
103
|
+
console.print(f"[red]Error: {e}[/red]")
|
|
104
|
+
raise click.Abort() from e
|
allocator/cli/main.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Modern CLI interface for allocator package using Click.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
|
|
9
|
+
from .. import __version__
|
|
10
|
+
from .cluster_cmd import kmeans
|
|
11
|
+
from .route_cmd import christofides, ortools, tsp
|
|
12
|
+
|
|
13
|
+
console = Console()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@click.group()
|
|
17
|
+
@click.version_option(version=__version__)
|
|
18
|
+
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose output")
|
|
19
|
+
@click.pass_context
|
|
20
|
+
def cli(ctx, verbose):
|
|
21
|
+
"""
|
|
22
|
+
Allocator v1.0 - Modern geographic task allocation, clustering, and routing.
|
|
23
|
+
|
|
24
|
+
Examples:
|
|
25
|
+
allocator cluster kmeans data.csv -n 5
|
|
26
|
+
allocator route tsp points.csv --method ortools
|
|
27
|
+
allocator sort points.csv --workers workers.csv
|
|
28
|
+
"""
|
|
29
|
+
ctx.ensure_object(dict)
|
|
30
|
+
ctx.obj["verbose"] = verbose
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@cli.group()
|
|
34
|
+
def cluster():
|
|
35
|
+
"""Cluster geographic data points."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@cli.group()
|
|
40
|
+
def route():
|
|
41
|
+
"""Find shortest paths through points (TSP)."""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
cluster.add_command(kmeans)
|
|
46
|
+
route.add_command(tsp)
|
|
47
|
+
route.add_command(christofides)
|
|
48
|
+
route.add_command(ortools)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@cli.command()
|
|
52
|
+
@click.argument("points", type=click.Path(exists=True))
|
|
53
|
+
@click.option("--workers", "-w", type=click.Path(exists=True), help="Worker locations file")
|
|
54
|
+
@click.option("--by-worker", is_flag=True, help="Sort points by worker instead of workers by point")
|
|
55
|
+
@click.option(
|
|
56
|
+
"--distance",
|
|
57
|
+
"-d",
|
|
58
|
+
default="euclidean",
|
|
59
|
+
type=click.Choice(["euclidean", "haversine", "osrm", "google"]),
|
|
60
|
+
help="Distance metric to use",
|
|
61
|
+
)
|
|
62
|
+
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
|
63
|
+
@click.option(
|
|
64
|
+
"--format",
|
|
65
|
+
"output_format",
|
|
66
|
+
default="csv",
|
|
67
|
+
type=click.Choice(["csv", "json"]),
|
|
68
|
+
help="Output format",
|
|
69
|
+
)
|
|
70
|
+
@click.pass_context
|
|
71
|
+
def sort(ctx, points, workers, by_worker, distance, output, output_format):
|
|
72
|
+
"""Sort points by distance to workers or assign to closest."""
|
|
73
|
+
from ..api import sort_by_distance
|
|
74
|
+
from ..io.data_handler import DataHandler
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
if workers:
|
|
78
|
+
if by_worker:
|
|
79
|
+
result = sort_by_distance(points, workers, by_worker=True, distance=distance)
|
|
80
|
+
else:
|
|
81
|
+
result = sort_by_distance(points, workers, by_worker=False, distance=distance)
|
|
82
|
+
else:
|
|
83
|
+
console.print("[red]Error: --workers option is required[/red]")
|
|
84
|
+
raise click.Abort()
|
|
85
|
+
|
|
86
|
+
# Save results
|
|
87
|
+
if output:
|
|
88
|
+
DataHandler.save_results(result, output, format=output_format)
|
|
89
|
+
console.print(f"[green]Results saved to {output}[/green]")
|
|
90
|
+
else:
|
|
91
|
+
console.print(result.data.head())
|
|
92
|
+
|
|
93
|
+
if ctx.obj["verbose"]:
|
|
94
|
+
console.print(f"Processed {len(result.data)} assignments")
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
console.print(f"[red]Error: {e}[/red]")
|
|
98
|
+
raise click.Abort() from e
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@cli.command()
|
|
102
|
+
@click.argument("input_file", type=click.Path(exists=True))
|
|
103
|
+
@click.option(
|
|
104
|
+
"--algorithms",
|
|
105
|
+
"-a",
|
|
106
|
+
default="kmeans",
|
|
107
|
+
help="Comma-separated list of algorithms to compare",
|
|
108
|
+
)
|
|
109
|
+
@click.option("--n-clusters", "-n", type=int, required=True, help="Number of clusters")
|
|
110
|
+
@click.option(
|
|
111
|
+
"--distance",
|
|
112
|
+
"-d",
|
|
113
|
+
default="euclidean",
|
|
114
|
+
type=click.Choice(["euclidean", "haversine", "osrm"]),
|
|
115
|
+
help="Distance metric to use",
|
|
116
|
+
)
|
|
117
|
+
@click.option("--output", "-o", type=click.Path(), help="Output file for comparison results")
|
|
118
|
+
@click.pass_context
|
|
119
|
+
def compare(ctx, input_file, algorithms, n_clusters, distance, output):
|
|
120
|
+
"""Compare clustering algorithms."""
|
|
121
|
+
from ..api import cluster
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
algos = [algo.strip() for algo in algorithms.split(",")]
|
|
125
|
+
results = {}
|
|
126
|
+
|
|
127
|
+
for algo in algos:
|
|
128
|
+
if algo in ["kmeans"]:
|
|
129
|
+
console.print(f"Running {algo} clustering...")
|
|
130
|
+
result = cluster(input_file, n_clusters=n_clusters, method=algo, distance=distance)
|
|
131
|
+
results[algo] = result
|
|
132
|
+
else:
|
|
133
|
+
console.print(f"[yellow]Warning: Unknown algorithm '{algo}', skipping[/yellow]")
|
|
134
|
+
|
|
135
|
+
# Create comparison table
|
|
136
|
+
table = Table(title="Clustering Comparison")
|
|
137
|
+
table.add_column("Algorithm", style="cyan")
|
|
138
|
+
table.add_column("Converged", style="green")
|
|
139
|
+
table.add_column("Iterations", style="magenta")
|
|
140
|
+
table.add_column("Inertia", style="yellow")
|
|
141
|
+
|
|
142
|
+
for algo, result in results.items():
|
|
143
|
+
converged = "Yes" if result.converged else "No"
|
|
144
|
+
iterations = str(result.n_iter)
|
|
145
|
+
inertia = f"{result.inertia:.2f}" if result.inertia else "N/A"
|
|
146
|
+
table.add_row(algo, converged, iterations, inertia)
|
|
147
|
+
|
|
148
|
+
console.print(table)
|
|
149
|
+
|
|
150
|
+
if output:
|
|
151
|
+
# Save detailed comparison
|
|
152
|
+
comparison_data = []
|
|
153
|
+
for algo, result in results.items():
|
|
154
|
+
df = result.data.copy()
|
|
155
|
+
df["algorithm"] = algo
|
|
156
|
+
comparison_data.append(df)
|
|
157
|
+
|
|
158
|
+
import pandas as pd
|
|
159
|
+
|
|
160
|
+
combined_df = pd.concat(comparison_data, ignore_index=True)
|
|
161
|
+
combined_df.to_csv(output, index=False)
|
|
162
|
+
console.print(f"[green]Detailed results saved to {output}[/green]")
|
|
163
|
+
|
|
164
|
+
except Exception as e:
|
|
165
|
+
console.print(f"[red]Error: {e}[/red]")
|
|
166
|
+
raise click.Abort() from e
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
if __name__ == "__main__":
|
|
170
|
+
cli()
|