google-ngrams 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_ngrams/__init__.py +19 -0
- google_ngrams/data/__init__.py +14 -0
- google_ngrams/data/googlebooks_eng_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/data/googlebooks_eng_gb_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/data/googlebooks_eng_us_all_totalcounts_20120701.parquet +0 -0
- google_ngrams/ngrams.py +341 -0
- google_ngrams/scatter_helpers.py +187 -0
- google_ngrams/vnc.py +518 -0
- google_ngrams/vnc_helpers.py +809 -0
- google_ngrams-0.2.0.dist-info/METADATA +144 -0
- google_ngrams-0.2.0.dist-info/RECORD +14 -0
- google_ngrams-0.2.0.dist-info/WHEEL +5 -0
- google_ngrams-0.2.0.dist-info/licenses/LICENSE +162 -0
- google_ngrams-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
# flake8: noqa
|
2
|
+
|
3
|
+
# Set version ----
|
4
|
+
from importlib.metadata import version as _v, PackageNotFoundError as _PNF
|
5
|
+
|
6
|
+
try:
|
7
|
+
__version__ = _v("google_ngrams")
|
8
|
+
except _PNF: # Fallback when running from source without installed metadata
|
9
|
+
__version__ = "0.0.0"
|
10
|
+
|
11
|
+
del _v
|
12
|
+
|
13
|
+
# Imports ----
|
14
|
+
|
15
|
+
from .ngrams import google_ngram
|
16
|
+
|
17
|
+
from .vnc import TimeSeries
|
18
|
+
|
19
|
+
__all__ = ['google_ngram', 'TimeSeries']
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# flake8: noqa
|
2
|
+
|
3
|
+
from importlib.resources import files as _files
|
4
|
+
|
5
|
+
sources = {
|
6
|
+
"eng_all": _files("google_ngrams") / "data/googlebooks_eng_all_totalcounts_20120701.parquet",
|
7
|
+
"gb_all": _files("google_ngrams") / "data/googlebooks_eng_gb_all_totalcounts_20120701.parquet",
|
8
|
+
"us_all": _files("google_ngrams") / "data/googlebooks_eng_us_all_totalcounts_20120701.parquet",
|
9
|
+
}
|
10
|
+
|
11
|
+
|
12
|
+
def __dir__():
|
13
|
+
return list(sources)
|
14
|
+
|
Binary file
|
Binary file
|
Binary file
|
google_ngrams/ngrams.py
ADDED
@@ -0,0 +1,341 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import polars as pl
|
4
|
+
import warnings
|
5
|
+
import logging
|
6
|
+
from textwrap import dedent
|
7
|
+
from typing import List
|
8
|
+
from .data import sources
|
9
|
+
|
10
|
+
|
11
|
+
def google_ngram(
|
12
|
+
word_forms: List[str],
|
13
|
+
variety="eng",
|
14
|
+
by="decade"
|
15
|
+
) -> pl.DataFrame:
|
16
|
+
"""
|
17
|
+
Fetches Google Ngram data for specified word forms.
|
18
|
+
|
19
|
+
This function retrieves ngram data from the Google Books Ngram Viewer
|
20
|
+
for the given word forms. It supports different varieties of English
|
21
|
+
(e.g., British, American) and allows aggregation by year or decade.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
word_forms : List
|
26
|
+
List of word forms to search for.
|
27
|
+
variety : str
|
28
|
+
Variety of English ('eng', 'gb', 'us').
|
29
|
+
by : str
|
30
|
+
Aggregation level ('year' or 'decade').
|
31
|
+
|
32
|
+
Returns
|
33
|
+
-------
|
34
|
+
pl.DataFrame
|
35
|
+
DataFrame containing the ngram data.
|
36
|
+
"""
|
37
|
+
variety_types = ["eng", "gb", "us"]
|
38
|
+
if variety not in variety_types:
|
39
|
+
raise ValueError("""variety_types
|
40
|
+
Invalid variety type. Expected one of: %s
|
41
|
+
""" % variety_types)
|
42
|
+
by_types = ["year", "decade"]
|
43
|
+
if by not in by_types:
|
44
|
+
raise ValueError("""variety_types
|
45
|
+
Invalid by type. Expected one of: %s
|
46
|
+
""" % by_types)
|
47
|
+
word_forms = [re.sub(r'([a-zA-Z0-9])-([a-zA-Z0-9])',
|
48
|
+
r'\1 - \2', wf) for wf in word_forms]
|
49
|
+
word_forms = [wf.strip() for wf in word_forms]
|
50
|
+
n = [len(re.findall(r'\S+', wf)) for wf in word_forms]
|
51
|
+
n = list(set(n))
|
52
|
+
|
53
|
+
if len(n) > 1:
|
54
|
+
raise ValueError("""Check spelling.
|
55
|
+
Word forms should be lemmas of the same word
|
56
|
+
(e.g. 'teenager' and 'teenagers'
|
57
|
+
or 'walk', 'walks' and 'walked'
|
58
|
+
""")
|
59
|
+
if n[0] > 5:
|
60
|
+
raise ValueError("""Ngrams can be a maximum of 5 tokens.
|
61
|
+
Hyphenated words are split and include the hyphen,
|
62
|
+
so 'x-ray' would count as 3 tokens.
|
63
|
+
""")
|
64
|
+
|
65
|
+
gram = [wf[:2] if n[0] > 1 else wf[:1] for wf in word_forms]
|
66
|
+
gram = list(set([g.lower() for g in gram]))
|
67
|
+
|
68
|
+
if len(gram) > 1:
|
69
|
+
raise ValueError("""Check spelling.
|
70
|
+
Word forms should be lemmas of the same word
|
71
|
+
(e.g. 'teenager' and 'teenagers'
|
72
|
+
or 'walk', 'walks' and 'walked'
|
73
|
+
""")
|
74
|
+
|
75
|
+
if re.match(r'^[a-z][^a-z]', gram[0]):
|
76
|
+
gram[0] = re.sub(r'[^a-z]', '_', gram[0])
|
77
|
+
if re.match(r'^[0-9]', gram[0]):
|
78
|
+
gram[0] = gram[0][:1]
|
79
|
+
if re.match(r'^[\W]', gram[0]):
|
80
|
+
gram[0] = "punctuation"
|
81
|
+
|
82
|
+
if any(re.match(r'^[ßæðøłœıƒþȥəħŋªºɣđijɔȝⅰʊʌʔɛȡɋⅱʃɇɑⅲ]', g) for g in gram):
|
83
|
+
gram[0] = "other"
|
84
|
+
|
85
|
+
gram[0] = gram[0].encode('latin-1', 'replace').decode('latin-1')
|
86
|
+
|
87
|
+
# Use HTTPS for integrity (Google Storage supports it) instead of HTTP
|
88
|
+
if variety == "eng":
|
89
|
+
repo = f"https://storage.googleapis.com/books/ngrams/books/googlebooks-eng-all-{n[0]}gram-20120701-{gram[0]}.gz" # noqa: E501
|
90
|
+
else:
|
91
|
+
repo = f"https://storage.googleapis.com/books/ngrams/books/googlebooks-eng-{variety}-all-{n[0]}gram-20120701-{gram[0]}.gz" # noqa: E501
|
92
|
+
|
93
|
+
logger = logging.getLogger(__name__)
|
94
|
+
logger.info(dedent(
|
95
|
+
"""
|
96
|
+
Accessing repository. For larger ones
|
97
|
+
(e.g., ngrams containing 2 or more words).
|
98
|
+
This may take a few minutes...
|
99
|
+
"""
|
100
|
+
))
|
101
|
+
|
102
|
+
# Preserve exact tokens for equality filtering in non-regex fallbacks
|
103
|
+
tokens_exact = list(word_forms)
|
104
|
+
word_forms = [re.sub(
|
105
|
+
r'(\.|\?|\$|\^|\)|\(|\}|\{|\]|\[|\*|\+|\|)',
|
106
|
+
r'\\\1', wf
|
107
|
+
) for wf in word_forms]
|
108
|
+
|
109
|
+
grep_words = "|".join([f"^{wf}$" for wf in word_forms])
|
110
|
+
|
111
|
+
# Read the data from the google repository and format
|
112
|
+
schema = {"column_1": pl.String,
|
113
|
+
"column_2": pl.Int64,
|
114
|
+
"column_3": pl.Int64,
|
115
|
+
"column_4": pl.Int64}
|
116
|
+
try:
|
117
|
+
df = pl.scan_csv(
|
118
|
+
repo,
|
119
|
+
separator='\t',
|
120
|
+
has_header=False,
|
121
|
+
schema=schema,
|
122
|
+
truncate_ragged_lines=True,
|
123
|
+
low_memory=True,
|
124
|
+
quote_char=None,
|
125
|
+
ignore_errors=True,
|
126
|
+
)
|
127
|
+
except TypeError:
|
128
|
+
# Fallback for environments/tests that monkeypatch scan_csv with a
|
129
|
+
# limited signature. Use minimal, widely-supported args.
|
130
|
+
df = pl.scan_csv(repo, separator='\t', has_header=False, schema=schema)
|
131
|
+
# Push down filter and projection before collection to minimize memory
|
132
|
+
filtered_df = (
|
133
|
+
df
|
134
|
+
.filter(pl.col("column_1").str.contains(r"(?i)" + grep_words))
|
135
|
+
.select([
|
136
|
+
pl.col("column_1").alias("Token"),
|
137
|
+
pl.col("column_2").alias("Year"),
|
138
|
+
pl.col("column_3").alias("AF"),
|
139
|
+
])
|
140
|
+
)
|
141
|
+
|
142
|
+
# Optional: allow tuning streaming batch size via env
|
143
|
+
try:
|
144
|
+
chunk_sz = os.environ.get("POLARS_STREAMING_CHUNK_SIZE")
|
145
|
+
if chunk_sz:
|
146
|
+
pl.Config.set_streaming_chunk_size(int(chunk_sz))
|
147
|
+
except Exception:
|
148
|
+
pass
|
149
|
+
|
150
|
+
# Collect with streaming fallback for stability across polars versions
|
151
|
+
try:
|
152
|
+
logger.debug("Collecting with engine='streaming'.")
|
153
|
+
all_grams = filtered_df.collect(engine="streaming")
|
154
|
+
except Exception:
|
155
|
+
try:
|
156
|
+
# Older streaming path (deprecated in newer Polars)
|
157
|
+
logger.debug("Collecting with deprecated streaming=True path.")
|
158
|
+
with warnings.catch_warnings():
|
159
|
+
warnings.filterwarnings(
|
160
|
+
"ignore",
|
161
|
+
category=DeprecationWarning,
|
162
|
+
message=r"the `streaming` parameter was deprecated.*",
|
163
|
+
)
|
164
|
+
all_grams = filtered_df.collect( # type: ignore[arg-type]
|
165
|
+
streaming=True
|
166
|
+
)
|
167
|
+
except Exception:
|
168
|
+
try:
|
169
|
+
# Plain in-memory collect
|
170
|
+
logger.debug(
|
171
|
+
"Collecting with in-memory engine (no streaming)."
|
172
|
+
)
|
173
|
+
all_grams = filtered_df.collect()
|
174
|
+
except Exception:
|
175
|
+
# Final memory-safe fallback: batched CSV reader with
|
176
|
+
# per-batch filter
|
177
|
+
logger.debug(
|
178
|
+
"Falling back to batched CSV reader + per-batch filter."
|
179
|
+
)
|
180
|
+
batch_sz = int(
|
181
|
+
os.environ.get("POLARS_CSV_BATCH_SIZE", "200000")
|
182
|
+
)
|
183
|
+
try:
|
184
|
+
reader = pl.read_csv_batched(
|
185
|
+
repo,
|
186
|
+
separator='\t',
|
187
|
+
has_header=False,
|
188
|
+
ignore_errors=True,
|
189
|
+
low_memory=True,
|
190
|
+
batch_size=batch_sz,
|
191
|
+
)
|
192
|
+
filtered_batches = []
|
193
|
+
# Prefer equality match for speed and stability
|
194
|
+
try:
|
195
|
+
for batch in reader: # type: ignore[assignment]
|
196
|
+
fb = (
|
197
|
+
batch
|
198
|
+
.filter(pl.col("column_1").is_in(tokens_exact))
|
199
|
+
.select([
|
200
|
+
pl.col("column_1").alias("Token"),
|
201
|
+
pl.col("column_2").alias("Year"),
|
202
|
+
pl.col("column_3").alias("AF"),
|
203
|
+
])
|
204
|
+
)
|
205
|
+
if fb.height:
|
206
|
+
filtered_batches.append(fb)
|
207
|
+
except TypeError:
|
208
|
+
# Fallback for alternate reader APIs
|
209
|
+
while True:
|
210
|
+
try:
|
211
|
+
batches = reader.next_batches(1)
|
212
|
+
except AttributeError:
|
213
|
+
break
|
214
|
+
if not batches:
|
215
|
+
break
|
216
|
+
batch = batches[0]
|
217
|
+
fb = (
|
218
|
+
batch
|
219
|
+
.filter(pl.col("column_1").is_in(tokens_exact))
|
220
|
+
.select([
|
221
|
+
pl.col("column_1").alias("Token"),
|
222
|
+
pl.col("column_2").alias("Year"),
|
223
|
+
pl.col("column_3").alias("AF"),
|
224
|
+
])
|
225
|
+
)
|
226
|
+
if fb.height:
|
227
|
+
filtered_batches.append(fb)
|
228
|
+
|
229
|
+
if filtered_batches:
|
230
|
+
all_grams = pl.concat(filtered_batches)
|
231
|
+
else:
|
232
|
+
all_grams = pl.DataFrame({
|
233
|
+
"Token": pl.Series([], dtype=pl.String),
|
234
|
+
"Year": pl.Series([], dtype=pl.Int64),
|
235
|
+
"AF": pl.Series([], dtype=pl.Int64),
|
236
|
+
})
|
237
|
+
except Exception as e:
|
238
|
+
# If batched reader is unavailable, re-raise with guidance
|
239
|
+
raise RuntimeError(
|
240
|
+
"Polars batched CSV reader fallback failed; consider "
|
241
|
+
"upgrading Polars or disabling this code path via "
|
242
|
+
"environment if necessary."
|
243
|
+
) from e
|
244
|
+
|
245
|
+
# read totals
|
246
|
+
if variety == "eng":
|
247
|
+
f_path = sources.get("eng_all")
|
248
|
+
elif variety == "gb":
|
249
|
+
f_path = sources.get("gb_all")
|
250
|
+
elif variety == "us":
|
251
|
+
f_path = sources.get("us_all")
|
252
|
+
|
253
|
+
total_counts = pl.read_parquet(f_path)
|
254
|
+
# format totals, fill missing data, and sum
|
255
|
+
total_counts = total_counts.cast({
|
256
|
+
"Year": pl.UInt32,
|
257
|
+
"Total": pl.UInt64,
|
258
|
+
"Pages": pl.UInt64,
|
259
|
+
"Volumes": pl.UInt64,
|
260
|
+
})
|
261
|
+
|
262
|
+
total_counts = (
|
263
|
+
total_counts
|
264
|
+
.with_columns(
|
265
|
+
pl.col("Year")
|
266
|
+
.cast(pl.String).str.to_datetime("%Y")
|
267
|
+
)
|
268
|
+
.sort("Year")
|
269
|
+
.upsample(time_column="Year", every="1y")
|
270
|
+
.with_columns(
|
271
|
+
pl.col(["Total", "Pages", "Volumes"])
|
272
|
+
.fill_null(strategy="zero")
|
273
|
+
)
|
274
|
+
)
|
275
|
+
total_counts = (
|
276
|
+
total_counts
|
277
|
+
.group_by_dynamic(
|
278
|
+
"Year", every="1y"
|
279
|
+
).agg(pl.col("Total").sum())
|
280
|
+
)
|
281
|
+
|
282
|
+
# sum token totals, convert to datetime and fill in missing years
|
283
|
+
sum_tokens = (
|
284
|
+
all_grams
|
285
|
+
.group_by("Year", maintain_order=True)
|
286
|
+
.agg(pl.col("AF").sum())
|
287
|
+
)
|
288
|
+
sum_tokens = (
|
289
|
+
sum_tokens
|
290
|
+
.with_columns(
|
291
|
+
pl.col("Year")
|
292
|
+
.cast(pl.String).str.to_datetime("%Y")
|
293
|
+
)
|
294
|
+
.sort("Year")
|
295
|
+
.upsample(time_column="Year", every="1y")
|
296
|
+
.with_columns(
|
297
|
+
pl.col("AF")
|
298
|
+
.fill_null(strategy="zero")
|
299
|
+
)
|
300
|
+
)
|
301
|
+
# join with totals
|
302
|
+
sum_tokens = sum_tokens.join(total_counts, on="Year", how="right")
|
303
|
+
# Fill any missing AF created by the join (years with no token hits)
|
304
|
+
sum_tokens = sum_tokens.with_columns(
|
305
|
+
pl.col("AF").fill_null(strategy="zero")
|
306
|
+
)
|
307
|
+
|
308
|
+
if by == "decade":
|
309
|
+
sum_tokens = (
|
310
|
+
sum_tokens
|
311
|
+
.group_by_dynamic("Year", every="10y")
|
312
|
+
.agg(pl.col(["AF", "Total"]).sum())
|
313
|
+
)
|
314
|
+
# normalize RF per million tokens
|
315
|
+
sum_tokens = (
|
316
|
+
sum_tokens
|
317
|
+
.with_columns(
|
318
|
+
RF=pl.col("AF").truediv("Total").mul(1000000)
|
319
|
+
)
|
320
|
+
.with_columns(
|
321
|
+
pl.col("RF").fill_nan(0)
|
322
|
+
)
|
323
|
+
)
|
324
|
+
sum_tokens.insert_column(1, (pl.lit(word_forms)).alias("Token"))
|
325
|
+
sum_tokens = (
|
326
|
+
sum_tokens
|
327
|
+
.with_columns(
|
328
|
+
pl.col("Year").dt.year().alias("Year")
|
329
|
+
)
|
330
|
+
.drop("Total")
|
331
|
+
)
|
332
|
+
|
333
|
+
if by == "decade":
|
334
|
+
# Avoid .rename to prevent potential segfaults
|
335
|
+
sum_tokens = (
|
336
|
+
sum_tokens
|
337
|
+
.with_columns(pl.col("Year").alias("Decade"))
|
338
|
+
.drop("Year")
|
339
|
+
)
|
340
|
+
|
341
|
+
return sum_tokens
|
@@ -0,0 +1,187 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class SmoothResult:
|
11
|
+
x: np.ndarray
|
12
|
+
y_fit: np.ndarray
|
13
|
+
y_lower: Optional[np.ndarray]
|
14
|
+
y_upper: Optional[np.ndarray]
|
15
|
+
|
16
|
+
|
17
|
+
def _df_from_ui(n: int, smoothing: int) -> int:
|
18
|
+
"""Map UI smoothing (1..9) to an effective df like original code.
|
19
|
+
|
20
|
+
Original mapping: df = (10 - smoothing) * 10.
|
21
|
+
Clamp to a feasible range given the sample size and cubic degree.
|
22
|
+
"""
|
23
|
+
s_param = int(smoothing)
|
24
|
+
if s_param < 1:
|
25
|
+
s_param = 1
|
26
|
+
elif s_param > 9:
|
27
|
+
s_param = 9
|
28
|
+
df = (10 - s_param) * 10 # 1->90 (flexible), 9->10 (smooth)
|
29
|
+
# For cubic regression spline, basis dimension = 4 + t (t=interior knots).
|
30
|
+
max_t = max(0, n - 5)
|
31
|
+
max_df = max_t + 4
|
32
|
+
df = max(6, min(df, max_df))
|
33
|
+
return df
|
34
|
+
|
35
|
+
|
36
|
+
def _crs_design(x: np.ndarray, knots: np.ndarray) -> np.ndarray:
|
37
|
+
"""Cubic regression spline design matrix using truncated power basis.
|
38
|
+
|
39
|
+
Columns: [1, x, x^2, x^3, (x - t1)_+^3, ..., (x - tm)_+^3]
|
40
|
+
Where (a)_+ = max(a, 0).
|
41
|
+
"""
|
42
|
+
x = np.asarray(x, dtype=float)
|
43
|
+
X = [
|
44
|
+
np.ones_like(x),
|
45
|
+
x,
|
46
|
+
x * x,
|
47
|
+
x * x * x,
|
48
|
+
]
|
49
|
+
for t in knots:
|
50
|
+
z = x - float(t)
|
51
|
+
z[z < 0.0] = 0.0
|
52
|
+
X.append(z * z * z)
|
53
|
+
return np.column_stack(X)
|
54
|
+
|
55
|
+
|
56
|
+
def _fit_ridge(X: np.ndarray, y: np.ndarray, lam: float) -> np.ndarray:
|
57
|
+
"""Solve ridge regression (X^T X + lam I) beta = X^T y."""
|
58
|
+
XtX = X.T @ X
|
59
|
+
n_feat = XtX.shape[0]
|
60
|
+
A = XtX + lam * np.eye(n_feat)
|
61
|
+
Xty = X.T @ y
|
62
|
+
try:
|
63
|
+
beta = np.linalg.solve(A, Xty)
|
64
|
+
except np.linalg.LinAlgError:
|
65
|
+
beta = np.linalg.lstsq(A, Xty, rcond=None)[0]
|
66
|
+
return beta
|
67
|
+
|
68
|
+
|
69
|
+
def gam_smoother(
|
70
|
+
x: np.ndarray,
|
71
|
+
y: np.ndarray,
|
72
|
+
*,
|
73
|
+
smoothing: int = 7,
|
74
|
+
ci: bool = True,
|
75
|
+
ci_level: float = 0.95,
|
76
|
+
n_boot: int = 200,
|
77
|
+
random_state: Optional[int] = None,
|
78
|
+
) -> SmoothResult:
|
79
|
+
"""Cubic regression spline (NumPy only) with optional bootstrap CIs.
|
80
|
+
|
81
|
+
Uses a truncated power basis with interior knots chosen at quantiles and
|
82
|
+
ridge regularization mapped from the UI smoothing parameter. Returns
|
83
|
+
predictions at the original x order and clips negatives to zero.
|
84
|
+
"""
|
85
|
+
x = np.asarray(x, dtype=float)
|
86
|
+
y = np.asarray(y, dtype=float)
|
87
|
+
if x.ndim != 1 or y.ndim != 1 or len(x) != len(y):
|
88
|
+
raise ValueError("x and y must be 1D arrays of the same length")
|
89
|
+
|
90
|
+
order = np.argsort(x)
|
91
|
+
x_sorted = x[order]
|
92
|
+
y_sorted = y[order]
|
93
|
+
|
94
|
+
# Handle duplicate x by adding tiny jitter
|
95
|
+
dx = np.diff(x_sorted)
|
96
|
+
if np.any(dx == 0):
|
97
|
+
eps = 1e-9 * max(1.0, (x_sorted.max() - x_sorted.min()))
|
98
|
+
counts = {}
|
99
|
+
x_use = x_sorted.copy()
|
100
|
+
for i, val in enumerate(x_sorted):
|
101
|
+
c = counts.get(val, 0)
|
102
|
+
if c > 0:
|
103
|
+
x_use[i] = val + c * eps
|
104
|
+
counts[val] = c + 1
|
105
|
+
else:
|
106
|
+
x_use = x_sorted
|
107
|
+
|
108
|
+
# Normalize x to [0,1] for numerical stability
|
109
|
+
xmin = float(x_use.min())
|
110
|
+
xmax = float(x_use.max())
|
111
|
+
span = xmax - xmin
|
112
|
+
if span <= 0:
|
113
|
+
span = 1.0
|
114
|
+
x0 = (x_use - xmin) / span
|
115
|
+
|
116
|
+
# Determine number of interior knots from df mapping
|
117
|
+
n = len(x0)
|
118
|
+
df = _df_from_ui(n, smoothing)
|
119
|
+
k = 3 # cubic
|
120
|
+
t_count = max(0, min(df - (k + 1), n - (k + 2)))
|
121
|
+
if t_count > 0:
|
122
|
+
qs = np.linspace(0, 1, t_count + 2)[1:-1]
|
123
|
+
knots = np.quantile(x0, qs)
|
124
|
+
else:
|
125
|
+
knots = np.array([], dtype=float)
|
126
|
+
|
127
|
+
# Build design
|
128
|
+
X = _crs_design(x0, knots)
|
129
|
+
# Standardize columns except intercept for stable ridge behavior
|
130
|
+
Xs = X.copy()
|
131
|
+
means = np.zeros(X.shape[1])
|
132
|
+
scales = np.ones(X.shape[1])
|
133
|
+
# skip intercept at col 0
|
134
|
+
for j in range(1, X.shape[1]):
|
135
|
+
col = X[:, j]
|
136
|
+
m = float(np.mean(col))
|
137
|
+
s = float(np.std(col))
|
138
|
+
if s <= 0.0:
|
139
|
+
s = 1.0
|
140
|
+
means[j] = m
|
141
|
+
scales[j] = s
|
142
|
+
Xs[:, j] = (col - m) / s
|
143
|
+
|
144
|
+
# Map smoothing (1..9) -> ridge lambda on a small scale
|
145
|
+
# smoothing=1 -> ~1e-9 (very flexible), smoothing=9 -> ~1e-4 (smoother)
|
146
|
+
s_param = int(np.clip(smoothing, 1, 9))
|
147
|
+
exp_min, exp_max = -9.0, -4.0
|
148
|
+
exponent = exp_min + (s_param - 1) * (exp_max - exp_min) / 8.0
|
149
|
+
lam = 10.0 ** exponent
|
150
|
+
|
151
|
+
beta = _fit_ridge(Xs, y_sorted, lam)
|
152
|
+
y_fit_sorted = Xs @ beta
|
153
|
+
|
154
|
+
inv = np.argsort(order)
|
155
|
+
y_fit = np.asarray(y_fit_sorted[inv], dtype=float)
|
156
|
+
y_fit[y_fit < 0] = 0.0
|
157
|
+
|
158
|
+
y_lower = None
|
159
|
+
y_upper = None
|
160
|
+
if ci:
|
161
|
+
rng = np.random.default_rng(random_state)
|
162
|
+
alpha = 1.0 - float(ci_level)
|
163
|
+
hi_q = 100.0 * (1.0 - alpha / 2.0)
|
164
|
+
resid_sorted = y_sorted - y_fit_sorted
|
165
|
+
boot_preds = np.empty((n_boot, len(x0)), dtype=float)
|
166
|
+
for b in range(n_boot):
|
167
|
+
resampled = rng.choice(
|
168
|
+
resid_sorted, size=len(resid_sorted), replace=True
|
169
|
+
)
|
170
|
+
y_b = y_fit_sorted + resampled
|
171
|
+
# Refit with the same design and penalty
|
172
|
+
beta_b = _fit_ridge(Xs, y_b, lam)
|
173
|
+
boot_preds[b, :] = Xs @ beta_b
|
174
|
+
# Symmetric, fit-centered half-width from absolute deviations
|
175
|
+
hw_sorted = np.percentile(
|
176
|
+
np.abs(boot_preds - y_fit_sorted), hi_q, axis=0
|
177
|
+
)
|
178
|
+
y_lower_sorted = y_fit_sorted - hw_sorted
|
179
|
+
y_upper_sorted = y_fit_sorted + hw_sorted
|
180
|
+
# Map back to original x order
|
181
|
+
y_lower = y_lower_sorted[inv]
|
182
|
+
y_upper = y_upper_sorted[inv]
|
183
|
+
# Clip to non-negative domain
|
184
|
+
y_lower[y_lower < 0] = 0.0
|
185
|
+
y_upper[y_upper < 0] = 0.0
|
186
|
+
|
187
|
+
return SmoothResult(x=x, y_fit=y_fit, y_lower=y_lower, y_upper=y_upper)
|