sibi-dst 0.3.42__py3-none-any.whl → 0.3.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. sibi_dst/df_helper/_artifact_updater_multi_wrapper.py +7 -2
  2. sibi_dst/df_helper/_df_helper.py +5 -2
  3. sibi_dst/df_helper/_parquet_artifact.py +33 -3
  4. sibi_dst/df_helper/_parquet_reader.py +5 -1
  5. sibi_dst/df_helper/backends/django/_load_from_db.py +1 -0
  6. sibi_dst/df_helper/backends/parquet/_filter_handler.py +2 -1
  7. sibi_dst/df_helper/backends/parquet/_parquet_options.py +2 -3
  8. sibi_dst/df_helper/backends/sqlalchemy/_db_connection.py +1 -0
  9. sibi_dst/df_helper/backends/sqlalchemy/_io_dask.py +2 -5
  10. sibi_dst/df_helper/core/_filter_handler.py +2 -1
  11. sibi_dst/osmnx_helper/__init__.py +2 -2
  12. sibi_dst/osmnx_helper/v1/basemaps/__init__.py +0 -0
  13. sibi_dst/osmnx_helper/{basemaps → v1/basemaps}/router_plotter.py +85 -30
  14. sibi_dst/osmnx_helper/v2/__init__.py +0 -0
  15. sibi_dst/osmnx_helper/v2/base_osm_map.py +153 -0
  16. sibi_dst/osmnx_helper/v2/basemaps/__init__.py +0 -0
  17. sibi_dst/osmnx_helper/v2/basemaps/utils.py +0 -0
  18. sibi_dst/utils/data_wrapper.py +4 -368
  19. sibi_dst/utils/df_utils.py +7 -0
  20. sibi_dst/utils/log_utils.py +6 -0
  21. sibi_dst/utils/parquet_saver.py +4 -2
  22. sibi_dst/utils/storage_manager.py +14 -7
  23. sibi_dst-0.3.44.dist-info/METADATA +194 -0
  24. {sibi_dst-0.3.42.dist-info → sibi_dst-0.3.44.dist-info}/RECORD +29 -24
  25. sibi_dst-0.3.42.dist-info/METADATA +0 -62
  26. /sibi_dst/osmnx_helper/{basemaps → v1}/__init__.py +0 -0
  27. /sibi_dst/osmnx_helper/{base_osm_map.py → v1/base_osm_map.py} +0 -0
  28. /sibi_dst/osmnx_helper/{basemaps → v1/basemaps}/calendar_html.py +0 -0
  29. /sibi_dst/osmnx_helper/{utils.py → v1/utils.py} +0 -0
  30. {sibi_dst-0.3.42.dist-info → sibi_dst-0.3.44.dist-info}/WHEEL +0 -0
@@ -25,7 +25,7 @@ class ArtifactUpdaterMultiWrapper:
25
25
  def __init__(self, wrapped_classes=None, debug=False, **kwargs):
26
26
  self.wrapped_classes = wrapped_classes or {}
27
27
  self.debug = debug
28
- self.logger = Logger.default_logger(logger_name=self.__class__.__name__)
28
+ self.logger = kwargs.setdefault('logger',Logger.default_logger(logger_name=self.__class__.__name__))
29
29
  self.logger.set_level(logging.DEBUG if debug else logging.INFO)
30
30
 
31
31
  today = datetime.datetime.today()
@@ -73,7 +73,12 @@ class ArtifactUpdaterMultiWrapper:
73
73
  raise ValueError(f"Unsupported data type: {data_type}")
74
74
 
75
75
  return [
76
- artifact_class(parquet_start_date=self.parquet_start_date, parquet_end_date=self.parquet_end_date)
76
+ artifact_class(
77
+ parquet_start_date=self.parquet_start_date,
78
+ parquet_end_date=self.parquet_end_date,
79
+ logger=self.logger,
80
+ debug=self.debug
81
+ )
77
82
  for artifact_class in self.wrapped_classes[data_type]
78
83
  ]
79
84
 
@@ -112,6 +112,7 @@ class DfHelper:
112
112
  :return: None
113
113
  """
114
114
  self.logger.debug(f"backend used: {self.backend}")
115
+ self.logger.debug(f"kwargs passed to backend plugins: {kwargs}")
115
116
  self._backend_query = self.__get_config(QueryConfig, kwargs)
116
117
  self._backend_params = self.__get_config(ParamsConfig, kwargs)
117
118
  if self.backend == 'django_db':
@@ -124,8 +125,8 @@ class DfHelper:
124
125
  elif self.backend == 'sqlalchemy':
125
126
  self.backend_sqlalchemy = self.__get_config(SqlAlchemyConnectionConfig, kwargs)
126
127
 
127
- @staticmethod
128
- def __get_config(model: [T], kwargs: Dict[str, Any]) -> Union[T]:
128
+
129
+ def __get_config(self, model: [T], kwargs: Dict[str, Any]) -> Union[T]:
129
130
  """
130
131
  Initializes a Pydantic model with the keys it recognizes from the kwargs,
131
132
  and removes those keys from the kwargs dictionary.
@@ -135,7 +136,9 @@ class DfHelper:
135
136
  """
136
137
  # Extract keys that the model can accept
137
138
  recognized_keys = set(model.__annotations__.keys())
139
+ self.logger.debug(f"recognized keys: {recognized_keys}")
138
140
  model_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in recognized_keys}
141
+ self.logger.debug(f"model_kwargs: {model_kwargs}")
139
142
  return model(**model_kwargs)
140
143
 
141
144
  def load_parallel(self, **options):
@@ -1,11 +1,12 @@
1
+ import datetime
2
+ import logging
1
3
  from typing import Optional, Any, Dict
2
4
 
3
5
  import dask.dataframe as dd
4
6
  import fsspec
5
7
 
6
8
  from sibi_dst.df_helper import DfHelper
7
- from sibi_dst.utils import DataWrapper
8
- from sibi_dst.utils import DateUtils
9
+ from sibi_dst.utils import DataWrapper, DateUtils, Logger
9
10
 
10
11
 
11
12
  class ParquetArtifact(DfHelper):
@@ -82,7 +83,12 @@ class ParquetArtifact(DfHelper):
82
83
  **kwargs,
83
84
  }
84
85
  self.df: Optional[dd.DataFrame] = None
86
+ self.debug = self.config.setdefault('debug', False)
87
+ self.logger = self.config.setdefault('logger',Logger.default_logger(logger_name=f'parquet_artifact_{__class__.__name__}'))
88
+ self.logger.set_level(logging.DEBUG if self.debug else logging.INFO)
85
89
  self.data_wrapper_class = data_wrapper_class
90
+ self.class_params = self.config.setdefault('class_params', None)
91
+ self.load_params = self.config.setdefault('load_params', None)
86
92
  self.date_field = self.config.setdefault('date_field', None)
87
93
  if self.date_field is None:
88
94
  raise ValueError('date_field must be set')
@@ -131,7 +137,30 @@ class ParquetArtifact(DfHelper):
131
137
 
132
138
  def update_parquet(self, period: str = 'today', **kwargs) -> None:
133
139
  """Update the Parquet file with data from a specific period."""
134
- kwargs.update(self.parse_parquet_period(period=period))
140
+
141
+ def itd_config():
142
+ try:
143
+ start_date = kwargs.pop('history_begins_on')
144
+ except KeyError:
145
+ raise ValueError("For period 'itd', you must provide 'history_begins_on' in kwargs.")
146
+ return {'parquet_start_date': start_date, 'parquet_end_date': datetime.date.today().strftime('%Y-%m-%d')}
147
+
148
+ def ytd_config():
149
+ return {
150
+ 'parquet_start_date': datetime.date(datetime.date.today().year, 1, 1).strftime('%Y-%m-%d'),
151
+ 'parquet_end_date': datetime.date.today().strftime('%Y-%m-%d')
152
+ }
153
+
154
+ config_map = {
155
+ 'itd': itd_config,
156
+ 'ytd': ytd_config
157
+ }
158
+
159
+ if period in config_map:
160
+ kwargs.update(config_map[period]())
161
+ else:
162
+ kwargs.update(self.parse_parquet_period(period=period))
163
+ print(kwargs)
135
164
  self.generate_parquet(**kwargs)
136
165
 
137
166
  def rebuild_parquet(self, **kwargs) -> None:
@@ -150,6 +179,7 @@ class ParquetArtifact(DfHelper):
150
179
 
151
180
  def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
152
181
  """Prepare the parameters for generating the Parquet file."""
182
+ kwargs = {**self.config, **kwargs}
153
183
  return {
154
184
  'class_params': kwargs.pop('class_params', None),
155
185
  'date_field': kwargs.pop('date_field', self.date_field),
@@ -1,10 +1,11 @@
1
+ import logging
1
2
  from typing import Optional
2
3
 
3
4
  import dask.dataframe as dd
4
5
  import fsspec
5
6
 
6
7
  from sibi_dst.df_helper import DfHelper
7
-
8
+ from sibi_dst.utils import Logger
8
9
 
9
10
  class ParquetReader(DfHelper):
10
11
  """
@@ -53,6 +54,9 @@ class ParquetReader(DfHelper):
53
54
  **kwargs,
54
55
  }
55
56
  self.df: Optional[dd.DataFrame] = None
57
+ self.debug = self.config.setdefault('debug', False)
58
+ self.logger = self.config.setdefault('logger', Logger.default_logger(logger_name=self.__class__.__name__))
59
+ self.logger.set_level(logging.DEBUG if self.debug else logging.INFO)
56
60
  self.parquet_storage_path = self.config.setdefault('parquet_storage_path', None)
57
61
  if self.parquet_storage_path is None:
58
62
  raise ValueError('parquet_storage_path must be set')
@@ -64,6 +64,7 @@ class DjangoLoadFromDb:
64
64
  self.connection_config = db_connection
65
65
  self.debug = kwargs.pop('debug', False)
66
66
  self.logger = logger or Logger.default_logger(logger_name=self.__class__.__name__)
67
+ self.logger.set_level(Logger.DEBUG if self.debug else Logger.INFO)
67
68
  if self.connection_config.model is None:
68
69
  if self.debug:
69
70
  self.logger.debug('Model must be specified')
@@ -17,8 +17,9 @@ class ParquetFilterHandler(object):
17
17
  :ivar logger: Logger object to handle logging within the class. Defaults to the class-level logger.
18
18
  :type logger: Logger
19
19
  """
20
- def __init__(self, logger=None):
20
+ def __init__(self, logger=None, debug=False):
21
21
  self.logger = logger or Logger.default_logger(logger_name=self.__class__.__name__)
22
+ self.logger.set_level(Logger.DEBUG if debug else Logger.INFO)
22
23
 
23
24
  @staticmethod
24
25
  def apply_filters_dask(df, filters):
@@ -62,6 +62,7 @@ class ParquetConfig(BaseModel):
62
62
  parquet_end_date: Optional[str] = None
63
63
  fs: Optional[fsspec.spec.AbstractFileSystem] = None # Your fsspec filesystem object
64
64
  logger: Optional[Logger] = None
65
+ debug: bool = False
65
66
  model_config = ConfigDict(arbitrary_types_allowed=True)
66
67
 
67
68
  @model_validator(mode='after')
@@ -83,9 +84,7 @@ class ParquetConfig(BaseModel):
83
84
  # Configure paths based on fsspec
84
85
  if self.logger is None:
85
86
  self.logger = Logger.default_logger(logger_name=self.__class__.__name__)
86
- #self.fs = fsspec.filesystem("file") if "://" not in str(self.parquet_storage_path) else fsspec.filesystem(
87
- # str(self.parquet_storage_path).split("://")[0])
88
- # Validation for parquet path
87
+ self.logger.set_level(Logger.DEBUG if self.debug else Logger.INFO)
89
88
 
90
89
 
91
90
  if self.parquet_storage_path is None:
@@ -63,3 +63,4 @@ class SqlAlchemyConnectionConfig(BaseModel):
63
63
  connection.execute(text("SELECT 1"))
64
64
  except OperationalError as e:
65
65
  raise ValueError(f"Failed to connect to the database: {e}")
66
+
@@ -29,6 +29,7 @@ class SQLAlchemyDask:
29
29
  self.engine = create_engine(engine_url)
30
30
  self.Session = sessionmaker(bind=self.engine)
31
31
  self.logger = logger or Logger.default_logger(logger_name=self.__class__.__name__)
32
+ self.logger.set_level(logger.DEBUG if debug else logger.INFO)
32
33
 
33
34
  @staticmethod
34
35
  def infer_dtypes_from_model(model):
@@ -70,11 +71,7 @@ class SQLAlchemyDask:
70
71
  # Build query
71
72
  self.query = select(self.model)
72
73
  if self.filters:
73
- """
74
- deprecated specific filter handling to a generic one
75
- #self.query = SqlAlchemyFilterHandler.apply_filters_sqlalchemy(self.query, self.model, self.filters)
76
- """
77
- self.query = FilterHandler(backend="sqlalchemy", logger=self.logger).apply_filters(self.query,
74
+ self.query = FilterHandler(backend="sqlalchemy", logger=self.logger, debug=self.debug).apply_filters(self.query,
78
75
  model=self.model,
79
76
  filters=self.filters)
80
77
  else:
@@ -25,7 +25,7 @@ class FilterHandler:
25
25
  :ivar backend_methods: A dictionary mapping backend-specific methods for column retrieval and operation application.
26
26
  :type backend_methods: dict
27
27
  """
28
- def __init__(self, backend, logger=None):
28
+ def __init__(self, backend, logger=None, debug=False):
29
29
  """
30
30
  Initialize the FilterHandler.
31
31
 
@@ -36,6 +36,7 @@ class FilterHandler:
36
36
  self.backend = backend
37
37
  self.logger = logger or Logger.default_logger(
38
38
  logger_name=self.__class__.__name__) # No-op logger if none provided
39
+ self.logger.set_level(Logger.DEBUG if debug else Logger.INFO)
39
40
  self.backend_methods = self._get_backend_methods(backend)
40
41
 
41
42
  def apply_filters(self, query_or_df, model=None, filters=None):
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from .base_osm_map import BaseOsmMap
4
- from .utils import PBFHandler
3
+ from .v1.base_osm_map import BaseOsmMap
4
+ from .v1.utils import PBFHandler
5
5
 
6
6
  __all__ = [
7
7
  "BaseOsmMap",
File without changes
@@ -1,20 +1,21 @@
1
1
  from __future__ import annotations
2
- from sibi_dst.osmnx_helper.utils import get_distance_between_points, add_arrows
2
+ from sibi_dst.osmnx_helper.v1.utils import get_distance_between_points, add_arrows
3
3
  from collections import defaultdict
4
4
  import folium
5
5
  from folium.plugins import AntPath
6
6
  import networkx as nx
7
7
 
8
8
  from sibi_dst.osmnx_helper import BaseOsmMap
9
- from sibi_dst.osmnx_helper.basemaps.calendar_html import calendar_html
9
+ from sibi_dst.osmnx_helper.v1.basemaps.calendar_html import calendar_html
10
10
 
11
11
  class RoutePlotter(BaseOsmMap):
12
12
  def __init__(self, osmnx_graph, df, **kwargs):
13
13
  self.action_field = kwargs.pop('action_field', '')
14
14
  self.action_groups = kwargs.pop('action_groups', {})
15
15
  self.action_styles = kwargs.pop('action_styles', {})
16
- self.use_ant_path = kwargs.pop('use_ant_path', True)
17
- self.show_calendar = kwargs.pop('show_calendar', True)
16
+ self.use_ant_path = kwargs.pop('use_ant_path', False)
17
+ self.show_calendar = kwargs.pop('show_calendar', False)
18
+ self.show_order_markers = kwargs.pop('show_order_markers', False)
18
19
  self.show_map_title = kwargs.pop('show_map_title', True)
19
20
  self.sort_keys = kwargs.pop('sort_keys', None)
20
21
  self.main_route_layer = folium.FeatureGroup(name="Main Route")
@@ -24,6 +25,8 @@ class RoutePlotter(BaseOsmMap):
24
25
  self.actions = []
25
26
  self.action_group_counts = {action_group: 0 for action_group in self.action_groups}
26
27
  self.marker_count = 1
28
+ # Add a snapping threshold (in meters) to avoid drawing nodes/markers that are too close.
29
+ self.snap_distance = kwargs.pop('snap_distance', 30)
27
30
  kwargs.update({'calc_nearest_nodes': True})
28
31
  kwargs['dt_field'] = 'date_time'
29
32
  super().__init__(osmnx_graph, df, **kwargs)
@@ -36,6 +39,8 @@ class RoutePlotter(BaseOsmMap):
36
39
  self._calculate_routes()
37
40
  self._plot_routes()
38
41
  self._add_markers()
42
+ if self.show_order_markers:
43
+ self._add_order_markers()
39
44
  self.main_route_layer.add_to(self.osm_map)
40
45
  if self.show_calendar:
41
46
  self._add_calendar()
@@ -57,8 +62,8 @@ class RoutePlotter(BaseOsmMap):
57
62
  self.route_polylines.append((polyline, color))
58
63
  for action_group, action_markers in markers.items():
59
64
  self.markers[action_group].extend(action_markers)
60
- self.action_group_counts[action_group] += 1
61
- self.marker_count += 1
65
+ self.action_group_counts[action_group] += len(action_markers)
66
+ self.marker_count += len(action_markers)
62
67
  if self.verbose:
63
68
  print("Route and marker calculation complete.")
64
69
 
@@ -70,7 +75,7 @@ class RoutePlotter(BaseOsmMap):
70
75
 
71
76
  def _calculate_route(self, i):
72
77
  if self.verbose:
73
- print(f"Calculating for item:{i}")
78
+ print(f"Calculating for item: {i}")
74
79
  orig = self.nearest_nodes[i]
75
80
  dest = self.nearest_nodes[i + 1]
76
81
  try:
@@ -81,17 +86,31 @@ class RoutePlotter(BaseOsmMap):
81
86
  lats, lons = zip(*[(self.G.nodes[node]['y'] + offset, self.G.nodes[node]['x']) for node in route])
82
87
  color = 'blue' if i < self.max_distance_index else 'red'
83
88
  polyline = list(zip(lats, lons))
89
+ # Apply node snapping to the polyline to remove points that are too close.
90
+ polyline = self._snap_polyline(polyline)
84
91
  markers = self._calculate_markers(i)
85
92
  return polyline, color, markers
86
93
  except nx.NetworkXNoPath:
87
94
  if self.verbose:
88
- print(f"Item:{i}-No path found for {orig} to {dest}")
95
+ print(f"Item: {i} - No path found for {orig} to {dest}")
89
96
  return None, None, {}
90
97
  except nx.NodeNotFound:
91
98
  if self.verbose:
92
- print(f"Item:{i}-No path found for {orig} to {dest}")
99
+ print(f"Item: {i} - No path found for {orig} to {dest}")
93
100
  return None, None, {}
94
101
 
102
+ def _snap_polyline(self, polyline: list[tuple[float, float]]) -> list[tuple[float, float]]:
103
+ """
104
+ Returns a filtered polyline where consecutive points closer than snap_distance are removed.
105
+ """
106
+ if not polyline:
107
+ return polyline
108
+ snapped_polyline = [polyline[0]]
109
+ for point in polyline[1:]:
110
+ if get_distance_between_points(snapped_polyline[-1], point, 'm') >= self.snap_distance:
111
+ snapped_polyline.append(point)
112
+ return snapped_polyline
113
+
95
114
  def _calculate_markers(self, i):
96
115
  # Calculate markers for action groups
97
116
  markers = defaultdict(list)
@@ -110,24 +129,26 @@ class RoutePlotter(BaseOsmMap):
110
129
  def _plot_routes(self):
111
130
  if self.verbose:
112
131
  print("Plotting routes and markers...")
113
- # self.action_group_counts = {action_group: 0 for action_group in self.feature_groups.keys()}
114
132
  for polyline, color in self.route_polylines:
115
133
  if self.use_ant_path:
116
134
  AntPath(
117
135
  locations=polyline,
118
136
  color=color,
119
- weight=3, # Increase line thickness
120
- opacity=10, # Increase opacity
121
- # pulse_color=color,
122
- delay=1000, # Slower animation to reduce flickering
123
- # dash_array=[20, 30] # Adjust dash pattern if needed
137
+ weight=3, # Increased line thickness
138
+ opacity=10, # Increased opacity
139
+ delay=1000, # Slower animation to reduce flickering
124
140
  ).add_to(self.main_route_layer)
125
141
  else:
126
142
  folium.PolyLine(locations=polyline, color=color).add_to(self.main_route_layer)
127
143
  self.osm_map = add_arrows(self.osm_map, polyline, color, n_arrows=3)
128
- # Plot markers for action groups
144
+ # Plot markers for action groups with snapping to avoid drawing too many nearby markers.
129
145
  for action_group, action_markers in self.markers.items():
146
+ seen_positions = []
130
147
  for location, tooltip, popup_data, action_style in action_markers:
148
+ # Skip marker if a nearby marker (within snap_distance) has already been added.
149
+ if any(get_distance_between_points(location, pos, 'm') < self.snap_distance for pos in seen_positions):
150
+ continue
151
+ seen_positions.append(location)
131
152
  folium.Marker(
132
153
  location=location,
133
154
  popup=folium.Popup(popup_data, max_width=600),
@@ -145,11 +166,14 @@ class RoutePlotter(BaseOsmMap):
145
166
  def _add_markers(self):
146
167
  if self.verbose:
147
168
  print("Adding markers...")
148
- # Add start marker
169
+ # Add a start marker
149
170
  start_popup = folium.Popup(f"Start of route at {self.dt[0]}", max_width=300)
150
- folium.Marker(location=self.gps_points[0], popup=start_popup,
151
- icon=folium.Icon(icon='flag-checkered', prefix='fa')).add_to(self.osm_map)
152
- # Add total distance marker at the end
171
+ folium.Marker(
172
+ location=self.gps_points[0],
173
+ popup=start_popup,
174
+ icon=folium.Icon(icon='flag-checkered', prefix='fa')
175
+ ).add_to(self.osm_map)
176
+ # Add an end marker with total distance info
153
177
  folium.Marker(
154
178
  self.gps_points[-1],
155
179
  popup=f"End of Route at {self.dt[self.max_time_index]}. Total Distance Travelled: {self.total_distance / 1000:.2f} km",
@@ -165,15 +189,15 @@ class RoutePlotter(BaseOsmMap):
165
189
  def _add_map_title(self):
166
190
  if self.map_html_title and self.show_map_title:
167
191
  title_html = f'''
168
- <div style="position: fixed;
169
- top: 10px;
170
- left: 50%;
192
+ <div style="position: fixed;
193
+ top: 10px;
194
+ left: 50%;
171
195
  transform: translate(-50%, 0%);
172
- z-index: 9999;
173
- font-size: 24px;
174
- font-weight: bold;
175
- background-color: white;
176
- padding: 10px;
196
+ z-index: 9999;
197
+ font-size: 24px;
198
+ font-weight: bold;
199
+ background-color: white;
200
+ padding: 10px;
177
201
  border: 2px solid black;
178
202
  border-radius: 5px;">
179
203
  {self.map_html_title}
@@ -181,6 +205,37 @@ class RoutePlotter(BaseOsmMap):
181
205
  '''
182
206
  self.osm_map.get_root().html.add_child(folium.Element(title_html))
183
207
 
208
+ def _add_order_markers(self):
209
+ """Adds numbered markers to indicate the visit order."""
210
+ order_feature_group = folium.FeatureGroup(name="Visit Order")
211
+ for idx, location in enumerate(self.gps_points):
212
+ # Create a DivIcon with the number (starting at 1)
213
+ icon = folium.DivIcon(
214
+ icon_size=(24, 24),
215
+ icon_anchor=(12, 12),
216
+ html=f'''
217
+ <div style="
218
+ font-size: 12pt;
219
+ color: black;
220
+ background-color: white;
221
+ border: 1px solid black;
222
+ border-radius: 50%;
223
+ width: 24px;
224
+ height: 24px;
225
+ text-align: center;
226
+ line-height: 24px;">
227
+ {idx + 1}
228
+ </div>
229
+ '''
230
+ )
231
+ folium.Marker(
232
+ location=location,
233
+ icon=icon,
234
+ tooltip=f"GPS Set No. {idx + 1}: {self.dt[idx]}"
235
+ ).add_to(order_feature_group)
236
+
237
+ order_feature_group.add_to(self.osm_map)
238
+
184
239
  def _get_data(self, index):
185
- # implement in subclass to populate popups
186
- ...
240
+ # Implement in subclass to populate popups
241
+ ...
File without changes
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ import html
4
+ from abc import ABC, abstractmethod
5
+ from typing import Optional
6
+
7
+ import folium
8
+ import geopandas as gpd
9
+ import numpy as np
10
+ import osmnx as ox
11
+ import pandas as pd
12
+ from folium.plugins import Fullscreen
13
+ from networkx import MultiDiGraph
14
+
15
+
16
+ class BaseOsmMap(ABC):
17
+ # Define available tile options for the map
18
+ tile_options = {
19
+ "OpenStreetMap": "OpenStreetMap",
20
+ "CartoDB": "cartodbpositron",
21
+ "CartoDB Voyager": "cartodbvoyager"
22
+ }
23
+ # Default geographical bounds (Costa Rica)
24
+ bounds = [[8.0340, -85.9417], [11.2192, -82.5566]]
25
+
26
+ def __init__(
27
+ self,
28
+ osmnx_graph: MultiDiGraph,
29
+ df: pd.DataFrame,
30
+ lat_col: str = "latitude",
31
+ lon_col: str = "longitude",
32
+ map_html_title: str = "OSM Basemap",
33
+ zoom_start: int = 13,
34
+ fullscreen: bool = True,
35
+ fullscreen_position: str = "topright",
36
+ tiles: str = "OpenStreetMap",
37
+ verbose: bool = False,
38
+ sort_keys: Optional[list[str]] = None,
39
+ dt_field: Optional[str] = None,
40
+ calc_nearest_nodes: bool = False,
41
+ max_bounds: bool = False,
42
+ ):
43
+ if df.empty:
44
+ raise ValueError("df must not be empty")
45
+
46
+ # Store attributes
47
+ self.df = df.copy()
48
+ self.osmnx_graph = osmnx_graph
49
+ self.lat_col = lat_col
50
+ self.lon_col = lon_col
51
+ self.map_html_title = self._sanitize_html(map_html_title)
52
+ self.zoom_start = zoom_start
53
+ self.fullscreen = fullscreen
54
+ self.fullscreen_position = fullscreen_position
55
+ self.tiles = tiles
56
+ self.verbose = verbose
57
+ self.sort_keys = sort_keys
58
+ self.dt_field = dt_field
59
+ self.calc_nearest_nodes = calc_nearest_nodes
60
+ self.max_bounds = max_bounds
61
+ self.dt = self.df[self.dt_field].to_list() if self.dt_field else None
62
+ self.nearest_nodes = None
63
+ self.G = None
64
+ self.osm_map = None
65
+
66
+ self._prepare_df()
67
+ self._initialize_map()
68
+
69
+ def _prepare_df(self):
70
+ """Sort and preprocess the DataFrame."""
71
+ if self.sort_keys:
72
+ self.df.sort_values(by=self.sort_keys, inplace=True, ignore_index=True)
73
+ self.gps_points = self.df[[self.lat_col, self.lon_col]].to_numpy()
74
+
75
+ # Compute nearest nodes if required
76
+ if self.calc_nearest_nodes and not self.df.empty:
77
+ self.nearest_nodes = ox.distance.nearest_nodes(
78
+ self.osmnx_graph, X=self.df[self.lon_col], Y=self.df[self.lat_col]
79
+ )
80
+
81
+ def _initialize_map(self):
82
+ """Initialize the folium map centered around the dataset."""
83
+ if self.gps_points.size == 0:
84
+ raise ValueError("No valid GPS points available for map initialization")
85
+
86
+ center = self.gps_points.mean(axis=0).tolist()
87
+ if self.osm_map is None:
88
+ self.osm_map = folium.Map(
89
+ location=center, zoom_start=self.zoom_start, tiles=self.tiles, max_bounds=self.max_bounds
90
+ )
91
+ self.G = self._extract_subgraph(*self._get_bounding_box_from_points())
92
+
93
+ def _get_bounding_box_from_points(self, margin: float = 0.001) -> tuple[float, float, float, float]:
94
+ """Compute bounding box for the dataset with margin."""
95
+ latitudes, longitudes = self.gps_points[:, 0], self.gps_points[:, 1]
96
+ return max(latitudes) + margin, min(latitudes) - margin, max(longitudes) + margin, min(longitudes) - margin
97
+
98
+ def _extract_subgraph(self, north: float, south: float, east: float, west: float) -> MultiDiGraph:
99
+ """Extract a subgraph from OSM data within the bounding box."""
100
+ bbox_poly = gpd.GeoSeries([ox.utils_geo.bbox_to_poly((west, south, east, north))])
101
+ nodes_gdf = ox.graph_to_gdfs(self.osmnx_graph, nodes=True, edges=False)
102
+ nodes_within_bbox = gpd.sjoin(nodes_gdf, gpd.GeoDataFrame(geometry=bbox_poly), predicate="within")
103
+ return self.osmnx_graph.subgraph(nodes_within_bbox.index)
104
+
105
+ def _post_process_map(self):
106
+ """Perform final adjustments to the map."""
107
+ self._attach_supported_tiles()
108
+ self.add_tile_layer()
109
+ self._add_fullscreen()
110
+ self._add_map_title()
111
+ if self.max_bounds and self.bounds:
112
+ self.osm_map.fit_bounds(self.bounds)
113
+
114
+ def _attach_supported_tiles(self):
115
+ """Attach additional tile layers to the map."""
116
+ for name, tile in self.tile_options.items():
117
+ if tile.lower() != self.tiles.lower():
118
+ folium.TileLayer(name=name, tiles=tile, show=False).add_to(self.osm_map)
119
+
120
+ def _add_fullscreen(self):
121
+ """Enable fullscreen control if required."""
122
+ if self.fullscreen:
123
+ Fullscreen(position=self.fullscreen_position).add_to(self.osm_map)
124
+
125
+ def _add_map_title(self):
126
+ """Add a title to the map if provided."""
127
+ if self.map_html_title:
128
+ self.osm_map.get_root().html.add_child(folium.Element(self.map_html_title))
129
+
130
+ @staticmethod
131
+ def _sanitize_html(input_html: str) -> str:
132
+ """Sanitize HTML input to prevent script injection."""
133
+ return html.escape(input_html)
134
+
135
+ @abstractmethod
136
+ def process_map(self):
137
+ """Abstract method to define map processing logic in subclasses."""
138
+ pass
139
+
140
+ def pre_process_map(self):
141
+ """Optional preprocessing step before main processing."""
142
+ pass
143
+
144
+ def add_tile_layer(self):
145
+ """Add a layer control to the map."""
146
+ folium.LayerControl().add_to(self.osm_map)
147
+
148
+ def generate_map(self) -> folium.Map:
149
+ """Generate and return the processed map."""
150
+ self.pre_process_map()
151
+ self.process_map()
152
+ self._post_process_map()
153
+ return self.osm_map
File without changes
File without changes