gbert 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. gbert-0.1.0/MANIFEST.in +3 -0
  2. gbert-0.1.0/PKG-INFO +162 -0
  3. gbert-0.1.0/README.md +276 -0
  4. gbert-0.1.0/README_PYPI.md +133 -0
  5. gbert-0.1.0/gbert/__init__.py +11 -0
  6. gbert-0.1.0/gbert/__main__.py +5 -0
  7. gbert-0.1.0/gbert/add/battle/API_VC.BTL.DETH_DS2_en_csv_v2_9816.csv +271 -0
  8. gbert-0.1.0/gbert/add/battle/Metadata_Country_API_VC.BTL.DETH_DS2_en_csv_v2_9816.csv +277 -0
  9. gbert-0.1.0/gbert/add/battle/Metadata_Indicator_API_VC.BTL.DETH_DS2_en_csv_v2_9816.csv +2 -0
  10. gbert-0.1.0/gbert/add/debt/API_GC.DOD.TOTL.GD.ZS_DS2_en_csv_v2_2484.csv +271 -0
  11. gbert-0.1.0/gbert/add/debt/Metadata_Country_API_GC.DOD.TOTL.GD.ZS_DS2_en_csv_v2_2484.csv +277 -0
  12. gbert-0.1.0/gbert/add/debt/Metadata_Indicator_API_GC.DOD.TOTL.GD.ZS_DS2_en_csv_v2_2484.csv +2 -0
  13. gbert-0.1.0/gbert/add/gdp/API_NY.GDP.MKTP.KD.ZG_DS2_en_csv_v2_2509.csv +271 -0
  14. gbert-0.1.0/gbert/add/gdp/Metadata_Country_API_NY.GDP.MKTP.KD.ZG_DS2_en_csv_v2_2509.csv +277 -0
  15. gbert-0.1.0/gbert/add/gdp/Metadata_Indicator_API_NY.GDP.MKTP.KD.ZG_DS2_en_csv_v2_2509.csv +4 -0
  16. gbert-0.1.0/gbert/add/inflation/API_FP.CPI.TOTL.ZG_DS2_en_csv_v2_2479.csv +271 -0
  17. gbert-0.1.0/gbert/add/inflation/Metadata_Country_API_FP.CPI.TOTL.ZG_DS2_en_csv_v2_2479.csv +277 -0
  18. gbert-0.1.0/gbert/add/inflation/Metadata_Indicator_API_FP.CPI.TOTL.ZG_DS2_en_csv_v2_2479.csv +2 -0
  19. gbert-0.1.0/gbert/add/trade/API_NE.TRD.GNFS.ZS_DS2_en_csv_v2_2650.csv +271 -0
  20. gbert-0.1.0/gbert/add/trade/Metadata_Country_API_NE.TRD.GNFS.ZS_DS2_en_csv_v2_2650.csv +277 -0
  21. gbert-0.1.0/gbert/add/trade/Metadata_Indicator_API_NE.TRD.GNFS.ZS_DS2_en_csv_v2_2650.csv +4 -0
  22. gbert-0.1.0/gbert/add/unemployment/API_SL.UEM.TOTL.ZS_DS2_en_csv_v2_2821.csv +271 -0
  23. gbert-0.1.0/gbert/add/unemployment/Metadata_Country_API_SL.UEM.TOTL.ZS_DS2_en_csv_v2_2821.csv +277 -0
  24. gbert-0.1.0/gbert/add/unemployment/Metadata_Indicator_API_SL.UEM.TOTL.ZS_DS2_en_csv_v2_2821.csv +2 -0
  25. gbert-0.1.0/gbert/cli.py +42 -0
  26. gbert-0.1.0/gbert/configuration_causal_interpretable.py +55 -0
  27. gbert-0.1.0/gbert/modeling_causal_interpretable.py +281 -0
  28. gbert-0.1.0/gbert/preprocess_meta.joblib +0 -0
  29. gbert-0.1.0/gbert/service.py +549 -0
  30. gbert-0.1.0/gbert.egg-info/PKG-INFO +162 -0
  31. gbert-0.1.0/gbert.egg-info/SOURCES.txt +35 -0
  32. gbert-0.1.0/gbert.egg-info/dependency_links.txt +1 -0
  33. gbert-0.1.0/gbert.egg-info/entry_points.txt +2 -0
  34. gbert-0.1.0/gbert.egg-info/requires.txt +11 -0
  35. gbert-0.1.0/gbert.egg-info/top_level.txt +1 -0
  36. gbert-0.1.0/pyproject.toml +56 -0
  37. gbert-0.1.0/setup.cfg +4 -0
@@ -0,0 +1,3 @@
1
+ include README_PYPI.md
2
+ recursive-include gbert/add *.csv
3
+ include gbert/preprocess_meta.joblib
gbert-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,162 @@
1
+ Metadata-Version: 2.1
2
+ Name: gbert
3
+ Version: 0.1.0
4
+ Summary: Simple multilingual policy text analysis with GBERT, optimized for notebooks and Kaggle.
5
+ Project-URL: Homepage, https://pypi.org/project/gbert/
6
+ Keywords: nlp,bert,text-classification,kaggle,policy-analysis
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3 :: Only
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Classifier: Topic :: Text Processing :: Linguistic
17
+ Requires-Python: >=3.10
18
+ Description-Content-Type: text/markdown
19
+ Requires-Dist: huggingface_hub<2,>=0.30
20
+ Requires-Dist: joblib<2,>=1.4
21
+ Requires-Dist: numpy<3,>=1.26
22
+ Requires-Dist: pandas<3,>=2.2
23
+ Requires-Dist: scikit-learn<2,>=1.5
24
+ Requires-Dist: torch<2.6,>=2.5
25
+ Requires-Dist: transformers<5,>=4.46
26
+ Provides-Extra: dev
27
+ Requires-Dist: build>=1.2; extra == "dev"
28
+ Requires-Dist: twine<7,>=5; extra == "dev"
29
+
30
+ # gbert
31
+
32
+ `gbert` 是一个尽量简单的多语言文本分析包,适合:
33
+
34
+ - Kaggle Notebook
35
+ - 单条文本快速预测
36
+ - 大批量文本批处理
37
+
38
+ 它默认不包含网页、Flask 或服务端代码,只保留推理必需内容。
39
+
40
+ ## 安装
41
+
42
+ ```bash
43
+ pip install gbert
44
+ ```
45
+
46
+ 如果你在 Kaggle 里使用,推荐先安装再直接在 Notebook 里跑:
47
+
48
+ ```python
49
+ !pip -q install gbert
50
+ ```
51
+
52
+ ## 最简单用法
53
+
54
+ ```python
55
+ from gbert import GbertClassifier
56
+
57
+ model = GbertClassifier(
58
+ model_repo_id="your-hf-model-repo",
59
+ hf_token="your_hf_token_if_needed",
60
+ )
61
+
62
+ result = model.predict(
63
+ "The government will expand industrial policy and labour training.",
64
+ country="Japan",
65
+ year=2026,
66
+ )
67
+
68
+ result["predictions"][:3]
69
+ ```
70
+
71
+ ## 批量分析
72
+
73
+ ```python
74
+ from gbert import GbertClassifier
75
+
76
+ model = GbertClassifier(model_repo_id="your-hf-model-repo")
77
+
78
+ results = model.predict_batch(
79
+ texts=[
80
+ "We will invest in industry.",
81
+ "Healthcare access must improve.",
82
+ "Tax reform should support growth.",
83
+ ],
84
+ country="Japan",
85
+ year=2026,
86
+ batch_size=16,
87
+ )
88
+ ```
89
+
90
+ 如果你更喜欢 `pandas.DataFrame`:
91
+
92
+ ```python
93
+ df = model.predict_batch(
94
+ texts=["text a", "text b"],
95
+ country=["Japan", "Germany"],
96
+ year=[2026, 2024],
97
+ return_df=True,
98
+ )
99
+ ```
100
+
101
+ ## API
102
+
103
+ ### `GbertClassifier`
104
+
105
+ 常用参数:
106
+
107
+ - `model_path`: 本地模型权重路径
108
+ - `model_repo_id`: Hugging Face Hub 模型仓库
109
+ - `model_filename`: 默认是 `causal_nam_best.pt`
110
+ - `hf_token`: 私有仓库时可传
111
+ - `device`: `cpu` / `cuda` / 自动检测
112
+ - `batch_size`: 批量推理默认 batch 大小
113
+
114
+ ### 方法
115
+
116
+ - `predict(text, country, year, top_k=5)`
117
+ - `predict_batch(texts, country, year, top_k=5, batch_size=None, return_df=False)`
118
+ - `list_countries()`
119
+ - `list_years(country)`
120
+ - `runtime_info()`
121
+
122
+ ## 设计目标
123
+
124
+ - Notebook first
125
+ - Kaggle friendly
126
+ - 单次加载,多次复用
127
+ - 单条和批量接口统一
128
+
129
+ ## 模型说明
130
+
131
+ 这个包会打包:
132
+
133
+ - `preprocess_meta.joblib`
134
+ - `add/` 下的宏观变量 CSV
135
+
136
+ 它不会打包大模型权重文件。运行时会按以下顺序寻找模型:
137
+
138
+ 1. 你显式传入的 `model_path`
139
+ 2. 环境变量 `GBERT_MODEL_PATH`
140
+ 3. 包目录下的 `causal_nam_best.pt`
141
+ 4. Hugging Face Hub (`model_repo_id`)
142
+
143
+ 如果你想通过环境变量控制,也支持:
144
+
145
+ - `GBERT_MODEL_PATH`
146
+ - `GBERT_MODEL_REPO_ID`
147
+ - `GBERT_MODEL_FILENAME`
148
+ - `HF_TOKEN`
149
+ - `TEXT_MODEL_NAME_OR_PATH`
150
+ - `TORCH_NUM_THREADS`
151
+
152
+ ## 命令行
153
+
154
+ ```bash
155
+ gbert --text "Industrial policy matters." --country Japan --year 2026
156
+ ```
157
+
158
+ ## 备注
159
+
160
+ - 当前宏观数据可支持到 `2026`
161
+ - 对超出训练年份的输入,年份 embedding 会自动回退到模型训练期的最后一年
162
+ - 批量分析建议复用同一个 `GbertClassifier` 实例,不要每条文本重新初始化模型
gbert-0.1.0/README.md ADDED
@@ -0,0 +1,276 @@
1
+ ---
2
+ title: Manifesto Model Demo
3
+ sdk: docker
4
+ app_port: 7860
5
+ ---
6
+
7
+ # Manifesto Model Web Demo
8
+
9
+ 这个项目现在已经补上了一个可直接启动的网站版本,访问者可以:
10
+
11
+ - 输入一段政策文本
12
+ - 选择国家和年份
13
+ - 由系统自动从 `add/` 目录读取 6 个宏观变量
14
+ - 调用训练好的 `causal_nam_best.pt` 模型返回 Top-5 预测类别
15
+
16
+ ## 新增文件
17
+
18
+ - `app.py`: Flask 网站入口
19
+ - `inference_service.py`: 推理服务,负责读取 `preprocess_meta.joblib`、`add/` 和模型
20
+ - `templates/index.html`: 页面模板
21
+ - `static/style.css`: 页面样式
22
+ - `static/app.js`: 前端交互
23
+ - `render.yaml`: Render 部署配置
24
+ - `Procfile`: Railway / 通用 PaaS 启动命令
25
+ - `.env.example`: 环境变量模板
26
+
27
+ ## 启动方式
28
+
29
+ 建议先准备一个可正常运行 `torch` 和 `transformers` 的 Python 环境,再执行:
30
+
31
+ ```bash
32
+ pip install -r requirements.txt
33
+ python app.py
34
+ ```
35
+
36
+ 启动后访问:
37
+
38
+ ```text
39
+ http://127.0.0.1:8000
40
+ ```
41
+
42
+ 本地环境变量可以参考 `.env.example`。
43
+
44
+ ## 作为 Python 包使用
45
+
46
+ 现在仓库也可以直接作为 Python 包安装:
47
+
48
+ ```bash
49
+ pip install .
50
+ ```
51
+
52
+ 安装后可以在 Python 里直接调用:
53
+
54
+ ```python
55
+ from manifesto_model import create_service, predict
56
+
57
+ service = create_service()
58
+ result = service.predict(
59
+ text="The government will expand industrial policy and labour training.",
60
+ country="Japan",
61
+ year=2026,
62
+ top_k=5,
63
+ )
64
+
65
+ # 或者直接:
66
+ result = predict(
67
+ text="The government will expand industrial policy and labour training.",
68
+ country="Japan",
69
+ year=2026,
70
+ )
71
+ ```
72
+
73
+ 如果你想作为 Flask API 使用,也可以:
74
+
75
+ ```python
76
+ from manifesto_model.web import create_app
77
+
78
+ app = create_app()
79
+ ```
80
+
81
+ ## 发布到 PyPI
82
+
83
+ 这个仓库现在已经补成标准 `pyproject.toml` 包结构,可以直接发布到 PyPI。
84
+
85
+ 建议流程:
86
+
87
+ 1. 先确认 PyPI 上还没有同名包 `manifesto-model`
88
+ 2. 安装构建与上传工具:
89
+
90
+ ```bash
91
+ pip install build twine
92
+ ```
93
+
94
+ 3. 构建发布文件:
95
+
96
+ ```bash
97
+ python -m build
98
+ ```
99
+
100
+ 4. 先上传到 TestPyPI 测试:
101
+
102
+ ```bash
103
+ python -m twine upload --repository testpypi dist/*
104
+ ```
105
+
106
+ 5. 测试没问题后,再上传到正式 PyPI:
107
+
108
+ ```bash
109
+ python -m twine upload dist/*
110
+ ```
111
+
112
+ 6. 上传成功后,其他人就可以直接安装:
113
+
114
+ ```bash
115
+ pip install manifesto-model
116
+ ```
117
+
118
+ 当前包发布时不会包含 `causal_nam_best.pt` 大模型权重;运行时需要通过 `MODEL_PATH` 或 Hugging Face Hub 环境变量提供模型文件。
119
+
120
+ ## 让其他用户访问
121
+
122
+ 现在最直接的做法是把这个项目部署到一个公开 URL。这个仓库已经补了 `Dockerfile`,适合直接部署到支持 Docker 的平台。
123
+
124
+ ## GitHub Pages 前端 + 独立后端 API
125
+
126
+ 现在项目已经拆成了两层:
127
+
128
+ 1. `docs/` 目录下是纯静态前端,可以直接发布到 GitHub Pages
129
+ 2. Flask 只负责独立后端 API,主要接口是 `/api/health`、`/api/options`、`/api/predict`
130
+
131
+ 前端不会依赖 Flask 模板,也不会依赖同域部署。访问者第一次打开 GitHub Pages 页面时,只需要填写一次后端地址,浏览器会自动保存。
132
+ 如果你已经有固定的线上后端地址,也可以直接把它写进 `docs/assets/app.js`,这样普通用户打开页面就能直接使用。
133
+
134
+ ### GitHub Pages 部署
135
+
136
+ 1. 把代码推到 GitHub 仓库
137
+ 2. 在仓库设置里打开 Pages
138
+ 3. 选择从 `main` 分支的 `/docs` 目录发布
139
+ 4. 等 GitHub 生成公开网址
140
+
141
+ ### 后端 API 部署
142
+
143
+ 后端继续部署到 Render、Railway 或 Hugging Face Spaces 都可以。由于前端是跨域调用,后端已经默认加好了 CORS。
144
+
145
+ 如果你想限制只允许自己的 GitHub Pages 域名访问,可以配置:
146
+
147
+ ```text
148
+ CORS_ALLOW_ORIGINS=https://your-name.github.io
149
+ ```
150
+
151
+ 如果不配置,当前默认允许公开前端调用。
152
+
153
+ ## 上线前必须注意
154
+
155
+ `causal_nam_best.pt` 现在大约 683MB。GitHub 官方对普通仓库单文件有 100MB 限制,所以它不能作为普通文件直接推送到 GitHub。
156
+
157
+ 现在代码支持两种上线方式:
158
+
159
+ 1. 使用 Git LFS 管理模型文件
160
+ 2. 更推荐:把网站代码和模型文件分开,代码仓库只放网站,模型放到 Hugging Face 模型仓库,部署时通过环境变量自动下载
161
+
162
+ 如果你选择“模型仓库分离”方案,需要配置:
163
+
164
+ ```text
165
+ MODEL_REPO_ID=your-username/your-model-repo
166
+ MODEL_FILENAME=causal_nam_best.pt
167
+ HF_TOKEN=
168
+ TEXT_MODEL_NAME_OR_PATH=bert-base-multilingual-cased
169
+ ```
170
+
171
+ 如果模型文件就在部署机器本地,只需要:
172
+
173
+ ```text
174
+ MODEL_PATH=./causal_nam_best.pt
175
+ ```
176
+
177
+ ### 方案 1:Hugging Face Spaces(更适合模型演示)
178
+
179
+ - 官方说明:Hugging Face 支持 Docker Spaces,可以直接运行自定义 `Dockerfile`
180
+ - 本项目已经按这种方式准备好了,默认暴露端口 `7860`
181
+ - 适合做公开 demo,后续如果 CPU 不够,也可以升级硬件
182
+
183
+ 基本流程:
184
+
185
+ 1. 新建一个 Hugging Face Space
186
+ 2. 选择 `Docker` 作为 SDK
187
+ 3. 把当前项目代码推上去
188
+ 4. 在 Space 里配置环境变量;如果模型不在代码仓库中,就配置 `MODEL_REPO_ID` 和 `MODEL_FILENAME`
189
+ 5. 等待构建完成后,平台会给你一个公开链接
190
+
191
+ ### 方案 2:Render / Railway(更像常规网站部署)
192
+
193
+ - 两个平台都支持从 GitHub 仓库直接部署 Python/Flask 应用
194
+ - Render 官方要求服务监听 `0.0.0.0` 和平台端口
195
+ - Railway 官方的 Flask 指南建议使用 `gunicorn`
196
+ - 这些要求我已经在代码里处理好了
197
+
198
+ 如果你走这条路,通常只需要:
199
+
200
+ 1. 把项目推到 GitHub
201
+ 2. 在平台里选择该仓库
202
+ 3. 配置环境变量,例如:
203
+
204
+ ```text
205
+ MODEL_REPO_ID=your-username/your-model-repo
206
+ MODEL_FILENAME=causal_nam_best.pt
207
+ TEXT_MODEL_NAME_OR_PATH=bert-base-multilingual-cased
208
+ ```
209
+
210
+ 4. 使用 Docker 部署,或直接使用:
211
+
212
+ ```text
213
+ Build: pip install -r requirements.txt
214
+ Start: gunicorn --bind 0.0.0.0:$PORT app:app
215
+ ```
216
+
217
+ ### 当前更推荐哪个
218
+
219
+ - 如果你的目标是“让别人在线试这个模型”,我更推荐 Hugging Face Spaces
220
+ - 如果你的目标是“做成普通网站,后面还想加账号、数据库、API 管理”,我更推荐 Render 或 Railway
221
+
222
+ ## 推荐上线结构
223
+
224
+ 建议按下面的结构上线:
225
+
226
+ 1. GitHub 仓库:只放网站代码
227
+ 2. Hugging Face 模型仓库:只放 `causal_nam_best.pt`
228
+ 3. 部署平台:Hugging Face Spaces 或 Render
229
+
230
+ 这样能避开 GitHub 大文件限制,也更方便后续更新模型。
231
+
232
+ ## 速度和内存
233
+
234
+ 如果你在线上看到 `Ran out of memory (used over 512MB)`,根因通常不是页面,而是模型本体太大。
235
+
236
+ 当前项目已经做了几项部署优化:
237
+
238
+ - Docker 默认安装 CPU 版 PyTorch,不再拉取 CUDA 大包
239
+ - Gunicorn 改成 `1 worker + 1 thread`
240
+ - 推理进程默认 `TORCH_NUM_THREADS=1`
241
+ - 模型 backbone 用更省内存的加载方式
242
+ - 权重加载优先使用 `weights_only` / `mmap`
243
+
244
+ 即便这样,`683MB` 权重加上 `bert-base-multilingual-cased` 在 `512MB RAM` 上依然很容易不稳定。
245
+ 如果你要“速度明显更快且稳定”,最有效的方法通常是:
246
+
247
+ 1. 把实例内存升到至少 `2GB`
248
+ 2. 后续改成模型常驻内存
249
+ 3. 或者换更小的模型
250
+
251
+ ## 说明
252
+
253
+ - 页面不会让用户手填宏观变量,推理时会自动从 `add/` 的 World Bank CSV 中按 `国家 + 年份` 读取。
254
+ - `add/` 中的 6 个宏观变量文件现在已经扩展到 `2026`;其中 `2025`、`2026` 当前采用 `2024` 数值前向延续,便于网站先支持未来年份推理。
255
+ - 少数国家名称和 World Bank 的命名不完全一致,代码里已经做了别名映射,例如:
256
+ - `South Korea -> Korea, Rep.`
257
+ - `Turkey -> Turkiye`
258
+ - `Czech Republic -> Czechia`
259
+ - 如果某个国家年份在 `add/` 里该指标缺失,当前实现会按训练阶段的做法补 `0`,然后再按 `preprocess_meta.joblib` 里的均值和标准差做标准化。
260
+ - 模型训练时的年份编码目前只到 `2023`,因此当页面选择 `2024-2026` 时,宏观变量会读取对应年份的新列,但模型内部年份 embedding 仍兼容回退到 `2023`。
261
+
262
+ ## 当前环境提醒
263
+
264
+ 我在这台机器上检查时,当前 Anaconda 环境里的 `torch` 启动就报了底层 OpenMP 错误,因此还不能在这里完成真实推理验证。
265
+
266
+ 网页和服务代码已经接好;只要换到一个可正常运行 `torch` 的环境,或者修复当前 Python 环境后,就可以直接启动并调用模型。
267
+
268
+ ## 参考资料
269
+
270
+ - [Hugging Face Spaces 概览](https://huggingface.co/docs/hub/en/spaces)
271
+ - [Hugging Face Docker Spaces](https://huggingface.co/docs/hub/spaces-sdks-docker)
272
+ - [Hugging Face Hub 文件下载](https://huggingface.co/docs/huggingface_hub/v0.30.2/en/guides/download)
273
+ - [GitHub 大文件限制](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-large-files-on-github)
274
+ - [Render Web Services 文档](https://render.com/docs/web-services)
275
+ - [Render Blueprints / render.yaml](https://render.com/docs/infrastructure-as-code)
276
+ - [Railway Flask 部署指南](https://docs.railway.com/guides/flask)
@@ -0,0 +1,133 @@
1
+ # gbert
2
+
3
+ `gbert` 是一个尽量简单的多语言文本分析包,适合:
4
+
5
+ - Kaggle Notebook
6
+ - 单条文本快速预测
7
+ - 大批量文本批处理
8
+
9
+ 它默认不包含网页、Flask 或服务端代码,只保留推理必需内容。
10
+
11
+ ## 安装
12
+
13
+ ```bash
14
+ pip install gbert
15
+ ```
16
+
17
+ 如果你在 Kaggle 里使用,推荐先安装再直接在 Notebook 里跑:
18
+
19
+ ```python
20
+ !pip -q install gbert
21
+ ```
22
+
23
+ ## 最简单用法
24
+
25
+ ```python
26
+ from gbert import GbertClassifier
27
+
28
+ model = GbertClassifier(
29
+ model_repo_id="your-hf-model-repo",
30
+ hf_token="your_hf_token_if_needed",
31
+ )
32
+
33
+ result = model.predict(
34
+ "The government will expand industrial policy and labour training.",
35
+ country="Japan",
36
+ year=2026,
37
+ )
38
+
39
+ result["predictions"][:3]
40
+ ```
41
+
42
+ ## 批量分析
43
+
44
+ ```python
45
+ from gbert import GbertClassifier
46
+
47
+ model = GbertClassifier(model_repo_id="your-hf-model-repo")
48
+
49
+ results = model.predict_batch(
50
+ texts=[
51
+ "We will invest in industry.",
52
+ "Healthcare access must improve.",
53
+ "Tax reform should support growth.",
54
+ ],
55
+ country="Japan",
56
+ year=2026,
57
+ batch_size=16,
58
+ )
59
+ ```
60
+
61
+ 如果你更喜欢 `pandas.DataFrame`:
62
+
63
+ ```python
64
+ df = model.predict_batch(
65
+ texts=["text a", "text b"],
66
+ country=["Japan", "Germany"],
67
+ year=[2026, 2024],
68
+ return_df=True,
69
+ )
70
+ ```
71
+
72
+ ## API
73
+
74
+ ### `GbertClassifier`
75
+
76
+ 常用参数:
77
+
78
+ - `model_path`: 本地模型权重路径
79
+ - `model_repo_id`: Hugging Face Hub 模型仓库
80
+ - `model_filename`: 默认是 `causal_nam_best.pt`
81
+ - `hf_token`: 私有仓库时可传
82
+ - `device`: `cpu` / `cuda` / 自动检测
83
+ - `batch_size`: 批量推理默认 batch 大小
84
+
85
+ ### 方法
86
+
87
+ - `predict(text, country, year, top_k=5)`
88
+ - `predict_batch(texts, country, year, top_k=5, batch_size=None, return_df=False)`
89
+ - `list_countries()`
90
+ - `list_years(country)`
91
+ - `runtime_info()`
92
+
93
+ ## 设计目标
94
+
95
+ - Notebook first
96
+ - Kaggle friendly
97
+ - 单次加载,多次复用
98
+ - 单条和批量接口统一
99
+
100
+ ## 模型说明
101
+
102
+ 这个包会打包:
103
+
104
+ - `preprocess_meta.joblib`
105
+ - `add/` 下的宏观变量 CSV
106
+
107
+ 它不会打包大模型权重文件。运行时会按以下顺序寻找模型:
108
+
109
+ 1. 你显式传入的 `model_path`
110
+ 2. 环境变量 `GBERT_MODEL_PATH`
111
+ 3. 包目录下的 `causal_nam_best.pt`
112
+ 4. Hugging Face Hub (`model_repo_id`)
113
+
114
+ 如果你想通过环境变量控制,也支持:
115
+
116
+ - `GBERT_MODEL_PATH`
117
+ - `GBERT_MODEL_REPO_ID`
118
+ - `GBERT_MODEL_FILENAME`
119
+ - `HF_TOKEN`
120
+ - `TEXT_MODEL_NAME_OR_PATH`
121
+ - `TORCH_NUM_THREADS`
122
+
123
+ ## 命令行
124
+
125
+ ```bash
126
+ gbert --text "Industrial policy matters." --country Japan --year 2026
127
+ ```
128
+
129
+ ## 备注
130
+
131
+ - 当前宏观数据可支持到 `2026`
132
+ - 对超出训练年份的输入,年份 embedding 会自动回退到模型训练期的最后一年
133
+ - 批量分析建议复用同一个 `GbertClassifier` 实例,不要每条文本重新初始化模型
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from .service import GbertClassifier, load_default_model, predict, predict_batch
4
+
5
+
6
+ __all__ = [
7
+ "GbertClassifier",
8
+ "load_default_model",
9
+ "predict",
10
+ "predict_batch",
11
+ ]
@@ -0,0 +1,5 @@
1
+ from .cli import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()