pyretailscience 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/PKG-INFO +1 -1
  2. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyproject.toml +1 -1
  3. pyretailscience-0.3.2/pyretailscience/assets/fonts/Poppins-Bold.ttf +0 -0
  4. pyretailscience-0.3.2/pyretailscience/assets/fonts/Poppins-LightItalic.ttf +0 -0
  5. pyretailscience-0.3.2/pyretailscience/assets/fonts/Poppins-Medium.ttf +0 -0
  6. pyretailscience-0.3.2/pyretailscience/assets/fonts/Poppins-Regular.ttf +0 -0
  7. pyretailscience-0.3.2/pyretailscience/assets/fonts/Poppins-SemiBold.ttf +0 -0
  8. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/customer.py +77 -42
  9. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/data/contracts.py +126 -113
  10. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/range_planning.py +33 -11
  11. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/segmentation.py +115 -45
  12. pyretailscience-0.3.2/pyretailscience/standard_graphs.py +313 -0
  13. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/style/graph_utils.py +47 -2
  14. pyretailscience-0.3.0/pyretailscience/standard_graphs.py +0 -96
  15. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/LICENSE +0 -0
  16. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/README.md +0 -0
  17. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/__init__.py +0 -0
  18. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/data/__init__.py +0 -0
  19. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/data/cli.py +0 -0
  20. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/data/simulation.py +0 -0
  21. {pyretailscience-0.3.0 → pyretailscience-0.3.2}/pyretailscience/style/tailwind.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pyretailscience
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Retail Data Science Tools
5
5
  License: Elastic-2.0
6
6
  Author: Murray Vanwyk
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "pyretailscience"
3
- version = "0.3.0"
3
+ version = "0.3.2"
4
4
  description = "Retail Data Science Tools"
5
5
  authors = ["Murray Vanwyk <2493311+mvanwyk@users.noreply.github.com>"]
6
6
  readme = "README.md"
@@ -6,6 +6,7 @@ from matplotlib.axes import Axes, SubplotBase
6
6
 
7
7
  from pyretailscience.data.contracts import TransactionItemLevelContract
8
8
  from pyretailscience.style.graph_utils import GraphStyles as gs
9
+ import pyretailscience.style.graph_utils as gu
9
10
  from pyretailscience.style.graph_utils import human_format, standard_graph_styles
10
11
  from pyretailscience.style.tailwind import COLORS
11
12
 
@@ -36,8 +37,8 @@ class PurchasesPerCustomer:
36
37
  percentile_line: float = 0.5,
37
38
  source_text: str | None = None,
38
39
  title: str | None = None,
39
- xlabel: str | None = None,
40
- ylabel: str | None = None,
40
+ x_label: str | None = None,
41
+ y_label: str | None = None,
41
42
  **kwargs: dict[str, any],
42
43
  ) -> SubplotBase:
43
44
  """Plot the distribution of the number of purchases per customer.
@@ -58,8 +59,8 @@ class PurchasesPerCustomer:
58
59
  if cumlative:
59
60
  density = True
60
61
 
61
- if xlabel is None:
62
- xlabel = "Number of purchases"
62
+ if x_label is None:
63
+ x_label = "Number of purchases"
63
64
 
64
65
  ax = self.cust_purchases_s.hist(
65
66
  bins=bins,
@@ -70,27 +71,38 @@ class PurchasesPerCustomer:
70
71
  **kwargs,
71
72
  )
72
73
 
73
- ax.set_xlabel(xlabel, fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE, labelpad=10)
74
+ ax.set_xlabel(
75
+ x_label,
76
+ fontproperties=gs.POPPINS_REG,
77
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
78
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
79
+ )
74
80
  ax.xaxis.set_major_formatter(lambda x, pos: human_format(x, pos, decimals=0))
75
81
 
76
82
  ax = standard_graph_styles(ax)
77
83
 
78
84
  if cumlative:
79
- if title is None:
80
- title = "Number of Purchases Cumulative Distribution"
81
- if ylabel is None:
82
- ylabel = "Percentage of customers"
85
+ default_title = "Number of Purchases Cumulative Distribution"
86
+ default_y_label = "Percentage of customers"
83
87
  ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
84
88
 
85
89
  else:
86
- if title is None:
87
- title = "Number of Purchases Distribution"
88
- if ylabel is None:
89
- ylabel = "Number of customers"
90
+ default_title = "Number of Purchases Distribution"
91
+ default_y_label = "Number of customers"
90
92
  ax.yaxis.set_major_formatter(lambda x, pos: human_format(x, pos, decimals=0))
91
93
 
92
- ax.set_title(title, fontsize=gs.DEFAULT_TITLE_FONT_SIZE, pad=15)
93
- ax.set_ylabel(ylabel, fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE, labelpad=10)
94
+ ax.set_title(
95
+ gu.not_none(title, default_title),
96
+ fontproperties=gs.POPPINS_SEMI_BOLD,
97
+ fontsize=gs.DEFAULT_TITLE_FONT_SIZE,
98
+ pad=gs.DEFAULT_TITLE_PAD,
99
+ )
100
+ ax.set_ylabel(
101
+ gu.not_none(y_label, default_y_label),
102
+ fontproperties=gs.POPPINS_REG,
103
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
104
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
105
+ )
94
106
 
95
107
  if draw_percentile_line:
96
108
  if percentile_line > 1 or percentile_line < 0:
@@ -112,6 +124,8 @@ class PurchasesPerCustomer:
112
124
  ha="left",
113
125
  va="center",
114
126
  fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
127
+ fontproperties=gs.POPPINS_LIGHT_ITALIC,
128
+ color="dimgray",
115
129
  )
116
130
 
117
131
  return ax
@@ -192,8 +206,8 @@ class DaysBetweenPurchases:
192
206
  draw_percentile_line: bool = False,
193
207
  percentile_line: float = 0.5,
194
208
  title: str | None = None,
195
- xlabel: str | None = None,
196
- ylabel: str | None = None,
209
+ x_label: str | None = None,
210
+ y_label: str | None = None,
197
211
  source_text: str = None,
198
212
  **kwargs: dict[str, any],
199
213
  ) -> SubplotBase:
@@ -223,29 +237,38 @@ class DaysBetweenPurchases:
223
237
  **kwargs,
224
238
  )
225
239
 
226
- if xlabel is None:
227
- xlabel = "Average Number of Days Between Purchases"
228
- ax.set_xlabel(xlabel, fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE, labelpad=10)
240
+ ax.set_xlabel(
241
+ gu.not_none(x_label, "Average Number of Days Between Purchases"),
242
+ fontproperties=gs.POPPINS_REG,
243
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
244
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
245
+ )
229
246
  ax.xaxis.set_major_formatter(lambda x, pos: human_format(x, pos, decimals=0))
230
247
 
231
248
  ax = standard_graph_styles(ax)
232
249
 
233
250
  if cumlative:
234
- if title is None:
235
- title = "Average Days Between Purchases Cumulative Distribution"
236
- if ylabel is None:
237
- ylabel = "Percentage of Customers"
251
+ default_title = "Average Days Between Purchases Cumulative Distribution"
252
+ default_y_label = "Percentage of Customers"
238
253
  ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
239
254
 
240
255
  else:
241
- if title is None:
242
- title = "Average Days Between Purchases Distribution"
243
- if ylabel is None:
244
- ylabel = "Number of Customers"
256
+ default_title = "Average Days Between Purchases Distribution"
257
+ default_y_label = "Number of Customers"
245
258
  ax.yaxis.set_major_formatter(lambda x, pos: human_format(x, pos, decimals=0))
246
259
 
247
- ax.set_title(title, fontsize=gs.DEFAULT_TITLE_FONT_SIZE, pad=15)
248
- ax.set_ylabel(ylabel, fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE, labelpad=10)
260
+ ax.set_title(
261
+ gu.not_none(title, default_title),
262
+ fontproperties=gs.POPPINS_SEMI_BOLD,
263
+ fontsize=gs.DEFAULT_TITLE_FONT_SIZE,
264
+ pad=gs.DEFAULT_TITLE_PAD,
265
+ )
266
+ ax.set_ylabel(
267
+ gu.not_none(y_label, default_y_label),
268
+ fontproperties=gs.POPPINS_REG,
269
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
270
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
271
+ )
249
272
 
250
273
  if draw_percentile_line:
251
274
  if percentile_line > 1 or percentile_line < 0:
@@ -268,6 +291,8 @@ class DaysBetweenPurchases:
268
291
  ha="left",
269
292
  va="center",
270
293
  fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
294
+ fontproperties=gs.POPPINS_LIGHT_ITALIC,
295
+ color="dimgray",
271
296
  )
272
297
 
273
298
  return ax
@@ -335,8 +360,8 @@ class TransactionChurn:
335
360
  cumlative: bool = False,
336
361
  ax: Axes | None = None,
337
362
  title: str | None = None,
338
- xlabel: str | None = None,
339
- ylabel: str | None = None,
363
+ x_label: str | None = None,
364
+ y_label: str | None = None,
340
365
  source_text: str = None,
341
366
  **kwargs: dict[str, any],
342
367
  ) -> SubplotBase:
@@ -366,17 +391,25 @@ class TransactionChurn:
366
391
 
367
392
  standard_graph_styles(ax)
368
393
 
369
- if title is None:
370
- title = "Churn Rate by Number of Purchases"
371
- if xlabel is None:
372
- xlabel = "Number of Purchases"
373
- if ylabel is None:
374
- ylabel = "% Churned"
375
-
376
394
  ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0))
377
- ax.set_xlabel(xlabel, fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE, labelpad=10)
378
- ax.set_ylabel(ylabel, fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE, labelpad=10)
379
- ax.set_title(title, fontsize=gs.DEFAULT_TITLE_FONT_SIZE, pad=15)
395
+ ax.set_xlabel(
396
+ gu.not_none(x_label, "Number of Purchases"),
397
+ fontproperties=gs.POPPINS_REG,
398
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
399
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
400
+ )
401
+ ax.set_ylabel(
402
+ gu.not_none(y_label, "% Churned"),
403
+ fontproperties=gs.POPPINS_REG,
404
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
405
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
406
+ )
407
+ ax.set_title(
408
+ gu.not_none(title, "Churn Rate by Number of Purchases"),
409
+ fontproperties=gs.POPPINS_SEMI_BOLD,
410
+ fontsize=gs.DEFAULT_TITLE_FONT_SIZE,
411
+ pad=gs.DEFAULT_TITLE_PAD,
412
+ )
380
413
 
381
414
  if source_text:
382
415
  ax.annotate(
@@ -386,6 +419,8 @@ class TransactionChurn:
386
419
  ha="left",
387
420
  va="center",
388
421
  fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
422
+ fontproperties=gs.POPPINS_LIGHT_ITALIC,
423
+ color="dimgray",
389
424
  )
390
425
 
391
426
  return ax
@@ -83,6 +83,51 @@ class PyRetailSciencePandasDataset(PandasDataset):
83
83
  }
84
84
 
85
85
 
86
+ def build_expected_columns(columns: list[str]) -> list[ExpectationConfiguration]:
87
+ """A helper function that builds a list of expectations for the columns to exist.
88
+
89
+ Args:
90
+ columns (list[str]): A list of columns to build the expectations for.
91
+
92
+ Returns:
93
+ list[ExpectationConfiguration]: A list of expectations for the columns to exist.
94
+ """
95
+ return [
96
+ ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": column})
97
+ for column in columns
98
+ ]
99
+
100
+
101
+ def build_expected_unique_columns(columns: list[str]) -> list[ExpectationConfiguration]:
102
+ """A helper function that builds a list of expectations for the columns to have unique values.
103
+
104
+ Args:
105
+ columns (list[str]): A list of columns to build the expectations for.
106
+
107
+ Returns:
108
+ list[ExpectationConfiguration]: A list of expectations for the columns to have unique values.
109
+ """
110
+ return [
111
+ ExpectationConfiguration(expectation_type="expect_column_values_to_be_unique", kwargs={"column": column})
112
+ for column in columns
113
+ ]
114
+
115
+
116
+ def build_non_null_columns(columns: list[list[str]]) -> list[ExpectationConfiguration]:
117
+ """A helper function that builds a list of expectations for the columns to have no null values.
118
+
119
+ Args:
120
+ columns (list[list[str]]): A list of columns to build the expectations for.
121
+
122
+ Returns:
123
+ list[ExpectationConfiguration]: A list of expectations for the columns to have no null values.
124
+ """
125
+ return [
126
+ ExpectationConfiguration(expectation_type="expect_column_values_to_not_be_null", kwargs={"column": column})
127
+ for column in columns
128
+ ]
129
+
130
+
86
131
  class ContractBase(abc.ABC):
87
132
  """Base class for data contracts. It contains the basic and extended expectations for the data, as well as the
88
133
  validation state and the result of the last validation. It also contains a method to validate the data.
@@ -180,65 +225,29 @@ class TransactionLevelContract(ContractBase):
180
225
  validation_result (dict): The result of the last validation.
181
226
  """
182
227
 
183
- basic_expectations = [
184
- ExpectationConfiguration(
185
- expectation_type="expect_column_to_exist",
186
- kwargs={"column": "transaction_id"},
187
- ),
188
- ExpectationConfiguration(
189
- expectation_type="expect_column_to_exist",
190
- kwargs={"column": "transaction_datetime"},
191
- ),
192
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "customer_id"}),
193
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "total_price"}),
194
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "store_id"}),
195
- ]
228
+ basic_expectations = build_expected_columns(
229
+ ["transaction_id", "transaction_datetime", "customer_id", "total_price", "store_id"]
230
+ )
196
231
 
197
- extended_expectations = [
198
- ExpectationConfiguration(
199
- expectation_type="expect_column_values_to_be_unique",
200
- kwargs={"column": "transaction_id"},
201
- ),
202
- ExpectationConfiguration(
203
- expectation_type="expect_compound_columns_to_be_unique",
204
- kwargs={
205
- "column_list": [
206
- "transaction_id",
207
- "transaction_datetime",
208
- "customer_id",
209
- "store_id",
210
- ]
211
- },
212
- ),
213
- ExpectationConfiguration(
214
- expectation_type="expect_column_values_to_be_between",
215
- kwargs={
216
- "column": "transaction_datetime",
217
- "min_value": "1970-01-01",
218
- "max_value": "2029-12-31",
219
- },
220
- ),
221
- ExpectationConfiguration(
222
- expectation_type="expect_column_values_to_not_be_null",
223
- kwargs={"column": "transaction_id"},
224
- ),
225
- ExpectationConfiguration(
226
- expectation_type="expect_column_values_to_not_be_null",
227
- kwargs={"column": "transaction_datetime"},
228
- ),
229
- ExpectationConfiguration(
230
- expectation_type="expect_column_values_to_not_be_null",
231
- kwargs={"column": "customer_id"},
232
- ),
233
- ExpectationConfiguration(
234
- expectation_type="expect_column_values_to_not_be_null",
235
- kwargs={"column": "total_price"},
236
- ),
237
- ExpectationConfiguration(
238
- expectation_type="expect_column_values_to_not_be_null",
239
- kwargs={"column": "store_id"},
240
- ),
241
- ]
232
+ extended_expectations = (
233
+ build_expected_unique_columns(
234
+ [
235
+ "transaction_id",
236
+ ["transaction_datetime", "customer_id", "total_price", "store_id"],
237
+ ]
238
+ )
239
+ + [
240
+ ExpectationConfiguration(
241
+ expectation_type="expect_column_values_to_be_between",
242
+ kwargs={
243
+ "column": "transaction_datetime",
244
+ "min_value": "1970-01-01",
245
+ "max_value": "2029-12-31",
246
+ },
247
+ )
248
+ ]
249
+ + build_non_null_columns(["transaction_id", "transaction_datetime", "customer_id", "total_price", "store_id"])
250
+ )
242
251
 
243
252
 
244
253
  class TransactionItemLevelContract(ContractBase):
@@ -263,23 +272,19 @@ class TransactionItemLevelContract(ContractBase):
263
272
  that these columns are not null.
264
273
  """
265
274
 
266
- basic_expectations = [
267
- ExpectationConfiguration(
268
- expectation_type="expect_column_to_exist",
269
- kwargs={"column": "transaction_id"},
270
- ),
271
- ExpectationConfiguration(
272
- expectation_type="expect_column_to_exist",
273
- kwargs={"column": "transaction_datetime"},
274
- ),
275
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "customer_id"}),
276
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "total_price"}),
277
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "store_id"}),
278
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "product_id"}),
279
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "product_name"}),
280
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "unit_price"}),
281
- ExpectationConfiguration(expectation_type="expect_column_to_exist", kwargs={"column": "quantity"}),
282
- ]
275
+ basic_expectations = build_expected_columns(
276
+ [
277
+ "transaction_id",
278
+ "transaction_datetime",
279
+ "customer_id",
280
+ "total_price",
281
+ "store_id",
282
+ "product_id",
283
+ "product_name",
284
+ "unit_price",
285
+ "quantity",
286
+ ]
287
+ )
283
288
 
284
289
  extended_expectations = [
285
290
  ExpectationConfiguration(
@@ -297,44 +302,18 @@ class TransactionItemLevelContract(ContractBase):
297
302
  expectation_type="expect_transaction_product_quantity_sign_to_be_unique",
298
303
  kwargs={},
299
304
  ),
300
- # Null expectations
301
- ExpectationConfiguration(
302
- expectation_type="expect_column_values_to_not_be_null",
303
- kwargs={"column": "transaction_id"},
304
- ),
305
- ExpectationConfiguration(
306
- expectation_type="expect_column_values_to_not_be_null",
307
- kwargs={"column": "transaction_datetime"},
308
- ),
309
- ExpectationConfiguration(
310
- expectation_type="expect_column_values_to_not_be_null",
311
- kwargs={"column": "customer_id"},
312
- ),
313
- ExpectationConfiguration(
314
- expectation_type="expect_column_values_to_not_be_null",
315
- kwargs={"column": "total_price"},
316
- ),
317
- ExpectationConfiguration(
318
- expectation_type="expect_column_values_to_not_be_null",
319
- kwargs={"column": "store_id"},
320
- ),
321
- ExpectationConfiguration(
322
- expectation_type="expect_column_values_to_not_be_null",
323
- kwargs={"column": "product_id"},
324
- ),
325
- ExpectationConfiguration(
326
- expectation_type="expect_column_values_to_not_be_null",
327
- kwargs={"column": "product_name"},
328
- ),
329
- ExpectationConfiguration(
330
- expectation_type="expect_column_values_to_not_be_null",
331
- kwargs={"column": "unit_price"},
332
- ),
333
- ExpectationConfiguration(
334
- expectation_type="expect_column_values_to_not_be_null",
335
- kwargs={"column": "quantity"},
336
- ),
337
- ]
305
+ ] + build_non_null_columns(
306
+ [
307
+ "transaction_id",
308
+ "transaction_datetime",
309
+ "customer_id",
310
+ "store_id",
311
+ "product_id",
312
+ "product_name",
313
+ "unit_price",
314
+ "quantity",
315
+ ]
316
+ )
338
317
 
339
318
  def __init__(self, df: pd.DataFrame) -> None:
340
319
  # If category or brand columns are present, add expectations for them
@@ -393,3 +372,37 @@ class CustomerLevelContract(ContractBase):
393
372
  kwargs={"column": "customer_id"},
394
373
  ),
395
374
  ]
375
+
376
+
377
+ class CustomContract(ContractBase):
378
+ """A helper class to construct contracts for specific use cases.
379
+
380
+ Args:
381
+ df (pd.DataFrame): The input DataFrame.
382
+ basic_expectations (list[ExpectationConfiguration] | None, optional): A list of basic expectation
383
+ configurations. Defaults to None. At least one basic or extended expectation must be supplied.
384
+ extended_expectations (list[ExpectationConfiguration] | None, optional): A list of extended expectation
385
+ configurations. Defaults to None. At least one basic or extended expectation must be supplied.
386
+
387
+ Raises:
388
+ ValueError: If both basic_expectations and extended_expectations are None.
389
+
390
+ Attributes:
391
+ basic_expectations (list[ExpectationConfiguration]): A list of basic expectation configurations.
392
+ extended_expectations (list[ExpectationConfiguration]): A list of extended expectation configurations.
393
+
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ df: pd.DataFrame,
399
+ basic_expectations: list[ExpectationConfiguration] | None = None,
400
+ extended_expectations: list[ExpectationConfiguration] | None = None,
401
+ ) -> None:
402
+ if basic_expectations is None and extended_expectations is None:
403
+ raise ValueError("At least one of basic_expectations or extended_expectations must be provided.")
404
+
405
+ self.basic_expectations = basic_expectations or []
406
+ self.extended_expectations = extended_expectations or []
407
+
408
+ super().__init__(df)
@@ -7,7 +7,7 @@ from matplotlib.axes import Axes, SubplotBase
7
7
  from scipy.cluster.hierarchy import dendrogram, linkage
8
8
 
9
9
  import pyretailscience.style.graph_utils as gu
10
- from pyretailscience.data.contracts import TransactionItemLevelContract
10
+ from pyretailscience.data.contracts import CustomContract, build_expected_columns, build_non_null_columns
11
11
  from pyretailscience.style.graph_utils import GraphStyles as gs
12
12
 
13
13
 
@@ -40,8 +40,17 @@ class CustomerDecisionHierarchy:
40
40
  ValueError: If the dataframe does not comply with the TransactionItemLevelContract.
41
41
 
42
42
  """
43
- if TransactionItemLevelContract(df).validate() is False:
44
- raise ValueError("The dataframe does not comply with the TransactionItemLevelContract")
43
+ cdh_contract = CustomContract(
44
+ df,
45
+ basic_expectations=build_expected_columns(columns=["customer_id", "transaction_id", "product_name"]),
46
+ extended_expectations=build_non_null_columns(columns=["customer_id", "transaction_id", "product_name"]),
47
+ )
48
+
49
+ if cdh_contract.validate() is False:
50
+ raise ValueError(
51
+ "The dataframe requires the columns 'customer_id', 'transaction_id', and 'product_name' and they must "
52
+ "be non-null"
53
+ )
45
54
 
46
55
  self.random_state = random_state
47
56
  self.pairs_df = self._get_pairs(df, exclude_same_transaction_products)
@@ -254,18 +263,21 @@ class CustomerDecisionHierarchy:
254
263
 
255
264
  ax.set_title(
256
265
  title,
266
+ fontproperties=gs.POPPINS_SEMI_BOLD,
257
267
  fontsize=gs.DEFAULT_TITLE_FONT_SIZE,
258
- pad=15,
268
+ pad=gs.DEFAULT_TITLE_PAD + 5,
259
269
  )
260
270
  ax.set_xlabel(
261
- gu.not_none(y_label, default_x_label),
271
+ gu.not_none(x_label, default_x_label),
272
+ fontproperties=gs.POPPINS_REG,
262
273
  fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
263
- labelpad=10,
274
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
264
275
  )
265
276
  ax.set_ylabel(
266
- gu.not_none(x_label, default_y_label),
277
+ gu.not_none(y_label, default_y_label),
278
+ fontproperties=gs.POPPINS_REG,
267
279
  fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
268
- labelpad=10,
280
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
269
281
  )
270
282
 
271
283
  # Set the y label to be on the right side of the plot
@@ -300,11 +312,21 @@ class CustomerDecisionHierarchy:
300
312
  ha="left",
301
313
  va="center",
302
314
  fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
315
+ fontproperties=gs.POPPINS_LIGHT_ITALIC,
316
+ color="dimgray",
303
317
  )
304
318
 
319
+ ax.xaxis.set_tick_params(labelsize=gs.DEFAULT_TICK_LABEL_FONT_SIZE)
320
+ ax.yaxis.set_tick_params(labelsize=gs.DEFAULT_TICK_LABEL_FONT_SIZE)
321
+
322
+ # Rotate the x-axis labels if they are too long
305
323
  if orientation in ["top", "bottom"]:
306
- ax.xaxis.set_tick_params(labelsize=gs.DEFAULT_TICK_LABEL_FONT_SIZE)
307
- else:
308
- ax.yaxis.set_tick_params(labelsize=gs.DEFAULT_TICK_LABEL_FONT_SIZE)
324
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
325
+
326
+ # Set the font properties for the tick labels
327
+ for tick in ax.get_xticklabels():
328
+ tick.set_fontproperties(gs.POPPINS_REG)
329
+ for tick in ax.get_yticklabels():
330
+ tick.set_fontproperties(gs.POPPINS_REG)
309
331
 
310
332
  return ax
@@ -4,37 +4,108 @@ import pandas as pd
4
4
  from matplotlib.axes import Axes, SubplotBase
5
5
 
6
6
  import pyretailscience.style.graph_utils as gu
7
- from pyretailscience.data.contracts import TransactionItemLevelContract, TransactionLevelContract
7
+ from pyretailscience.data.contracts import (
8
+ TransactionItemLevelContract,
9
+ TransactionLevelContract,
10
+ CustomContract,
11
+ build_expected_columns,
12
+ build_non_null_columns,
13
+ build_expected_unique_columns,
14
+ )
15
+ from pyretailscience.style.graph_utils import GraphStyles as gs
8
16
  from pyretailscience.style.tailwind import COLORS
9
17
 
10
18
 
11
- class HMLSegmentation:
12
- def __init__(self, df: pd.DataFrame, value_col: str = "total_price") -> None:
19
+ class BaseSegmentation:
20
+ def add_segment(self, df: pd.DataFrame) -> pd.DataFrame:
21
+ """
22
+ Adds the segment to the dataframe based on the customer_id column.
23
+
24
+ Args:
25
+ df (pd.DataFrame): The dataframe to add the segment to. The dataframe must have a customer_id column.
26
+
27
+ Returns:
28
+ pd.DataFrame: The dataframe with the segment added.
29
+
30
+ Raises:
31
+ ValueError: If the number of rows before and after the merge do not match.
32
+ """
33
+ rows_before = len(df)
34
+ df = df.merge(self.df[["segment_name", "segment_id"]], how="left", left_on="customer_id", right_index=True)
35
+ rows_after = len(df)
36
+ if rows_before != rows_after:
37
+ raise ValueError("The number of rows before and after the merge do not match. This should not happen.")
38
+
39
+ return df
40
+
41
+
42
+ class ExistingSegmentation(BaseSegmentation):
43
+ def __init__(self, df: pd.DataFrame) -> None:
44
+ """
45
+ Segments customers based on an existing segment in the dataframe.
46
+
47
+ Args:
48
+
49
+ df (pd.DataFrame): A dataframe with the customer_id, segment_name and segment_id columns.
50
+
51
+ Raises:
52
+ ValueError: If the dataframe does not have the columns customer_id, segment_name and segment_id.
53
+ """
54
+ required_cols = "customer_id", "segment_name", "segment_id"
55
+ contract = CustomContract(
56
+ df,
57
+ basic_expectations=build_expected_columns(columns=required_cols),
58
+ extended_expectations=build_non_null_columns(columns=required_cols)
59
+ + build_expected_unique_columns(columns=[required_cols]),
60
+ )
61
+
62
+ if contract.validate() is False:
63
+ raise ValueError(
64
+ f"The dataframe requires the columns {required_cols} and they must be non-null and unique."
65
+ )
66
+
67
+ self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id")
68
+
69
+
70
+ class HMLSegmentation(BaseSegmentation):
71
+ def __init__(
72
+ self,
73
+ df: pd.DataFrame,
74
+ value_col: str = "total_price",
75
+ zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment",
76
+ ) -> None:
13
77
  """
14
78
  Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend.
15
79
 
16
80
  Args:
17
- df (pd.DataFrame): A dataframe with the transaction data. The dataframe must comply with the
18
- TransactionItemLevelContract or the TransactionLevelContract.
81
+ df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column.
19
82
  value_col (str, optional): The column to use for the segmentation. Defaults to "total_price".
20
83
 
21
84
  Raises:
22
- ValueError: If the dataframe does not comply with the TransactionItemLevelContract or
23
- TransactionLevelContract.
85
+ ValueError: If the dataframe is missing the columns "customer_id" or `value_col`, or these columns contain
86
+ null values.
24
87
  """
88
+ required_cols = ["customer_id", value_col]
89
+ contract = CustomContract(
90
+ df,
91
+ basic_expectations=build_expected_columns(columns=required_cols),
92
+ extended_expectations=build_non_null_columns(columns=required_cols),
93
+ )
25
94
 
26
- if TransactionItemLevelContract(df).validate() is False and TransactionLevelContract(df).validate() is False:
27
- raise ValueError("The dataframe does not comply with the TransactionItemLevelContract")
95
+ if contract.validate() is False:
96
+ raise ValueError(f"The dataframe requires the columns {required_cols} and they must be non-null")
28
97
 
29
98
  # Group by customer_id and calculate total_spend
30
99
  grouped_df = df.groupby("customer_id")[value_col].sum().to_frame(value_col)
31
100
 
32
101
  # Separate customers with zero spend
33
- zero_idx = grouped_df[value_col] == 0
34
- zero_cust_df = grouped_df[zero_idx]
35
- zero_cust_df["segment_name"] = "Zero"
102
+ hml_df = grouped_df
103
+ if zero_value_customers in ["separate_segment", "exclude"]:
104
+ zero_idx = grouped_df[value_col] == 0
105
+ zero_cust_df = grouped_df[zero_idx]
106
+ zero_cust_df["segment_name"] = "Zero"
36
107
 
37
- hml_df = grouped_df[~zero_idx]
108
+ hml_df = grouped_df[~zero_idx]
38
109
 
39
110
  # Create a new column 'segment' based on the total_spend
40
111
  hml_df["segment_name"] = pd.qcut(
@@ -43,38 +114,14 @@ class HMLSegmentation:
43
114
  labels=["Light", "Medium", "Heavy"],
44
115
  )
45
116
 
46
- hml_df = pd.concat([hml_df, zero_cust_df])
117
+ if zero_value_customers == "separate_segment":
118
+ hml_df = pd.concat([hml_df, zero_cust_df])
47
119
 
48
120
  segment_code_map = {"Light": "L", "Medium": "M", "Heavy": "H", "Zero": "Z"}
49
121
 
50
122
  hml_df["segment_id"] = hml_df["segment_name"].map(segment_code_map)
51
123
 
52
- self.df = grouped_df
53
-
54
- def add_segment(self, df: pd.DataFrame) -> pd.DataFrame:
55
- """
56
- Adds the segment to the dataframe based on the customer_id column.
57
-
58
- Args:
59
- df (pd.DataFrame): The dataframe to add the segment to. The dataframe must have a customer_id column.
60
-
61
- Returns:
62
- pd.DataFrame: The dataframe with the segment added.
63
-
64
- Raises:
65
- ValueError: If the number of rows before and after the merge do not match.
66
- """
67
-
68
- # TODO: Add a contract that ensures there's a customer ID column or matches one or more of a set of contracts
69
- # efficently - Eg checks all the quick validations and then tries the extended validations
70
-
71
- rows_before = len(df)
72
- df = df.merge(self.df[["segment_name", "segment_id"]], how="left", left_on="customer_id", right_index=True)
73
- rows_after = len(df)
74
- if rows_before != rows_after:
75
- raise ValueError("The number of rows before and after the merge do not match. This should not happen.")
76
-
77
- return df
124
+ self.df = hml_df
78
125
 
79
126
 
80
127
  class SegTransactionStats:
@@ -195,18 +242,41 @@ class SegTransactionStats:
195
242
  decimals = gu.get_decimals(ax.get_xlim(), ax.get_xticks())
196
243
  ax.xaxis.set_major_formatter(lambda x, pos: gu.human_format(x, pos, decimals=decimals))
197
244
 
198
- ax.set_title(gu.not_none(title, default_title))
199
- ax.set_ylabel(plot_y_label)
200
- ax.set_xlabel(plot_x_label)
245
+ ax.set_title(
246
+ gu.not_none(title, default_title),
247
+ fontproperties=gs.POPPINS_SEMI_BOLD,
248
+ fontsize=gs.DEFAULT_TITLE_FONT_SIZE,
249
+ pad=gs.DEFAULT_TITLE_PAD,
250
+ )
251
+ ax.set_ylabel(
252
+ plot_y_label,
253
+ fontproperties=gs.POPPINS_REG,
254
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
255
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
256
+ )
257
+ ax.set_xlabel(
258
+ plot_x_label,
259
+ fontproperties=gs.POPPINS_REG,
260
+ fontsize=gs.DEFAULT_AXIS_LABEL_FONT_SIZE,
261
+ labelpad=gs.DEFAULT_AXIS_LABEL_PAD,
262
+ )
201
263
 
202
264
  if source_text is not None:
203
265
  ax.annotate(
204
266
  source_text,
205
- xy=(-0.1, -0.2),
267
+ xy=(-0.1, -0.15),
206
268
  xycoords="axes fraction",
207
269
  ha="left",
208
270
  va="center",
209
- fontsize=10,
271
+ fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
272
+ fontproperties=gs.POPPINS_LIGHT_ITALIC,
273
+ color="dimgray",
210
274
  )
211
275
 
276
+ # Set the font properties for the tick labels
277
+ for tick in ax.get_xticklabels():
278
+ tick.set_fontproperties(gs.POPPINS_REG)
279
+ for tick in ax.get_yticklabels():
280
+ tick.set_fontproperties(gs.POPPINS_REG)
281
+
212
282
  return ax
@@ -0,0 +1,313 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from matplotlib.axes import Axes, SubplotBase
6
+ from pandas.tseries.offsets import BaseOffset
7
+
8
+ import pyretailscience.style.graph_utils as gu
9
+ from pyretailscience.style.graph_utils import GraphStyles as gs
10
+ from pyretailscience.style.tailwind import COLORS, get_linear_cmap
11
+
12
+ # TODO: Consider simplifying this by reducing the color range in the get_linear_cmap function.
13
+ COLORMAP_MIN = 0.25
14
+ COLORMAP_MAX = 0.75
15
+
16
+
17
+ def time_plot(
18
+ df: pd.DataFrame,
19
+ value_col: str,
20
+ period: str | BaseOffset = "D",
21
+ agg_func: str = "sum",
22
+ group_col: str | None = None,
23
+ title: str | None = None,
24
+ x_label: str | None = None,
25
+ y_label: str | None = None,
26
+ legend_title: str | None = None,
27
+ ax: Axes | None = None,
28
+ source_text: str = None,
29
+ **kwargs: dict[str, any],
30
+ ) -> SubplotBase:
31
+ """
32
+ Plots the value_col over time.
33
+
34
+ Args:
35
+ df (pd.DataFrame): The dataframe to plot.
36
+ value_col (str): The column to plot.
37
+ period (str | BaseOffset): The period to group the data by.
38
+ agg_func (str, optional): The aggregation function to apply to the value_col. Defaults to "sum".
39
+ group_col (str, optional): The column to group the data by. Defaults to None.
40
+ title (str, optional): The title of the plot. Defaults to None. When None the title is set to
41
+ `f"{value_col.title()} by {group_col.title()}"`
42
+ x_label (str, optional): The x-axis label. Defaults to None. When None the x-axis label is set to blank
43
+ y_label (str, optional): The y-axis label. Defaults to None. When None the y-axis label is set to the title
44
+ case of `value_col`
45
+ legend_title (str, optional): The title of the legend. Defaults to None. When None the legend title is set to
46
+ the title case of `group_col`
47
+ ax (Axes, optional): The matplotlib axes object to plot on. Defaults to None.
48
+ source_text (str, optional): The source text to add to the plot. Defaults to None.
49
+ **kwargs: Additional keyword arguments to pass to the Pandas plot function.
50
+
51
+ Returns:
52
+ SubplotBase: The matplotlib axes object.
53
+ """
54
+ df["transaction_period"] = df["transaction_datetime"].dt.to_period(period)
55
+
56
+ if group_col is None:
57
+ colors = COLORS["green"][500]
58
+ df = df.groupby("transaction_period")[value_col].agg(agg_func)
59
+ default_title = "Total Sales"
60
+ show_legend = False
61
+ else:
62
+ colors = get_linear_cmap("green")(np.linspace(COLORMAP_MIN, COLORMAP_MAX, df[group_col].nunique()))
63
+ df = (
64
+ df.groupby([group_col, "transaction_period"])[value_col]
65
+ .agg(agg_func)
66
+ .reset_index()
67
+ .pivot(index="transaction_period", columns=group_col, values=value_col)
68
+ )
69
+ default_title = f"{value_col.title()} by {group_col.title()}"
70
+ show_legend = True
71
+
72
+ ax = df.plot(
73
+ linewidth=3,
74
+ color=colors,
75
+ legend=show_legend,
76
+ ax=ax,
77
+ **kwargs,
78
+ )
79
+ ax = gu.standard_graph_styles(
80
+ ax,
81
+ title=gu.not_none(title, default_title),
82
+ x_label=gu.not_none(x_label, ""),
83
+ y_label=gu.not_none(y_label, value_col.title()),
84
+ )
85
+
86
+ decimals = gu.get_decimals(ax.get_ylim(), ax.get_yticks())
87
+ ax.yaxis.set_major_formatter(lambda x, pos: gu.human_format(x, pos, decimals=decimals))
88
+
89
+ if show_legend:
90
+ legend = ax.legend(
91
+ title=gu.not_none(legend_title, group_col.title()),
92
+ frameon=True,
93
+ )
94
+ legend.get_frame().set_facecolor("white")
95
+ legend.get_frame().set_edgecolor("white")
96
+
97
+ if source_text is not None:
98
+ ax.annotate(
99
+ source_text,
100
+ xy=(-0.1, -0.2),
101
+ xycoords="axes fraction",
102
+ ha="left",
103
+ va="center",
104
+ fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
105
+ fontproperties=gs.POPPINS_LIGHT_ITALIC,
106
+ color="dimgray",
107
+ )
108
+
109
+ # Set the font properties for the tick labels
110
+ for tick in ax.get_xticklabels():
111
+ tick.set_fontproperties(gs.POPPINS_REG)
112
+ for tick in ax.get_yticklabels():
113
+ tick.set_fontproperties(gs.POPPINS_REG)
114
+
115
+ return ax
116
+
117
+
118
+ def get_indexes(
119
+ df: pd.DataFrame,
120
+ df_index_filter: list[bool],
121
+ index_col: str,
122
+ value_col: str,
123
+ index_subgroup_col: str | None = None,
124
+ agg_func: str = "sum",
125
+ offset: int = 0,
126
+ ) -> pd.DataFrame:
127
+ """
128
+ Calculates the index of the value_col for the subset of a dataframe defined by df_index_filter.
129
+
130
+ Args:
131
+ df (pd.DataFrame): The dataframe to calculate the index on.
132
+ df_index_filter (list[bool]): The boolean index to filter the data by.
133
+ grp_cols (list[str]): The columns to group the data by.
134
+ value_col (str): The column to calculate the index on.
135
+ agg_func (str): The aggregation function to apply to the value_col.
136
+ offset (int, optional): The offset to subtract from the index. Defaults to 0.
137
+
138
+ Returns:
139
+ pd.Series: The index of the value_col for the subset of data defined by filter_index.
140
+ """
141
+ if all(df_index_filter) or not any(df_index_filter):
142
+ raise ValueError("The df_index_filter cannot be all True or all False.")
143
+
144
+ grp_cols = [index_col] if index_subgroup_col is None else [index_subgroup_col, index_col]
145
+
146
+ overall_df = df.groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col)
147
+ if index_subgroup_col is None:
148
+ overall_total = overall_df[value_col].sum()
149
+ else:
150
+ overall_total = overall_df.groupby(index_subgroup_col)[value_col].sum()
151
+ overall_s = overall_df[value_col] / overall_total
152
+
153
+ subset_df = df[df_index_filter].groupby(grp_cols)[value_col].agg(agg_func).to_frame(value_col)
154
+ if index_subgroup_col is None:
155
+ subset_total = subset_df[value_col].sum()
156
+ else:
157
+ subset_total = subset_df.groupby(index_subgroup_col)[value_col].sum()
158
+ subset_s = subset_df[value_col] / subset_total
159
+
160
+ index_df = ((subset_s / overall_s * 100) - offset).to_frame("index").reset_index()
161
+
162
+ return index_df
163
+
164
+
165
+ def index_plot(
166
+ df: pd.DataFrame,
167
+ df_index_filter: list[bool],
168
+ value_col: str,
169
+ group_col: str,
170
+ agg_func: str = "sum",
171
+ series_col: str | None = None,
172
+ title: str | None = None,
173
+ x_label: str = "Index",
174
+ y_label: str | None = None,
175
+ legend_title: str | None = None,
176
+ highlight_range: Literal["default"] | tuple[float, float] | None = "default",
177
+ sort_by: Literal["group", "value"] | None = "group",
178
+ sort_order: Literal["ascending", "descending"] = "ascending",
179
+ ax: Axes | None = None,
180
+ source_text: str = None,
181
+ exclude_groups: list[any] | None = None,
182
+ include_only_groups: list[any] | None = None,
183
+ **kwargs: dict[str, any],
184
+ ) -> SubplotBase:
185
+ """
186
+ Plots the value_col over time.
187
+
188
+ Args:
189
+ df (pd.DataFrame): The dataframe to plot.
190
+ df_index_filter (list[bool]): The filter to apply to the dataframe.
191
+ value_col (str): The column to plot.
192
+ group_col str: The column to group the data by.
193
+ agg_func (str, optional): The aggregation function to apply to the value_col. Defaults to "sum".
194
+ series_col (str, optional): The column to use as the series. Defaults to None.
195
+ title (str, optional): The title of the plot. Defaults to None. When None the title is set to
196
+ `f"{value_col.title()} by {group_col.title()}"`
197
+ x_label (str, optional): The x-axis label. Defaults to "Index".
198
+ y_label (str, optional): The y-axis label. Defaults to None. When None the y-axis label is set to the title
199
+ case of `group_col`
200
+ legend_title (str, optional): The title of the legend. Defaults to None. When None the legend title is set to
201
+ the title case of `group_col`
202
+ highlight_range (Literal["default"] | tuple[float, float] | None, optional): The range to highlight. Defaults
203
+ to "default". When "default" the range is set to (80, 120). When None no range is highlighted.
204
+ sort_by (Literal["group", "value"] | None, optional): The column to sort by. Defaults to "group". When None the
205
+ data is not sorted. When "group" the data is sorted by group_col. When "value" the data is sorted by
206
+ the value_col. When series_col is not None this option is ignored.
207
+ sort_order (Literal["ascending", "descending"], optional): The order to sort the data. Defaults to "ascending".
208
+ ax (Axes, optional): The matplotlib axes object to plot on. Defaults to None.
209
+ source_text (str, optional): The source text to add to the plot. Defaults to None.
210
+ exclude_groups (list[any], optional): The groups to exclude from the plot. Defaults to None.
211
+ include_only_groups (list[any], optional): The groups to include in the plot. Defaults to None. When None all
212
+ groups are included. When not None only the groups in the list are included. Can not be used with
213
+ exclude_groups.
214
+ **kwargs: Additional keyword arguments to pass to the Pandas plot function.
215
+
216
+ Returns:
217
+ SubplotBase: The matplotlib axes object.
218
+
219
+ Raises:
220
+ ValueError: If sort_by is not either "group" or "value" or None.
221
+ ValueError: If sort_order is not either "ascending" or "descending".
222
+ ValueError: If exclude_groups and include_only_groups are used together.
223
+ """
224
+
225
+ if sort_by is not None and sort_by not in ["group", "value"]:
226
+ raise ValueError("sort_by must be either 'group' or 'value' or None")
227
+ if sort_order not in ["ascending", "descending"]:
228
+ raise ValueError("sort_order must be either 'ascending' or 'descending'")
229
+ if exclude_groups is not None and include_only_groups is not None:
230
+ raise ValueError("exclude_groups and include_only_groups cannot be used together.")
231
+
232
+ index_df = get_indexes(
233
+ df=df,
234
+ df_index_filter=df_index_filter,
235
+ index_col=group_col,
236
+ index_subgroup_col=series_col,
237
+ value_col=value_col,
238
+ agg_func=agg_func,
239
+ offset=100,
240
+ )
241
+
242
+ if exclude_groups is not None:
243
+ index_df = index_df[~index_df[group_col].isin(exclude_groups)]
244
+ if include_only_groups is not None:
245
+ index_df = index_df[index_df[group_col].isin(include_only_groups)]
246
+
247
+ if series_col is None:
248
+ colors = COLORS["green"][500]
249
+ show_legend = False
250
+ index_df = index_df[[group_col, "index"]].set_index(group_col)
251
+ if sort_by == "group":
252
+ index_df = index_df.sort_values(by=group_col, ascending=sort_order == "ascending")
253
+ elif sort_by == "value":
254
+ index_df = index_df.sort_values(by="index", ascending=sort_order == "ascending")
255
+ else:
256
+ show_legend = True
257
+ colors = get_linear_cmap("green")(np.linspace(COLORMAP_MIN, COLORMAP_MAX, df[series_col].nunique()))
258
+
259
+ if sort_by == "group":
260
+ index_df = index_df.sort_values(by=[group_col, series_col], ascending=sort_order == "ascending")
261
+ index_df = index_df.pivot_table(index=group_col, columns=series_col, values="index", sort=False)
262
+
263
+ ax = index_df.plot.barh(
264
+ left=100,
265
+ legend=show_legend,
266
+ ax=ax,
267
+ color=colors,
268
+ width=gs.DEFAULT_BAR_WIDTH,
269
+ zorder=2,
270
+ **kwargs,
271
+ )
272
+
273
+ ax.axvline(100, color="black", linewidth=1, alpha=0.5)
274
+ if highlight_range == "default":
275
+ highlight_range = (80, 120)
276
+ if highlight_range is not None:
277
+ ax.axvline(highlight_range[0], color="black", linewidth=0.25, alpha=0.1, zorder=-1)
278
+ ax.axvline(highlight_range[1], color="black", linewidth=0.25, alpha=0.1, zorder=-1)
279
+ ax.axvspan(highlight_range[0], highlight_range[1], color="black", alpha=0.1, zorder=-1)
280
+
281
+ default_title = f"{value_col.title()} by {group_col.title()}"
282
+
283
+ ax = gu.standard_graph_styles(
284
+ ax=ax,
285
+ title=gu.not_none(title, default_title),
286
+ x_label=gu.not_none(x_label, "Index"),
287
+ y_label=gu.not_none(y_label, group_col.title()),
288
+ )
289
+
290
+ if show_legend:
291
+ legend = ax.legend(title=gu.not_none(legend_title, series_col.title()), frameon=True)
292
+ legend.get_frame().set_facecolor("white")
293
+ legend.get_frame().set_edgecolor("white")
294
+
295
+ if source_text is not None:
296
+ ax.annotate(
297
+ source_text,
298
+ xy=(-0.1, -0.2),
299
+ xycoords="axes fraction",
300
+ ha="left",
301
+ va="center",
302
+ fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
303
+ fontproperties=gs.POPPINS_LIGHT_ITALIC,
304
+ color="dimgray",
305
+ )
306
+
307
+ # Set the font properties for the tick labels
308
+ for tick in ax.get_xticklabels():
309
+ tick.set_fontproperties(gs.POPPINS_REG)
310
+ for tick in ax.get_yticklabels():
311
+ tick.set_fontproperties(gs.POPPINS_REG)
312
+
313
+ return ax
@@ -1,14 +1,29 @@
1
+ import matplotlib.font_manager as fm
1
2
  from matplotlib.axes import Axes
3
+ import importlib.resources as pkg_resources
4
+
5
+ ASSETS_PATH = pkg_resources.files("pyretailscience").joinpath("assets")
2
6
 
3
7
 
4
8
  class GraphStyles:
5
9
  """A class to hold the styles for a graph."""
6
10
 
7
- DEFAULT_TITLE_FONT_SIZE = 16
11
+ POPPINS_BOLD = fm.FontProperties(fname=f"{ASSETS_PATH}/fonts/Poppins-Bold.ttf")
12
+ POPPINS_SEMI_BOLD = fm.FontProperties(fname=f"{ASSETS_PATH}/fonts/Poppins-SemiBold.ttf")
13
+ POPPINS_REG = fm.FontProperties(fname=f"{ASSETS_PATH}/fonts/Poppins-Regular.ttf")
14
+ POPPINS_MED = fm.FontProperties(fname=f"{ASSETS_PATH}/fonts/Poppins-Medium.ttf")
15
+ POPPINS_LIGHT_ITALIC = fm.FontProperties(fname=f"{ASSETS_PATH}/fonts/Poppins-LightItalic.ttf")
16
+
17
+ DEFAULT_TITLE_FONT_SIZE = 20
8
18
  DEFAULT_SOURCE_FONT_SIZE = 10
9
19
  DEFAULT_AXIS_LABEL_FONT_SIZE = 12
10
20
  DEFAULT_TICK_LABEL_FONT_SIZE = 10
11
21
 
22
+ DEFAULT_AXIS_LABEL_PAD = 10
23
+ DEFAULT_TITLE_PAD = 10
24
+
25
+ DEFAULT_BAR_WIDTH = 0.8
26
+
12
27
 
13
28
  def human_format(num, pos=None, decimals=0, prefix="") -> str:
14
29
  """Format a number in a human readable format for Matplotlib.
@@ -31,7 +46,12 @@ def human_format(num, pos=None, decimals=0, prefix="") -> str:
31
46
  return f"{prefix}%.{decimals}f%s" % (num, ["", "K", "M", "G", "T", "P"][magnitude])
32
47
 
33
48
 
34
- def standard_graph_styles(ax: Axes) -> Axes:
49
+ def standard_graph_styles(
50
+ ax: Axes,
51
+ title: str | None = None,
52
+ x_label: str | None = None,
53
+ y_label: str | None = None,
54
+ ) -> Axes:
35
55
  """Apply standard styles to a Matplotlib graph.
36
56
 
37
57
  Args:
@@ -43,6 +63,31 @@ def standard_graph_styles(ax: Axes) -> Axes:
43
63
  ax.spines[["top", "right"]].set_visible(False)
44
64
  ax.grid(which="major", axis="x", color="#DAD8D7", alpha=0.5, zorder=1)
45
65
  ax.grid(which="major", axis="y", color="#DAD8D7", alpha=0.5, zorder=1)
66
+
67
+ if title is not None:
68
+ ax.set_title(
69
+ title,
70
+ fontproperties=GraphStyles.POPPINS_SEMI_BOLD,
71
+ fontsize=GraphStyles.DEFAULT_TITLE_FONT_SIZE,
72
+ pad=GraphStyles.DEFAULT_TITLE_PAD,
73
+ )
74
+
75
+ if x_label is not None:
76
+ ax.set_xlabel(
77
+ x_label,
78
+ fontproperties=GraphStyles.POPPINS_REG,
79
+ fontsize=GraphStyles.DEFAULT_AXIS_LABEL_FONT_SIZE,
80
+ labelpad=GraphStyles.DEFAULT_AXIS_LABEL_PAD,
81
+ )
82
+
83
+ if y_label is not None:
84
+ ax.set_ylabel(
85
+ y_label,
86
+ fontproperties=GraphStyles.POPPINS_REG,
87
+ fontsize=GraphStyles.DEFAULT_AXIS_LABEL_FONT_SIZE,
88
+ labelpad=GraphStyles.DEFAULT_AXIS_LABEL_PAD,
89
+ )
90
+
46
91
  return ax
47
92
 
48
93
 
@@ -1,96 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- from matplotlib.axes import Axes, SubplotBase
4
- from pandas.tseries.offsets import BaseOffset
5
-
6
- import pyretailscience.style.graph_utils as gu
7
- from pyretailscience.style.tailwind import COLORS, get_linear_cmap
8
-
9
- # TODO: Consider simplifying this by reducing the color range in the get_linear_cmap function.
10
- COLORMAP_MIN = 0.25
11
- COLORMAP_MAX = 0.75
12
-
13
-
14
- def time_plot(
15
- df: pd.DataFrame,
16
- value_col: str,
17
- period: str | BaseOffset = "D",
18
- agg_func: str = "sum",
19
- group_col: str | None = None,
20
- title: str | None = None,
21
- x_label: str | None = None,
22
- y_label: str | None = None,
23
- ax: Axes | None = None,
24
- source_text: str = None,
25
- **kwargs: dict[str, any],
26
- ) -> SubplotBase:
27
- """
28
- Plots the value_col over time.
29
-
30
- Args:
31
- df (pd.DataFrame): The dataframe to plot.
32
- value_col (str): The column to plot.
33
- period (str | BaseOffset): The period to group the data by.
34
- agg_func (str, optional): The aggregation function to apply to the value_col. Defaults to "sum".
35
- group_col (str, optional): The column to group the data by. Defaults to None.
36
- title (str, optional): The title of the plot. Defaults to None. When None the title is set to
37
- `f"{value_col.title()} by {group_col.title()}"`
38
- x_label (str, optional): The x-axis label. Defaults to None. When None the x-axis label is set to blank
39
- y_label (str, optional): The y-axis label. Defaults to None. When None the y-axis label is set to the title
40
- case of `value_col`
41
- ax (Axes, optional): The matplotlib axes object to plot on. Defaults to None.
42
- source_text (str, optional): The source text to add to the plot. Defaults to None.
43
- **kwargs: Additional keyword arguments to pass to the Pandas plot function.
44
-
45
- Returns:
46
- SubplotBase: The matplotlib axes object.
47
- """
48
- df["transaction_period"] = df["transaction_datetime"].dt.to_period(period)
49
-
50
- if group_col is None:
51
- colors = COLORS["green"][500]
52
- df = df.groupby("transaction_period")[value_col].agg(agg_func)
53
- default_title = "Total Sales"
54
- show_legend = False
55
- else:
56
- colors = get_linear_cmap("green")(np.linspace(COLORMAP_MIN, COLORMAP_MAX, df[group_col].nunique()))
57
- df = (
58
- df.groupby([group_col, "transaction_period"])[value_col]
59
- .agg(agg_func)
60
- .reset_index()
61
- .pivot(index="transaction_period", columns=group_col, values=value_col)
62
- )
63
- default_title = f"{value_col.title()} by {group_col.title()}"
64
- show_legend = True
65
-
66
- ax = df.plot(
67
- linewidth=3,
68
- color=colors,
69
- legend=show_legend,
70
- ax=ax,
71
- **kwargs,
72
- )
73
- ax = gu.standard_graph_styles(ax)
74
- ax.set_ylabel(gu.not_none(y_label, value_col.title()))
75
- ax.set_title(gu.not_none(title, default_title))
76
- ax.set_xlabel(gu.not_none(x_label, ""))
77
-
78
- decimals = gu.get_decimals(ax.get_ylim(), ax.get_yticks())
79
- ax.yaxis.set_major_formatter(lambda x, pos: gu.human_format(x, pos, decimals=decimals))
80
-
81
- if show_legend:
82
- legend = ax.legend(title="Segment", frameon=True)
83
- legend.get_frame().set_facecolor("white")
84
- legend.get_frame().set_edgecolor("white")
85
-
86
- if source_text is not None:
87
- ax.annotate(
88
- source_text,
89
- xy=(-0.1, -0.2),
90
- xycoords="axes fraction",
91
- ha="left",
92
- va="center",
93
- fontsize=10,
94
- )
95
-
96
- return ax
File without changes