@nahisaho/satori 0.23.0 → 0.25.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +61 -29
- package/package.json +1 -1
- package/src/.github/skills/scientific-adaptive-experiments/SKILL.md +287 -0
- package/src/.github/skills/scientific-anomaly-detection/SKILL.md +296 -0
- package/src/.github/skills/scientific-causal-ml/SKILL.md +240 -0
- package/src/.github/skills/scientific-data-profiling/SKILL.md +247 -0
- package/src/.github/skills/scientific-federated-learning/SKILL.md +241 -0
- package/src/.github/skills/scientific-geospatial-analysis/SKILL.md +274 -0
- package/src/.github/skills/scientific-model-monitoring/SKILL.md +247 -0
- package/src/.github/skills/scientific-multi-task-learning/SKILL.md +238 -0
- package/src/.github/skills/scientific-network-visualization/SKILL.md +278 -0
- package/src/.github/skills/scientific-neural-architecture-search/SKILL.md +206 -0
- package/src/.github/skills/scientific-radiology-ai/SKILL.md +285 -0
- package/src/.github/skills/scientific-reproducible-reporting/SKILL.md +330 -0
- package/src/.github/skills/scientific-semi-supervised-learning/SKILL.md +210 -0
- package/src/.github/skills/scientific-statistical-simulation/SKILL.md +227 -0
- package/src/.github/skills/scientific-streaming-analytics/SKILL.md +221 -0
- package/src/.github/skills/scientific-time-series-forecasting/SKILL.md +246 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: scientific-data-profiling
|
|
3
|
+
description: |
|
|
4
|
+
データプロファイリング・品質スキル。ydata-profiling 自動 EDA ・
|
|
5
|
+
Great Expectations データバリデーション・データ品質スコア・
|
|
6
|
+
型推論・相関検出・外れ値フラグ・データカタログ生成。
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
# Scientific Data Profiling
|
|
10
|
+
|
|
11
|
+
データセットの包括的プロファイリング・品質評価・
|
|
12
|
+
自動 EDA レポートパイプラインを提供する。
|
|
13
|
+
|
|
14
|
+
## When to Use
|
|
15
|
+
|
|
16
|
+
- 新しいデータセットの全体像を素早く把握するとき
|
|
17
|
+
- データ品質スコアを算出して品質基準をチェックするとき
|
|
18
|
+
- ydata-profiling で自動 EDA レポートを生成するとき
|
|
19
|
+
- Great Expectations でデータバリデーションルールを定義するとき
|
|
20
|
+
- データカタログ (辞書) を自動生成するとき
|
|
21
|
+
- 相関・外れ値・欠損を一括診断するとき
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
## Quick Start
|
|
26
|
+
|
|
27
|
+
## 1. ydata-profiling 自動 EDA
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
import numpy as np
|
|
31
|
+
import pandas as pd
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def auto_profile_report(df, title="Data Profile Report",
|
|
35
|
+
minimal=False, output="profile_report.html"):
|
|
36
|
+
"""
|
|
37
|
+
ydata-profiling 自動 EDA レポート。
|
|
38
|
+
|
|
39
|
+
Parameters:
|
|
40
|
+
df: pd.DataFrame — 入力データ
|
|
41
|
+
title: str — レポートタイトル
|
|
42
|
+
minimal: bool — 軽量モード
|
|
43
|
+
output: str — 出力 HTML パス
|
|
44
|
+
"""
|
|
45
|
+
from ydata_profiling import ProfileReport
|
|
46
|
+
|
|
47
|
+
profile = ProfileReport(
|
|
48
|
+
df, title=title, minimal=minimal,
|
|
49
|
+
correlations={"pearson": {"calculate": True},
|
|
50
|
+
"spearman": {"calculate": True},
|
|
51
|
+
"kendall": {"calculate": True}},
|
|
52
|
+
missing_diagrams={"bar": True, "matrix": True, "heatmap": True})
|
|
53
|
+
|
|
54
|
+
profile.to_file(output)
|
|
55
|
+
|
|
56
|
+
# サマリー抽出
|
|
57
|
+
desc = profile.get_description()
|
|
58
|
+
summary = {
|
|
59
|
+
"n_rows": len(df),
|
|
60
|
+
"n_cols": len(df.columns),
|
|
61
|
+
"n_numeric": len(df.select_dtypes(include=[np.number]).columns),
|
|
62
|
+
"n_categorical": len(df.select_dtypes(include=["object", "category"]).columns),
|
|
63
|
+
"total_missing": int(df.isnull().sum().sum()),
|
|
64
|
+
"missing_pct": float(df.isnull().sum().sum() / (len(df) * len(df.columns)) * 100),
|
|
65
|
+
"n_duplicates": int(df.duplicated().sum()),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
print(f"Profile Report → {output}")
|
|
69
|
+
print(f" {summary['n_rows']} rows × {summary['n_cols']} cols, "
|
|
70
|
+
f"{summary['missing_pct']:.1f}% missing, "
|
|
71
|
+
f"{summary['n_duplicates']} duplicates")
|
|
72
|
+
return {"report_path": output, "summary": summary}
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## 2. データ品質スコア
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
def data_quality_score(df, rules=None):
|
|
79
|
+
"""
|
|
80
|
+
データ品質スコア算出 (0-100)。
|
|
81
|
+
|
|
82
|
+
Parameters:
|
|
83
|
+
df: pd.DataFrame — 入力データ
|
|
84
|
+
rules: dict | None — カスタムルール
|
|
85
|
+
"""
|
|
86
|
+
scores = {}
|
|
87
|
+
|
|
88
|
+
# 1. 完全性 (Completeness) — 非欠損率
|
|
89
|
+
completeness = 1.0 - df.isnull().sum().sum() / (len(df) * len(df.columns))
|
|
90
|
+
scores["completeness"] = completeness
|
|
91
|
+
|
|
92
|
+
# 2. 一意性 (Uniqueness) — 非重複率
|
|
93
|
+
uniqueness = 1.0 - df.duplicated().sum() / len(df) if len(df) > 0 else 1.0
|
|
94
|
+
scores["uniqueness"] = uniqueness
|
|
95
|
+
|
|
96
|
+
# 3. 一貫性 (Consistency) — 型一貫性
|
|
97
|
+
type_consistent = 0
|
|
98
|
+
for col in df.columns:
|
|
99
|
+
non_null = df[col].dropna()
|
|
100
|
+
if len(non_null) == 0:
|
|
101
|
+
type_consistent += 1
|
|
102
|
+
continue
|
|
103
|
+
try:
|
|
104
|
+
inferred = pd.api.types.infer_dtype(non_null, skipna=True)
|
|
105
|
+
if inferred not in ["mixed", "mixed-integer"]:
|
|
106
|
+
type_consistent += 1
|
|
107
|
+
except Exception:
|
|
108
|
+
pass
|
|
109
|
+
consistency = type_consistent / len(df.columns) if len(df.columns) > 0 else 1.0
|
|
110
|
+
scores["consistency"] = consistency
|
|
111
|
+
|
|
112
|
+
# 4. 適時性 (Timeliness) — 日付カラムの新しさ
|
|
113
|
+
date_cols = df.select_dtypes(include=["datetime64"]).columns
|
|
114
|
+
if len(date_cols) > 0:
|
|
115
|
+
max_date = df[date_cols[0]].max()
|
|
116
|
+
freshness = 1.0 # Placeholder
|
|
117
|
+
scores["timeliness"] = freshness
|
|
118
|
+
else:
|
|
119
|
+
scores["timeliness"] = 1.0
|
|
120
|
+
|
|
121
|
+
# 5. 妥当性 (Validity) — 数値カラムの有限性
|
|
122
|
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|
123
|
+
if len(numeric_cols) > 0:
|
|
124
|
+
finite_rate = df[numeric_cols].apply(lambda x: np.isfinite(x.dropna()).mean()).mean()
|
|
125
|
+
scores["validity"] = float(finite_rate)
|
|
126
|
+
else:
|
|
127
|
+
scores["validity"] = 1.0
|
|
128
|
+
|
|
129
|
+
# 総合スコア
|
|
130
|
+
weights = {"completeness": 0.3, "uniqueness": 0.2,
|
|
131
|
+
"consistency": 0.2, "timeliness": 0.1, "validity": 0.2}
|
|
132
|
+
total_score = sum(scores[k] * weights[k] for k in weights) * 100
|
|
133
|
+
|
|
134
|
+
# カスタムルール
|
|
135
|
+
rule_results = []
|
|
136
|
+
if rules:
|
|
137
|
+
for rule_name, rule_fn in rules.items():
|
|
138
|
+
try:
|
|
139
|
+
passed = rule_fn(df)
|
|
140
|
+
rule_results.append({"rule": rule_name, "passed": passed})
|
|
141
|
+
except Exception as e:
|
|
142
|
+
rule_results.append({"rule": rule_name, "passed": False,
|
|
143
|
+
"error": str(e)})
|
|
144
|
+
|
|
145
|
+
print(f"Data Quality Score: {total_score:.1f}/100")
|
|
146
|
+
for k, v in scores.items():
|
|
147
|
+
print(f" {k}: {v:.3f}")
|
|
148
|
+
|
|
149
|
+
return {"total_score": total_score, "dimension_scores": scores,
|
|
150
|
+
"rule_results": rule_results}
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
## 3. Great Expectations バリデーション
|
|
154
|
+
|
|
155
|
+
```python
|
|
156
|
+
def great_expectations_validate(df, expectations=None):
|
|
157
|
+
"""
|
|
158
|
+
Great Expectations スタイルのデータバリデーション。
|
|
159
|
+
|
|
160
|
+
Parameters:
|
|
161
|
+
df: pd.DataFrame — 入力データ
|
|
162
|
+
expectations: list[dict] | None — バリデーションルール
|
|
163
|
+
"""
|
|
164
|
+
if expectations is None:
|
|
165
|
+
expectations = _auto_generate_expectations(df)
|
|
166
|
+
|
|
167
|
+
results = []
|
|
168
|
+
for exp in expectations:
|
|
169
|
+
exp_type = exp["type"]
|
|
170
|
+
col = exp.get("column")
|
|
171
|
+
kwargs = exp.get("kwargs", {})
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
if exp_type == "expect_column_to_exist":
|
|
175
|
+
success = col in df.columns
|
|
176
|
+
elif exp_type == "expect_column_values_to_not_be_null":
|
|
177
|
+
max_pct = kwargs.get("mostly", 1.0)
|
|
178
|
+
non_null_pct = df[col].notnull().mean()
|
|
179
|
+
success = non_null_pct >= max_pct
|
|
180
|
+
elif exp_type == "expect_column_values_to_be_between":
|
|
181
|
+
min_val, max_val = kwargs["min_value"], kwargs["max_value"]
|
|
182
|
+
vals = df[col].dropna()
|
|
183
|
+
success = bool((vals >= min_val).all() and (vals <= max_val).all())
|
|
184
|
+
elif exp_type == "expect_column_values_to_be_unique":
|
|
185
|
+
success = not df[col].duplicated().any()
|
|
186
|
+
elif exp_type == "expect_column_values_to_be_in_set":
|
|
187
|
+
valid_set = set(kwargs["value_set"])
|
|
188
|
+
success = df[col].dropna().isin(valid_set).all()
|
|
189
|
+
elif exp_type == "expect_table_row_count_to_be_between":
|
|
190
|
+
success = kwargs["min_value"] <= len(df) <= kwargs["max_value"]
|
|
191
|
+
else:
|
|
192
|
+
success = None
|
|
193
|
+
|
|
194
|
+
results.append({"expectation": exp_type, "column": col,
|
|
195
|
+
"success": success})
|
|
196
|
+
except Exception as e:
|
|
197
|
+
results.append({"expectation": exp_type, "column": col,
|
|
198
|
+
"success": False, "error": str(e)})
|
|
199
|
+
|
|
200
|
+
results_df = pd.DataFrame(results)
|
|
201
|
+
n_pass = results_df["success"].sum()
|
|
202
|
+
n_total = len(results_df)
|
|
203
|
+
print(f"Validation: {n_pass}/{n_total} expectations passed "
|
|
204
|
+
f"({n_pass/n_total*100:.0f}%)")
|
|
205
|
+
return results_df
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _auto_generate_expectations(df):
|
|
209
|
+
"""自動でバリデーションルールを推論。"""
|
|
210
|
+
expectations = []
|
|
211
|
+
for col in df.columns:
|
|
212
|
+
expectations.append({"type": "expect_column_to_exist", "column": col})
|
|
213
|
+
expectations.append({
|
|
214
|
+
"type": "expect_column_values_to_not_be_null",
|
|
215
|
+
"column": col,
|
|
216
|
+
"kwargs": {"mostly": 0.9}})
|
|
217
|
+
|
|
218
|
+
if df[col].dtype in [np.float64, np.int64]:
|
|
219
|
+
q1, q3 = df[col].quantile([0.01, 0.99])
|
|
220
|
+
iqr = q3 - q1
|
|
221
|
+
expectations.append({
|
|
222
|
+
"type": "expect_column_values_to_be_between",
|
|
223
|
+
"column": col,
|
|
224
|
+
"kwargs": {"min_value": float(q1 - 3 * iqr),
|
|
225
|
+
"max_value": float(q3 + 3 * iqr)}})
|
|
226
|
+
return expectations
|
|
227
|
+
```
|
|
228
|
+
|
|
229
|
+
---
|
|
230
|
+
|
|
231
|
+
## パイプライン統合
|
|
232
|
+
|
|
233
|
+
```
|
|
234
|
+
[データ取得] → data-profiling → eda-correlation
|
|
235
|
+
(品質診断) (探索的解析)
|
|
236
|
+
│ ↓
|
|
237
|
+
missing-data-analysis anomaly-detection
|
|
238
|
+
(欠損補完) (異常検知)
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
## パイプライン出力
|
|
242
|
+
|
|
243
|
+
| ファイル | 説明 | 次スキル |
|
|
244
|
+
|---------|------|---------|
|
|
245
|
+
| `profile_report.html` | ydata-profiling レポート | → EDA |
|
|
246
|
+
| `quality_score.json` | データ品質スコア | → 品質管理 |
|
|
247
|
+
| `validation_results.csv` | バリデーション結果 | → データ修正 |
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: scientific-federated-learning
|
|
3
|
+
description: |
|
|
4
|
+
連合学習スキル。Flower フレームワークによる FL パイプライン・
|
|
5
|
+
FedAvg/FedProx/FedOpt 集約戦略・差分プライバシー (DP-SGD)・
|
|
6
|
+
非 IID データ分割・通信効率化。
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
# Scientific Federated Learning
|
|
10
|
+
|
|
11
|
+
プライバシー保護型分散機械学習を実現する連合学習パイプラインを提供する。
|
|
12
|
+
|
|
13
|
+
## When to Use
|
|
14
|
+
|
|
15
|
+
- 複数施設・組織のデータを集約せずにモデル学習するとき
|
|
16
|
+
- 医療データ・個人情報を含むデータで ML を行うとき
|
|
17
|
+
- 差分プライバシーを適用した学習が必要なとき
|
|
18
|
+
- 非 IID データ分割下での連合学習を設計するとき
|
|
19
|
+
- 通信効率を考慮した分散学習を構築するとき
|
|
20
|
+
|
|
21
|
+
---
|
|
22
|
+
|
|
23
|
+
## Quick Start
|
|
24
|
+
|
|
25
|
+
## 1. Flower 連合学習パイプライン
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
import flwr as fl
|
|
29
|
+
import numpy as np
|
|
30
|
+
from typing import Dict, List, Tuple, Optional
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def create_fl_client(model, train_loader, val_loader,
|
|
34
|
+
device="cpu"):
|
|
35
|
+
"""
|
|
36
|
+
Flower クライアント生成。
|
|
37
|
+
|
|
38
|
+
Parameters:
|
|
39
|
+
model: nn.Module — PyTorch モデル
|
|
40
|
+
train_loader: DataLoader — 訓練データ
|
|
41
|
+
val_loader: DataLoader — 検証データ
|
|
42
|
+
device: str — "cpu" / "cuda"
|
|
43
|
+
"""
|
|
44
|
+
import torch
|
|
45
|
+
|
|
46
|
+
class SatoriFlClient(fl.client.NumPyClient):
|
|
47
|
+
def get_parameters(self, config):
|
|
48
|
+
return [val.cpu().numpy()
|
|
49
|
+
for val in model.parameters()]
|
|
50
|
+
|
|
51
|
+
def set_parameters(self, parameters):
|
|
52
|
+
for param, new_val in zip(model.parameters(), parameters):
|
|
53
|
+
param.data = torch.tensor(new_val).to(device)
|
|
54
|
+
|
|
55
|
+
def fit(self, parameters, config):
|
|
56
|
+
self.set_parameters(parameters)
|
|
57
|
+
model.train()
|
|
58
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
59
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
60
|
+
|
|
61
|
+
epochs = config.get("local_epochs", 1)
|
|
62
|
+
for _ in range(epochs):
|
|
63
|
+
for X, y in train_loader:
|
|
64
|
+
X, y = X.to(device), y.to(device)
|
|
65
|
+
optimizer.zero_grad()
|
|
66
|
+
loss = criterion(model(X), y)
|
|
67
|
+
loss.backward()
|
|
68
|
+
optimizer.step()
|
|
69
|
+
|
|
70
|
+
return self.get_parameters(config), len(train_loader.dataset), {}
|
|
71
|
+
|
|
72
|
+
def evaluate(self, parameters, config):
|
|
73
|
+
self.set_parameters(parameters)
|
|
74
|
+
model.eval()
|
|
75
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
76
|
+
total_loss, correct, total = 0.0, 0, 0
|
|
77
|
+
|
|
78
|
+
with torch.no_grad():
|
|
79
|
+
for X, y in val_loader:
|
|
80
|
+
X, y = X.to(device), y.to(device)
|
|
81
|
+
preds = model(X)
|
|
82
|
+
total_loss += criterion(preds, y).item() * len(y)
|
|
83
|
+
correct += (preds.argmax(1) == y).sum().item()
|
|
84
|
+
total += len(y)
|
|
85
|
+
|
|
86
|
+
return total_loss / total, total, {"accuracy": correct / total}
|
|
87
|
+
|
|
88
|
+
return SatoriFlClient()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def create_fl_strategy(algorithm="fedavg", min_clients=2,
|
|
92
|
+
fraction_fit=1.0, fraction_evaluate=1.0,
|
|
93
|
+
proximal_mu=0.1):
|
|
94
|
+
"""
|
|
95
|
+
連合学習集約戦略の選択。
|
|
96
|
+
|
|
97
|
+
Parameters:
|
|
98
|
+
algorithm: str — "fedavg" / "fedprox" / "fedopt" / "fedadam"
|
|
99
|
+
min_clients: int — 最小クライアント数
|
|
100
|
+
fraction_fit: float — 学習参加率
|
|
101
|
+
fraction_evaluate: float — 評価参加率
|
|
102
|
+
proximal_mu: float — FedProx 近接項の強度
|
|
103
|
+
"""
|
|
104
|
+
common = dict(
|
|
105
|
+
min_fit_clients=min_clients,
|
|
106
|
+
min_evaluate_clients=min_clients,
|
|
107
|
+
min_available_clients=min_clients,
|
|
108
|
+
fraction_fit=fraction_fit,
|
|
109
|
+
fraction_evaluate=fraction_evaluate,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
strategies = {
|
|
113
|
+
"fedavg": fl.server.strategy.FedAvg(**common),
|
|
114
|
+
"fedprox": fl.server.strategy.FedProx(
|
|
115
|
+
proximal_mu=proximal_mu, **common),
|
|
116
|
+
"fedadam": fl.server.strategy.FedAdam(
|
|
117
|
+
eta=1e-1, eta_l=1e-1, tau=1e-9, **common),
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
strategy = strategies.get(algorithm, strategies["fedavg"])
|
|
121
|
+
print(f"FL Strategy: {algorithm} | min_clients={min_clients}")
|
|
122
|
+
return strategy
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
## 2. 差分プライバシー (DP-SGD)
|
|
126
|
+
|
|
127
|
+
```python
|
|
128
|
+
def apply_differential_privacy(model, train_loader,
|
|
129
|
+
target_epsilon=1.0,
|
|
130
|
+
target_delta=1e-5,
|
|
131
|
+
max_grad_norm=1.0,
|
|
132
|
+
noise_multiplier=1.1,
|
|
133
|
+
epochs=10, lr=1e-3):
|
|
134
|
+
"""
|
|
135
|
+
Opacus DP-SGD による差分プライバシー学習。
|
|
136
|
+
|
|
137
|
+
Parameters:
|
|
138
|
+
model: nn.Module — PyTorch モデル
|
|
139
|
+
train_loader: DataLoader — 訓練データ
|
|
140
|
+
target_epsilon: float — プライバシーバジェット ε
|
|
141
|
+
target_delta: float — プライバシーパラメータ δ
|
|
142
|
+
max_grad_norm: float — 勾配クリッピングノルム
|
|
143
|
+
noise_multiplier: float — ノイズ乗数 σ
|
|
144
|
+
epochs: int — 学習エポック数
|
|
145
|
+
lr: float — 学習率
|
|
146
|
+
"""
|
|
147
|
+
import torch
|
|
148
|
+
from opacus import PrivacyEngine
|
|
149
|
+
|
|
150
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
|
151
|
+
privacy_engine = PrivacyEngine()
|
|
152
|
+
|
|
153
|
+
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
|
154
|
+
module=model,
|
|
155
|
+
optimizer=optimizer,
|
|
156
|
+
data_loader=train_loader,
|
|
157
|
+
epochs=epochs,
|
|
158
|
+
target_epsilon=target_epsilon,
|
|
159
|
+
target_delta=target_delta,
|
|
160
|
+
max_grad_norm=max_grad_norm,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
164
|
+
history = []
|
|
165
|
+
|
|
166
|
+
for epoch in range(epochs):
|
|
167
|
+
model.train()
|
|
168
|
+
total_loss = 0
|
|
169
|
+
for X, y in train_loader:
|
|
170
|
+
optimizer.zero_grad()
|
|
171
|
+
loss = criterion(model(X), y)
|
|
172
|
+
loss.backward()
|
|
173
|
+
optimizer.step()
|
|
174
|
+
total_loss += loss.item()
|
|
175
|
+
|
|
176
|
+
epsilon = privacy_engine.get_epsilon(delta=target_delta)
|
|
177
|
+
history.append({"epoch": epoch + 1,
|
|
178
|
+
"loss": total_loss / len(train_loader),
|
|
179
|
+
"epsilon": epsilon})
|
|
180
|
+
print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}, "
|
|
181
|
+
f"ε={epsilon:.2f}")
|
|
182
|
+
|
|
183
|
+
import pandas as pd
|
|
184
|
+
return pd.DataFrame(history)
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
## 3. 非 IID データ分割
|
|
188
|
+
|
|
189
|
+
```python
|
|
190
|
+
def create_non_iid_splits(dataset_labels, n_clients=5,
|
|
191
|
+
alpha=0.5, seed=42):
|
|
192
|
+
"""
|
|
193
|
+
Dirichlet 分布ベースの非 IID データ分割。
|
|
194
|
+
|
|
195
|
+
Parameters:
|
|
196
|
+
dataset_labels: np.ndarray — 全データのラベル配列
|
|
197
|
+
n_clients: int — クライアント数
|
|
198
|
+
alpha: float — Dirichlet α (小さいほど偏りが大きい)
|
|
199
|
+
seed: int — 乱数シード
|
|
200
|
+
"""
|
|
201
|
+
rng = np.random.default_rng(seed)
|
|
202
|
+
n_classes = len(np.unique(dataset_labels))
|
|
203
|
+
client_indices = [[] for _ in range(n_clients)]
|
|
204
|
+
|
|
205
|
+
for c in range(n_classes):
|
|
206
|
+
class_idx = np.where(dataset_labels == c)[0]
|
|
207
|
+
proportions = rng.dirichlet(np.repeat(alpha, n_clients))
|
|
208
|
+
split_points = (np.cumsum(proportions) * len(class_idx)).astype(int)
|
|
209
|
+
splits = np.split(class_idx, split_points[:-1])
|
|
210
|
+
for i, split in enumerate(splits):
|
|
211
|
+
client_indices[i].extend(split.tolist())
|
|
212
|
+
|
|
213
|
+
# 分布サマリー
|
|
214
|
+
for i, indices in enumerate(client_indices):
|
|
215
|
+
labels = dataset_labels[indices]
|
|
216
|
+
unique, counts = np.unique(labels, return_counts=True)
|
|
217
|
+
dist = dict(zip(unique.tolist(), counts.tolist()))
|
|
218
|
+
print(f"Client {i}: {len(indices)} samples, dist={dist}")
|
|
219
|
+
|
|
220
|
+
return client_indices
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
---
|
|
224
|
+
|
|
225
|
+
## パイプライン統合
|
|
226
|
+
|
|
227
|
+
```
|
|
228
|
+
[プライバシー要件] → federated-learning → model-monitoring
|
|
229
|
+
(連合学習) (モデル監視)
|
|
230
|
+
│
|
|
231
|
+
deep-learning ← transfer-learning
|
|
232
|
+
(基盤 NN) (転移学習)
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
## パイプライン出力
|
|
236
|
+
|
|
237
|
+
| ファイル | 説明 | 次スキル |
|
|
238
|
+
|---------|------|---------|
|
|
239
|
+
| `fl_strategy_config.json` | FL 集約設定 | → サーバー起動 |
|
|
240
|
+
| `dp_training_history.csv` | DP 学習履歴 | → model-monitoring |
|
|
241
|
+
| `client_splits.json` | 非 IID 分割情報 | → FL クライアント |
|