max-div 0.0.3__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
max_div/_cli.py ADDED
@@ -0,0 +1,99 @@
1
+ """Command-line interface for max-div."""
2
+
3
+ import click
4
+
5
+ from max_div.benchmark import benchmark_randint as _benchmark_randint
6
+ from max_div.benchmark import benchmark_randint_constrained as _benchmark_randint_constrained
7
+
8
+
9
+ # -------------------------------------------------------------------------
10
+ # Main CLI Group
11
+ # -------------------------------------------------------------------------
12
+ @click.group()
13
+ def cli():
14
+ """max-div: Flexible Solver for Maximum Diversity Problems with Fairness Constraints."""
15
+ pass
16
+
17
+
18
+ # -------------------------------------------------------------------------
19
+ # Benchmarking Commands
20
+ # -------------------------------------------------------------------------
21
+ @cli.group()
22
+ @click.option(
23
+ "--turbo",
24
+ is_flag=True,
25
+ default=False,
26
+ help="Run shorter, less accurate benchmark; identical to --speed=1.0; intended for testing purposes.",
27
+ )
28
+ @click.option(
29
+ "--speed",
30
+ default=0.0,
31
+ help="Values closer to 1.0 result in shorter, less accurate benchmark; Overridden by --turbo when provided.",
32
+ )
33
+ @click.option(
34
+ "--markdown",
35
+ is_flag=True,
36
+ default=False,
37
+ help="Output benchmark results in Markdown table format.",
38
+ )
39
+ @click.pass_context
40
+ def benchmark(ctx, turbo: bool, speed: float, markdown: bool):
41
+ """Benchmarking commands."""
42
+ # Store flags in context so subcommands can access them
43
+ ctx.ensure_object(dict)
44
+ if turbo:
45
+ ctx.obj["speed"] = 1.0
46
+ else:
47
+ ctx.obj["speed"] = speed
48
+ ctx.obj["markdown"] = markdown
49
+
50
+
51
+ @benchmark.command(name="randint")
52
+ @click.pass_context
53
+ def randint(ctx):
54
+ """Benchmarks the `randint` function from `max_div.sampling.uncon`."""
55
+ speed = ctx.obj["speed"]
56
+ markdown = ctx.obj["markdown"]
57
+ _benchmark_randint(speed=speed, markdown=markdown)
58
+
59
+
60
+ @benchmark.command(name="randint_constrained")
61
+ @click.pass_context
62
+ def randint_constrained(ctx):
63
+ """Benchmarks the `randint_constrained` function from `max_div.sampling.con`."""
64
+ speed = ctx.obj["speed"]
65
+ markdown = ctx.obj["markdown"]
66
+ _benchmark_randint_constrained(speed=speed, markdown=markdown)
67
+
68
+
69
+ # -------------------------------------------------------------------------
70
+ # Misc Commands
71
+ # -------------------------------------------------------------------------
72
+ @cli.command()
73
+ def numba_status():
74
+ """Show Numba version, llvmlite version, and configuration including SVML status."""
75
+ import llvmlite
76
+ import numba
77
+
78
+ click.echo(f"Numba version : {numba.__version__}")
79
+ click.echo(f"llvmlite version : {llvmlite.__version__}")
80
+
81
+ # Show key configuration settings
82
+ from numba import config
83
+
84
+ click.echo("\nNumba Configuration:")
85
+ click.echo("-" * 50)
86
+ click.echo(f"SVML enabled : {config.USING_SVML}")
87
+ click.echo(f"Threading layer : {config.THREADING_LAYER}")
88
+ click.echo(f"Number of threads : {config.NUMBA_NUM_THREADS}")
89
+ click.echo(f"Optimization level : {config.OPT}")
90
+ click.echo(f"Debug mode : {config.DEBUG}")
91
+ click.echo(f"Disable JIT : {config.DISABLE_JIT}")
92
+ click.echo("-" * 50)
93
+
94
+
95
+ # -------------------------------------------------------------------------
96
+ # Entrypoint
97
+ # -------------------------------------------------------------------------
98
+ if __name__ == "__main__":
99
+ cli()
@@ -1 +1,2 @@
1
- from .sample_int import benchmark_sample_int
1
+ from .randint import benchmark_randint
2
+ from .randint_constrained import benchmark_randint_constrained
@@ -0,0 +1,218 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Literal
4
+
5
+ from max_div.internal.benchmarking import BenchmarkResult
6
+ from max_div.internal.formatting import md_bold, md_colored, md_table
7
+
8
+
9
+ # =================================================================================================
10
+ # Helper classes / types
11
+ # =================================================================================================
12
+ @dataclass
13
+ class Percentage:
14
+ frac: float # fraction between 0.0 and 1.0
15
+ decimals: int = 1 # number of decimals to display
16
+
17
+ def __str__(self):
18
+ return f"{(self.frac * 100):.{self.decimals}f}%"
19
+
20
+
21
+ CellContent = str | BenchmarkResult | Percentage
22
+
23
+
24
+ # =================================================================================================
25
+ # Aggregation
26
+ # =================================================================================================
27
+ def extend_table_with_aggregate_row(
28
+ data: list[list[CellContent]],
29
+ agg: Literal["mean", "geomean", "sum"],
30
+ include_benchmark_result: bool = True,
31
+ include_percentage: bool = True,
32
+ ) -> list[list[CellContent]]:
33
+ """
34
+ This function adds aggregate statistics for BenchmarkResult | Percentage (=Aggregatable) columns to the data table.
35
+
36
+ Extend an extra row to the provided data that contains aggregate statistics of the provided data:
37
+ - for each column that has at least 1 row containing a Aggregatable object, compute an aggregate
38
+ - all other columns are kept empty
39
+
40
+ The last column not containing any Aggregatable objects that comes before the first column containing
41
+ Aggregatable objects is used as label for the aggregate row, based on the 'agg' argument, capitalized.
42
+
43
+ BenchmarkResults are aggregated by aggregation the q25, q50, and q75 times separately.
44
+ Percentage objects are aggregated with decimals equal to max of what we observed in that col.
45
+ """
46
+ n_cols = len(data[0])
47
+
48
+ Aggregatable = BenchmarkResult | Percentage
49
+
50
+ # Identify which columns contain Aggregatable objects
51
+ has_aggregatable = [False] * n_cols
52
+ for row in data:
53
+ for col_idx, cell in enumerate(row):
54
+ if isinstance(cell, Aggregatable):
55
+ has_aggregatable[col_idx] = True
56
+
57
+ # Find the first column with Aggregatable objects
58
+ first_aggregatable_col = None
59
+ for col_idx, has_result in enumerate(has_aggregatable):
60
+ if has_result:
61
+ first_aggregatable_col = col_idx
62
+ break
63
+
64
+ # Find the last non-Aggregatable column before the first Aggregatable column
65
+ label_col = None
66
+ for col_idx in range(first_aggregatable_col - 1, -1, -1):
67
+ if not has_aggregatable[col_idx]:
68
+ label_col = col_idx
69
+ break
70
+
71
+ # Create the aggregate row
72
+ agg_row: list[CellContent] = [""] * n_cols
73
+
74
+ # Set the label if we found a label column
75
+ if label_col is not None:
76
+ agg_row[label_col] = agg.capitalize() + ":"
77
+
78
+ # Compute aggregates for each column with BenchmarkResult objects
79
+ for col_idx in range(n_cols):
80
+ if include_benchmark_result:
81
+ # Collect all BenchmarkResult values from this column
82
+ results = [row[col_idx] for row in data if isinstance(row[col_idx], BenchmarkResult)]
83
+ if results: # Only compute if we have values
84
+ agg_row[col_idx] = BenchmarkResult.aggregate(results, method=agg)
85
+
86
+ if include_percentage:
87
+ # Collect all Percentage values from this column
88
+ percentages = [row[col_idx] for row in data if isinstance(row[col_idx], Percentage)]
89
+ if percentages: # Only compute if we have values
90
+ # Compute average fraction and max decimals
91
+ avg_frac = sum(p.frac for p in percentages) / len(percentages)
92
+ max_decimals = max(p.decimals for p in percentages)
93
+ agg_row[col_idx] = Percentage(frac=avg_frac, decimals=max_decimals + 1)
94
+
95
+ # Return data with the aggregate row appended
96
+ return data + [agg_row]
97
+
98
+
99
+ # =================================================================================================
100
+ # Markdown highlighters
101
+ # =================================================================================================
102
+ class HighLighter(ABC):
103
+ @abstractmethod
104
+ def process_row(self, row: list[CellContent]) -> list[CellContent]:
105
+ raise NotImplementedError()
106
+
107
+
108
+ class FastestBenchmark(HighLighter):
109
+ def __init__(self, bold: bool = True, color: str = "#00aa00"):
110
+ self.bold = bold
111
+ self.color = color
112
+
113
+ def process_row(self, row: list[CellContent]) -> list[CellContent]:
114
+ if any(isinstance(value, BenchmarkResult) for value in row):
115
+ # Find the fastest BenchmarkResult (minimum median time)
116
+ t_q50_min = min([value.t_sec_q_50 for value in row if isinstance(value, BenchmarkResult)])
117
+
118
+ # Convert row to strings, highlighting the results with t_q25 <= t_q50_min
119
+ converted_row: list[CellContent] = []
120
+ for i, value in enumerate(row):
121
+ if isinstance(value, BenchmarkResult):
122
+ text = str(value)
123
+ if value.t_sec_q_25 <= t_q50_min:
124
+ if self.bold:
125
+ text = md_bold(text)
126
+ text = md_colored(text, self.color)
127
+ converted_row.append(text)
128
+ else:
129
+ converted_row.append(value)
130
+ return converted_row
131
+ else:
132
+ return row
133
+
134
+
135
+ class HighestPercentage(HighLighter):
136
+ def __init__(self, bold: bool = True, color: str = "#00aa00"):
137
+ self.bold = bold
138
+ self.color = color
139
+
140
+ def process_row(self, row: list[CellContent]) -> list[CellContent]:
141
+ if any(isinstance(value, Percentage) for value in row):
142
+ # Find the highest Percentage (maximum frac)
143
+ max_perc = max([value for value in row if isinstance(value, Percentage)], key=lambda x: x.frac)
144
+
145
+ # Convert row to strings, highlighting the results with frac == max_frac
146
+ converted_row: list[CellContent] = []
147
+ for i, value in enumerate(row):
148
+ if isinstance(value, Percentage):
149
+ text = str(value)
150
+ if text == str(max_perc): # make green if its str-representation is equal
151
+ if self.bold:
152
+ text = md_bold(text)
153
+ text = md_colored(text, self.color)
154
+ converted_row.append(text)
155
+ else:
156
+ converted_row.append(value)
157
+ return converted_row
158
+ else:
159
+ return row
160
+
161
+
162
+ class BoldLabels(HighLighter):
163
+ def process_row(self, row: list[CellContent]) -> list[CellContent]:
164
+ converted_row: list[CellContent] = []
165
+ for value in row:
166
+ if isinstance(value, str) and value.endswith(":"):
167
+ converted_row.append(md_bold(value))
168
+ else:
169
+ converted_row.append(value)
170
+ return converted_row
171
+
172
+
173
+ # =================================================================================================
174
+ # Formatting
175
+ # =================================================================================================
176
+ def format_as_markdown(
177
+ headers: list[str], data: list[list[CellContent]], highlighters: list[HighLighter] | None = None
178
+ ) -> list[str]:
179
+ """
180
+ Format benchmark data as a Markdown table.
181
+
182
+ Converts BenchmarkResult objects to strings using t_sec_with_uncertainty_str.
183
+ The fastest BenchmarkResult in each row is highlighted in bold and green.
184
+
185
+ :param headers: List of column headers
186
+ :param data: 2D list where each row contains strings and/or BenchmarkResult objects
187
+ :param highlighters: Optional list of HighLighter objects to apply to each row
188
+ :return: List of strings representing the Markdown table lines
189
+ """
190
+ # Convert data to string format and identify the fastest results
191
+ converted_data: list[list[str]] = [headers]
192
+
193
+ for row in data:
194
+ # highlight if requested
195
+ for highlighter in highlighters or []:
196
+ row = highlighter.process_row(row)
197
+
198
+ # convert to str
199
+ row = [str(cell) for cell in row]
200
+
201
+ # append to converted data
202
+ converted_data.append(row)
203
+
204
+ return md_table(converted_data)
205
+
206
+
207
+ def format_for_console(headers: list[str], data: list[list[CellContent]]) -> list[str]:
208
+ """Similar to `format_as_markdown`, but without extensive formatting, to keep it readable with rendering."""
209
+ table_data = [headers]
210
+ for row in data:
211
+ converted_row: list[str] = []
212
+ for cell in row:
213
+ if isinstance(cell, BenchmarkResult):
214
+ converted_row.append(cell.t_sec_with_uncertainty_str)
215
+ else:
216
+ converted_row.append(str(cell))
217
+ table_data.append(converted_row)
218
+ return md_table(table_data)
@@ -0,0 +1,104 @@
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+
4
+ from max_div.internal.benchmarking import BenchmarkResult, benchmark
5
+ from max_div.sampling.uncon import randint_numba, randint_numpy
6
+
7
+ from ._formatting import (
8
+ BoldLabels,
9
+ CellContent,
10
+ FastestBenchmark,
11
+ extend_table_with_aggregate_row,
12
+ format_as_markdown,
13
+ format_for_console,
14
+ )
15
+
16
+
17
+ def benchmark_randint(speed: float = 0.0, markdown: bool = False) -> None:
18
+ """
19
+ Benchmarks the `randint` function from `max_div.sampling.uncon`.
20
+
21
+ Different scenarios are tested:
22
+
23
+ * with & without replacement
24
+ * uniform & non-uniform sampling
25
+ * `use_numba` True and False
26
+ * different sizes of (`n`, `k`):
27
+ * both `n` & `k` are varied across [1, 10, 100, 1000, 10000]
28
+ * all valid combinations are tested (if `replace==False` we don't test `k`>`n`)
29
+
30
+ :param speed: value in [0.0, 1.0] (default=0.0); 0.0=accurate but slow; 1.0=fast but less accurate
31
+ :param markdown: If `True`, outputs the results as a Markdown table.
32
+ """
33
+
34
+ print("Benchmarking `randint`...")
35
+ print()
36
+
37
+ for replace, use_p, desc in [
38
+ (True, False, "A. WITH replacement, UNIFORM probabilities"),
39
+ (False, False, "B. WITHOUT replacement, UNIFORM probabilities"),
40
+ (True, True, "C. WITH replacement, CUSTOM probabilities"),
41
+ (False, True, "D. WITHOUT replacement, CUSTOM probabilities"),
42
+ ]:
43
+ if markdown:
44
+ print(f"## {desc}")
45
+ else:
46
+ print(f"{desc}:")
47
+
48
+ # --- create headers ------------------------------
49
+ if markdown:
50
+ headers = [
51
+ "`k`",
52
+ "`n`",
53
+ "`randint_numpy`",
54
+ "`randint_numba`",
55
+ ]
56
+ else:
57
+ headers = ["k", "n", "randint_numpy", "randint_numba"]
58
+
59
+ # --- benchmark ------------------------------------
60
+ data: list[list[CellContent]] = []
61
+ n_k_values = [(n, k) for n in [10, 100, 1000, 10000] for k in [1, 10, 100, 1000, 10000] if replace or (k <= n)]
62
+ for n, k in tqdm(n_k_values, leave=False):
63
+ data_row: list[CellContent] = [str(k), str(n)]
64
+
65
+ for use_numba in [False, True]:
66
+ if use_p:
67
+ p = np.random.rand(n)
68
+ p /= p.sum()
69
+ else:
70
+ p = np.zeros(0)
71
+ p = p.astype(np.float32)
72
+
73
+ if use_numba:
74
+
75
+ def func_to_benchmark():
76
+ randint_numba(n=n, k=k, replace=replace, p=p)
77
+ else:
78
+
79
+ def func_to_benchmark():
80
+ randint_numpy(n=n, k=k, replace=replace, p=p)
81
+
82
+ data_row.append(
83
+ benchmark(
84
+ f=func_to_benchmark,
85
+ t_per_run=0.05 / (1000.0**speed),
86
+ n_warmup=int(8 - 5 * speed),
87
+ n_benchmark=int(25 - 22 * speed),
88
+ silent=True,
89
+ )
90
+ )
91
+
92
+ data.append(data_row)
93
+
94
+ # --- show results -----------------------------------------
95
+ data = extend_table_with_aggregate_row(data, agg="geomean")
96
+ if markdown:
97
+ display_data = format_as_markdown(headers, data, highlighters=[FastestBenchmark(), BoldLabels()])
98
+ else:
99
+ display_data = format_for_console(headers, data)
100
+
101
+ print()
102
+ for line in display_data:
103
+ print(line)
104
+ print()