@icyfenix-dmla/cli 2026.5.3-821 → 2026.5.4-1250
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +1 -1
- package/scripts/build.js +2 -2
- package/shared_modules/__init__.py +4 -4
- package/shared_modules/bayesian/__init__.py +3 -3
- package/shared_modules/cnn/__init__.py +5 -2
- package/shared_modules/cnn/minimal_preprocess_cache.py +153 -0
- package/shared_modules/cnn/realtime_dataset.py +215 -0
- package/shared_modules/linear/__init__.py +5 -3
- package/shared_modules/neural/__init__.py +2 -2
- package/shared_modules/svm/__init__.py +1 -1
- package/shared_modules/tree/__init__.py +3 -3
- package/shared_modules/unsupervised/__init__.py +2 -2
- package/src/commands/server.js +80 -0
- package/src/commands/update.js +58 -0
- package/src/index.js +11 -0
- package/src/server/cuda_compat_check.py +359 -0
- package/src/server/dmla_progress.py +327 -0
- package/src/server/index.js +18 -5
- package/src/server/kernel_runner.py +515 -0
- package/src/server/routes/sandbox.js +114 -1
- package/src/server/sandbox.js +451 -9
- package/version.json +2 -2
package/package.json
CHANGED
package/scripts/build.js
CHANGED
|
@@ -49,11 +49,11 @@ function copyDir(src, dest, filter = null) {
|
|
|
49
49
|
return true
|
|
50
50
|
}
|
|
51
51
|
|
|
52
|
-
//
|
|
52
|
+
// 复制服务器代码(复制 .js 和 .py 文件)
|
|
53
53
|
console.log('\n📋 复制服务器代码...')
|
|
54
54
|
console.log(` 源目录: ${localServerSrc}`)
|
|
55
55
|
console.log(` 目标目录: ${cliServerDest}`)
|
|
56
|
-
copyDir(localServerSrc, cliServerDest, (name) => name.endsWith('.js'))
|
|
56
|
+
copyDir(localServerSrc, cliServerDest, (name) => name.endsWith('.js') || name.endsWith('.py'))
|
|
57
57
|
|
|
58
58
|
// 复制共享模块(复制所有 .py 文件和 __init__.py)
|
|
59
59
|
console.log('\n📋 复制共享模块...')
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
# shared 模块包初始化
|
|
2
2
|
# 包含统计学习系列文档中可复用的类定义
|
|
3
3
|
|
|
4
|
-
from .linear import *
|
|
5
|
-
from .cnn import *
|
|
6
4
|
from .bayesian import *
|
|
7
|
-
from .
|
|
5
|
+
from .cnn import *
|
|
6
|
+
from .linear import *
|
|
7
|
+
from .neural import *
|
|
8
8
|
from .svm import *
|
|
9
|
+
from .tree import *
|
|
9
10
|
from .unsupervised import *
|
|
10
|
-
from .neural import *
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
#
|
|
2
|
-
from .
|
|
1
|
+
# BAYESIAN 模块
|
|
2
|
+
from .bayesian_network import SimpleBayesianNetwork
|
|
3
3
|
from .gaussian_mixture_model import GaussianMixtureModel
|
|
4
4
|
from .multinomial_naive_bayes import MultinomialNaiveBayes
|
|
5
5
|
|
|
6
|
-
__all__ = ['SimpleBayesianNetwork', 'GaussianMixtureModel', 'MultinomialNaiveBayes']
|
|
6
|
+
__all__ = ['SimpleBayesianNetwork', 'GaussianMixtureModel', 'MultinomialNaiveBayes']
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
# CNN 模块
|
|
2
|
-
from .
|
|
2
|
+
from .alexnet import AlexNet
|
|
3
|
+
from .minimal_preprocess_cache import MinimalPreprocessCache
|
|
4
|
+
from .realtime_dataset import RealtimeAugmentDataset
|
|
5
|
+
from .realtime_dataset import RealtimeValDataset
|
|
3
6
|
from .tiny_imagenet_dataset import TinyImageNetDataset
|
|
4
7
|
|
|
5
|
-
__all__ = ['AlexNet', 'TinyImageNetDataset']
|
|
8
|
+
__all__ = ['AlexNet', 'MinimalPreprocessCache', 'RealtimeAugmentDataset', 'RealtimeValDataset', 'TinyImageNetDataset']
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# MinimalPreprocessCache 类定义
|
|
2
|
+
# 从文档自动提取生成
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
from PIL import Image
|
|
6
|
+
|
|
7
|
+
class MinimalPreprocessCache:
|
|
8
|
+
"""
|
|
9
|
+
最小缓存策略:执行 64*64 -> 224*224 Resize,保存为 JPEG 格式
|
|
10
|
+
缓存大小:约 600MB,训练时约 4GB 内存
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, data_dir, cache_dir):
|
|
13
|
+
self.data_dir = data_dir
|
|
14
|
+
self.cache_dir = cache_dir
|
|
15
|
+
self.train_cache = os.path.join(cache_dir, 'train')
|
|
16
|
+
self.val_cache = os.path.join(cache_dir, 'val')
|
|
17
|
+
self.manifest_path = os.path.join(cache_dir, 'manifest.json')
|
|
18
|
+
|
|
19
|
+
def preprocess_image(self, img_path, save_path):
|
|
20
|
+
"""单张图片预处理:Resize(224) → JPEG"""
|
|
21
|
+
img = Image.open(img_path).convert('RGB')
|
|
22
|
+
img = img.resize((224, 224), Image.BILINEAR)
|
|
23
|
+
img.save(save_path, 'JPEG', quality=95)
|
|
24
|
+
|
|
25
|
+
def check_cache_exists(self):
|
|
26
|
+
"""检查缓存是否已完整存在"""
|
|
27
|
+
return os.path.exists(self.manifest_path)
|
|
28
|
+
|
|
29
|
+
def get_cache_stats(self):
|
|
30
|
+
"""获取缓存统计信息"""
|
|
31
|
+
if os.path.exists(self.manifest_path):
|
|
32
|
+
with open(self.manifest_path, 'r') as f:
|
|
33
|
+
manifest = json.load(f)
|
|
34
|
+
return manifest.get('train_count', 0), manifest.get('val_count', 0)
|
|
35
|
+
return 0, 0
|
|
36
|
+
|
|
37
|
+
def _preprocess_train_set(self, progress):
|
|
38
|
+
"""预处理训练集"""
|
|
39
|
+
train_dir = os.path.join(self.data_dir, 'train')
|
|
40
|
+
classes = sorted(os.listdir(train_dir))
|
|
41
|
+
|
|
42
|
+
os.makedirs(self.train_cache, exist_ok=True)
|
|
43
|
+
total_count = 0
|
|
44
|
+
|
|
45
|
+
for cls_idx, cls in enumerate(classes):
|
|
46
|
+
cls_cache_dir = os.path.join(self.train_cache, cls)
|
|
47
|
+
|
|
48
|
+
# 中断恢复:检查已存在的类别目录
|
|
49
|
+
if os.path.exists(cls_cache_dir):
|
|
50
|
+
existing_files = [f for f in os.listdir(cls_cache_dir) if f.endswith('.JPEG')]
|
|
51
|
+
if len(existing_files) >= 500:
|
|
52
|
+
total_count += len(existing_files)
|
|
53
|
+
progress.update(cls_idx + 1, message=f"跳过已缓存类别 {cls_idx+1}/200: {cls}")
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
os.makedirs(cls_cache_dir, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
images_dir = os.path.join(train_dir, cls, 'images')
|
|
59
|
+
if not os.path.exists(images_dir):
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
count = 0
|
|
63
|
+
for img_name in os.listdir(images_dir):
|
|
64
|
+
if img_name.endswith('.JPEG'):
|
|
65
|
+
img_path = os.path.join(images_dir, img_name)
|
|
66
|
+
save_path = os.path.join(cls_cache_dir, img_name)
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
self.preprocess_image(img_path, save_path)
|
|
70
|
+
count += 1
|
|
71
|
+
total_count += 1
|
|
72
|
+
except Exception as e:
|
|
73
|
+
print(f"Warning: Failed to process {img_path}: {e}")
|
|
74
|
+
|
|
75
|
+
progress.update(cls_idx + 1, message=f"预处理类别 {cls_idx+1}/200: {cls} ({count} 张)")
|
|
76
|
+
|
|
77
|
+
return total_count
|
|
78
|
+
|
|
79
|
+
def _preprocess_val_set(self, progress):
|
|
80
|
+
"""预处理验证集"""
|
|
81
|
+
val_dir = os.path.join(self.data_dir, 'val')
|
|
82
|
+
val_images_dir = os.path.join(val_dir, 'images')
|
|
83
|
+
val_annotations = os.path.join(val_dir, 'val_annotations.txt')
|
|
84
|
+
|
|
85
|
+
# 读取类别映射
|
|
86
|
+
wnids_path = os.path.join(self.data_dir, 'wnids.txt')
|
|
87
|
+
with open(wnids_path, 'r') as f:
|
|
88
|
+
wnids = [line.strip() for line in f.readlines()]
|
|
89
|
+
class_to_idx = {wnid: idx for idx, wnid in enumerate(wnids)}
|
|
90
|
+
|
|
91
|
+
# 读取标注文件
|
|
92
|
+
with open(val_annotations, 'r') as f:
|
|
93
|
+
val_lines = f.readlines()
|
|
94
|
+
total_val = len(val_lines)
|
|
95
|
+
|
|
96
|
+
os.makedirs(self.val_cache, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
# 中断恢复
|
|
99
|
+
existing_files = [f for f in os.listdir(self.val_cache) if f.endswith('.JPEG')]
|
|
100
|
+
start_idx = len(existing_files)
|
|
101
|
+
if start_idx >= total_val:
|
|
102
|
+
progress.update(total_val, message=f"验证集已缓存: {total_val} 张")
|
|
103
|
+
return total_val, []
|
|
104
|
+
|
|
105
|
+
labels = []
|
|
106
|
+
progress.reset(total_steps=total_val, description="预处理验证集")
|
|
107
|
+
|
|
108
|
+
for line_idx in range(start_idx, total_val):
|
|
109
|
+
parts = val_lines[line_idx].strip().split('\t')
|
|
110
|
+
if len(parts) >= 2:
|
|
111
|
+
img_name = parts[0]
|
|
112
|
+
img_path = os.path.join(val_images_dir, img_name)
|
|
113
|
+
save_path = os.path.join(self.val_cache, f'val_{line_idx}.JPEG')
|
|
114
|
+
if os.path.exists(img_path):
|
|
115
|
+
try:
|
|
116
|
+
self.preprocess_image(img_path, save_path)
|
|
117
|
+
labels.append(class_to_idx.get(parts[1], 0))
|
|
118
|
+
except Exception as e:
|
|
119
|
+
print(f"处理图片出现异常 {img_path}: {e}")
|
|
120
|
+
|
|
121
|
+
if (line_idx + 1) % 100 == 0 or line_idx == total_val - 1:
|
|
122
|
+
progress.update(line_idx + 1, message=f"预处理验证集 {line_idx+1}/{total_val}")
|
|
123
|
+
|
|
124
|
+
return total_val, labels
|
|
125
|
+
|
|
126
|
+
def run(self, progress):
|
|
127
|
+
"""执行预处理(支持断点续传)"""
|
|
128
|
+
start_time = time.time()
|
|
129
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
130
|
+
train_count = self._preprocess_train_set(progress)
|
|
131
|
+
val_count, val_labels = self._preprocess_val_set(progress)
|
|
132
|
+
|
|
133
|
+
# 保存清单文件
|
|
134
|
+
manifest = {
|
|
135
|
+
'train_count': train_count,
|
|
136
|
+
'val_count': val_count,
|
|
137
|
+
'val_labels': val_labels if val_labels else self._load_existing_val_labels()
|
|
138
|
+
}
|
|
139
|
+
with open(self.manifest_path, 'w') as f:
|
|
140
|
+
json.dump(manifest, f)
|
|
141
|
+
|
|
142
|
+
elapsed = time.time() - start_time
|
|
143
|
+
progress.complete(message=f"预处理完成: 训练集 {train_count} 张, 验证集 {val_count} 张, 耗时 {elapsed:.1f}s")
|
|
144
|
+
|
|
145
|
+
return train_count, val_count
|
|
146
|
+
|
|
147
|
+
def _load_existing_val_labels(self):
|
|
148
|
+
"""加载已有的验证集标签"""
|
|
149
|
+
if os.path.exists(self.manifest_path):
|
|
150
|
+
with open(self.manifest_path, 'r') as f:
|
|
151
|
+
manifest = json.load(f)
|
|
152
|
+
return manifest.get('val_labels', [])
|
|
153
|
+
return []
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# 实时数据增强 Dataset 类
|
|
2
|
+
# 从缓存读取 JPEG,实时执行数据增强
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
from PIL import Image
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
from torchvision import transforms
|
|
11
|
+
|
|
12
|
+
# 性能日志文件(全局)
|
|
13
|
+
PERF_LOG_PATH = '/data/models/alexnet/dataset_perf_log.txt'
|
|
14
|
+
_perf_log_file = None
|
|
15
|
+
_perf_counter = 0
|
|
16
|
+
|
|
17
|
+
def _get_perf_log():
|
|
18
|
+
"""获取性能日志文件(懒加载)"""
|
|
19
|
+
global _perf_log_file
|
|
20
|
+
if _perf_log_file is None:
|
|
21
|
+
os.makedirs('/data/models/alexnet', exist_ok=True)
|
|
22
|
+
_perf_log_file = open(PERF_LOG_PATH, 'w')
|
|
23
|
+
_perf_log_file.write("idx,jpeg_decode_ms,to_tensor_ms,augment_ms,total_ms\n")
|
|
24
|
+
return _perf_log_file
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RealtimeAugmentDataset(Dataset):
|
|
28
|
+
"""
|
|
29
|
+
实时执行数据增强的 Dataset
|
|
30
|
+
|
|
31
|
+
流程:
|
|
32
|
+
1. 从缓存读取 JPEG(224×224)
|
|
33
|
+
2. CPU 执行 ToTensor + RandomFlip + RandomCrop + ColorJitter
|
|
34
|
+
3. Normalize 参数提供,可在 GPU 执行
|
|
35
|
+
|
|
36
|
+
特点:
|
|
37
|
+
- 内存占用低:只缓存当前 batch
|
|
38
|
+
- 数据增强随机性:每次 epoch 看到不同版本
|
|
39
|
+
- 多线程友好:配合 DataLoader num_workers
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
# ImageNet Normalize 参数
|
|
43
|
+
MEAN = [0.485, 0.456, 0.406]
|
|
44
|
+
STD = [0.229, 0.224, 0.225]
|
|
45
|
+
|
|
46
|
+
def __init__(self, cache_dir, augment=True, normalize_on_gpu=False):
|
|
47
|
+
"""
|
|
48
|
+
Args:
|
|
49
|
+
cache_dir: 缓存目录路径
|
|
50
|
+
augment: 是否执行数据增强(训练集 True,验证集 False)
|
|
51
|
+
normalize_on_gpu: 是否将 Normalize 移到 GPU 执行
|
|
52
|
+
"""
|
|
53
|
+
self.cache_dir = cache_dir
|
|
54
|
+
self.augment = augment
|
|
55
|
+
self.normalize_on_gpu = normalize_on_gpu
|
|
56
|
+
|
|
57
|
+
# 加载清单文件
|
|
58
|
+
manifest_path = os.path.join(cache_dir, 'manifest.json')
|
|
59
|
+
if os.path.exists(manifest_path):
|
|
60
|
+
with open(manifest_path, 'r') as f:
|
|
61
|
+
self.manifest = json.load(f)
|
|
62
|
+
else:
|
|
63
|
+
self.manifest = {'val_labels': []}
|
|
64
|
+
|
|
65
|
+
# 加载图片路径和标签
|
|
66
|
+
self.image_paths = []
|
|
67
|
+
self.labels = []
|
|
68
|
+
|
|
69
|
+
train_cache = os.path.join(cache_dir, 'train')
|
|
70
|
+
if os.path.exists(train_cache):
|
|
71
|
+
# 加载类别映射
|
|
72
|
+
wnids_path = '/data/datasets/tiny-imagenet-200/wnids.txt'
|
|
73
|
+
if os.path.exists(wnids_path):
|
|
74
|
+
with open(wnids_path, 'r') as f:
|
|
75
|
+
wnids = [line.strip() for line in f.readlines()]
|
|
76
|
+
class_to_idx = {wnid: idx for idx, wnid in enumerate(wnids)}
|
|
77
|
+
else:
|
|
78
|
+
class_to_idx = {}
|
|
79
|
+
|
|
80
|
+
classes = sorted(os.listdir(train_cache))
|
|
81
|
+
for cls_idx, cls in enumerate(classes):
|
|
82
|
+
cls_dir = os.path.join(train_cache, cls)
|
|
83
|
+
if os.path.isdir(cls_dir):
|
|
84
|
+
for img_name in os.listdir(cls_dir):
|
|
85
|
+
if img_name.endswith('.JPEG'):
|
|
86
|
+
self.image_paths.append(os.path.join(cls_dir, img_name))
|
|
87
|
+
self.labels.append(class_to_idx.get(cls, cls_idx))
|
|
88
|
+
|
|
89
|
+
# CPU 数据增强(训练集)- 分离各个变换以便测量耗时
|
|
90
|
+
if augment:
|
|
91
|
+
self.to_tensor = transforms.ToTensor()
|
|
92
|
+
self.augment_transform = transforms.Compose([
|
|
93
|
+
transforms.RandomHorizontalFlip(p=0.5),
|
|
94
|
+
transforms.RandomCrop(224, padding=4),
|
|
95
|
+
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
|
96
|
+
])
|
|
97
|
+
else:
|
|
98
|
+
# 验证集预处理(无增强)
|
|
99
|
+
self.to_tensor = transforms.ToTensor()
|
|
100
|
+
self.augment_transform = None
|
|
101
|
+
|
|
102
|
+
# Normalize(可在 CPU 或 GPU 执行)
|
|
103
|
+
self.normalize = None if normalize_on_gpu else transforms.Normalize(mean=self.MEAN, std=self.STD)
|
|
104
|
+
|
|
105
|
+
def __len__(self):
|
|
106
|
+
return len(self.image_paths)
|
|
107
|
+
|
|
108
|
+
def __getitem__(self, idx):
|
|
109
|
+
"""
|
|
110
|
+
获取单张图片(带详细耗时测量)
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
image: tensor [3, 224, 224]
|
|
114
|
+
label: int
|
|
115
|
+
"""
|
|
116
|
+
global _perf_counter
|
|
117
|
+
|
|
118
|
+
start_time = time.time()
|
|
119
|
+
img_path = self.image_paths[idx]
|
|
120
|
+
label = self.labels[idx]
|
|
121
|
+
|
|
122
|
+
# 1. JPEG 解码耗时
|
|
123
|
+
jpeg_start = time.time()
|
|
124
|
+
image = Image.open(img_path).convert('RGB')
|
|
125
|
+
jpeg_time = time.time() - jpeg_start
|
|
126
|
+
|
|
127
|
+
# 2. ToTensor 耗时
|
|
128
|
+
to_tensor_start = time.time()
|
|
129
|
+
image = self.to_tensor(image)
|
|
130
|
+
to_tensor_time = time.time() - to_tensor_start
|
|
131
|
+
|
|
132
|
+
# 3. 数据增强耗时(如果有)
|
|
133
|
+
augment_time = 0.0
|
|
134
|
+
if self.augment_transform:
|
|
135
|
+
augment_start = time.time()
|
|
136
|
+
image = self.augment_transform(image)
|
|
137
|
+
augment_time = time.time() - augment_start
|
|
138
|
+
|
|
139
|
+
# 4. Normalize 耗时(如果在 CPU 执行)
|
|
140
|
+
if self.normalize:
|
|
141
|
+
image = self.normalize(image)
|
|
142
|
+
|
|
143
|
+
total_time = time.time() - start_time
|
|
144
|
+
|
|
145
|
+
# 采样记录日志(每100张记录一次,避免日志过大)
|
|
146
|
+
_perf_counter += 1
|
|
147
|
+
if _perf_counter % 100 == 0:
|
|
148
|
+
log = _get_perf_log()
|
|
149
|
+
log.write(f"{idx},{jpeg_time*1000:.1f},{to_tensor_time*1000:.1f},{augment_time*1000:.1f},{total_time*1000:.1f}\n")
|
|
150
|
+
log.flush()
|
|
151
|
+
|
|
152
|
+
return image, label
|
|
153
|
+
|
|
154
|
+
def get_normalize_params(self):
|
|
155
|
+
"""获取 Normalize 参数(供 GPU 执行时使用)"""
|
|
156
|
+
return torch.tensor(self.MEAN).view(3, 1, 1), torch.tensor(self.STD).view(3, 1, 1)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class RealtimeValDataset(Dataset):
|
|
160
|
+
"""
|
|
161
|
+
验证集 Dataset(从缓存读取)
|
|
162
|
+
|
|
163
|
+
扁平化结构:val/val_<idx>.JPEG
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
MEAN = [0.485, 0.456, 0.406]
|
|
167
|
+
STD = [0.229, 0.224, 0.225]
|
|
168
|
+
|
|
169
|
+
def __init__(self, cache_dir, normalize_on_gpu=False):
|
|
170
|
+
self.cache_dir = cache_dir
|
|
171
|
+
self.val_cache = os.path.join(cache_dir, 'val')
|
|
172
|
+
self.normalize_on_gpu = normalize_on_gpu
|
|
173
|
+
|
|
174
|
+
# 加载清单文件获取标签
|
|
175
|
+
manifest_path = os.path.join(cache_dir, 'manifest.json')
|
|
176
|
+
if os.path.exists(manifest_path):
|
|
177
|
+
with open(manifest_path, 'r') as f:
|
|
178
|
+
manifest = json.load(f)
|
|
179
|
+
self.labels = manifest.get('val_labels', [])
|
|
180
|
+
else:
|
|
181
|
+
self.labels = []
|
|
182
|
+
|
|
183
|
+
# 构建图片路径列表
|
|
184
|
+
self.image_paths = []
|
|
185
|
+
if os.path.exists(self.val_cache):
|
|
186
|
+
for i in range(len(self.labels)):
|
|
187
|
+
img_path = os.path.join(self.val_cache, f'val_{i}.JPEG')
|
|
188
|
+
if os.path.exists(img_path):
|
|
189
|
+
self.image_paths.append(img_path)
|
|
190
|
+
|
|
191
|
+
# 验证集变换(无增强)
|
|
192
|
+
self.transform = transforms.Compose([
|
|
193
|
+
transforms.ToTensor(),
|
|
194
|
+
])
|
|
195
|
+
|
|
196
|
+
if not normalize_on_gpu:
|
|
197
|
+
self.transform = transforms.Compose([
|
|
198
|
+
self.transform,
|
|
199
|
+
transforms.Normalize(mean=self.MEAN, std=self.STD)
|
|
200
|
+
])
|
|
201
|
+
|
|
202
|
+
def __len__(self):
|
|
203
|
+
return len(self.image_paths)
|
|
204
|
+
|
|
205
|
+
def __getitem__(self, idx):
|
|
206
|
+
img_path = self.image_paths[idx]
|
|
207
|
+
label = self.labels[idx]
|
|
208
|
+
|
|
209
|
+
image = Image.open(img_path).convert('RGB')
|
|
210
|
+
image = self.transform(image)
|
|
211
|
+
|
|
212
|
+
return image, label
|
|
213
|
+
|
|
214
|
+
def get_normalize_params(self):
|
|
215
|
+
return torch.tensor(self.MEAN).view(3, 1, 1), torch.tensor(self.STD).view(3, 1, 1)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
#
|
|
1
|
+
# LINEAR 模块
|
|
2
|
+
from .lasso_regression import LassoRegression
|
|
2
3
|
from .logistic_regression import LogisticRegression
|
|
4
|
+
from .naive_bayes import MultinomialNaiveBayes
|
|
5
|
+
from .naive_bayes import GaussianNaiveBayes
|
|
3
6
|
from .ridge_regression import RidgeRegression
|
|
4
|
-
from .lasso_regression import LassoRegression
|
|
5
7
|
|
|
6
|
-
__all__ = ['LogisticRegression', '
|
|
8
|
+
__all__ = ['LassoRegression', 'LogisticRegression', 'MultinomialNaiveBayes', 'GaussianNaiveBayes', 'RidgeRegression']
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
#
|
|
1
|
+
# TREE 模块
|
|
2
|
+
from .ada_boost import AdaBoost
|
|
2
3
|
from .decision_tree_classifier import DecisionTreeClassifier
|
|
3
4
|
from .random_forest_classifier import RandomForestClassifier
|
|
4
|
-
from .ada_boost import AdaBoost
|
|
5
5
|
|
|
6
|
-
__all__ = ['
|
|
6
|
+
__all__ = ['AdaBoost', 'DecisionTreeClassifier', 'RandomForestClassifier']
|
package/src/commands/server.js
CHANGED
|
@@ -326,6 +326,44 @@ function findServerPath() {
|
|
|
326
326
|
return null
|
|
327
327
|
}
|
|
328
328
|
|
|
329
|
+
/**
|
|
330
|
+
* 查找 kernel_runner.py 路径
|
|
331
|
+
* --dev 模式下需要挂载此文件
|
|
332
|
+
*/
|
|
333
|
+
function findKernelRunnerPath() {
|
|
334
|
+
// 开发环境路径:packages/cli/src/commands -> ../../../local-server/src/kernel_runner.py
|
|
335
|
+
const devPath = path.resolve(__dirname, '../../../local-server/src/kernel_runner.py')
|
|
336
|
+
// npm 包路径:packages/cli/src/commands -> ../server/kernel_runner.py(构建后)
|
|
337
|
+
const npmPath = path.resolve(__dirname, '../server/kernel_runner.py')
|
|
338
|
+
|
|
339
|
+
if (fs.existsSync(devPath)) {
|
|
340
|
+
return devPath
|
|
341
|
+
}
|
|
342
|
+
if (fs.existsSync(npmPath)) {
|
|
343
|
+
return npmPath
|
|
344
|
+
}
|
|
345
|
+
return null
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
/**
|
|
349
|
+
* 查找 dmla_progress.py 路径
|
|
350
|
+
* --dev 模式下需要挂载此文件
|
|
351
|
+
*/
|
|
352
|
+
function findProgressReporterPath() {
|
|
353
|
+
// 开发环境路径:packages/cli/src/commands -> ../../../local-server/src/dmla_progress.py
|
|
354
|
+
const devPath = path.resolve(__dirname, '../../../local-server/src/dmla_progress.py')
|
|
355
|
+
// npm 包路径:packages/cli/src/commands -> ../server/dmla_progress.py(构建后)
|
|
356
|
+
const npmPath = path.resolve(__dirname, '../server/dmla_progress.py')
|
|
357
|
+
|
|
358
|
+
if (fs.existsSync(devPath)) {
|
|
359
|
+
return devPath
|
|
360
|
+
}
|
|
361
|
+
if (fs.existsSync(npmPath)) {
|
|
362
|
+
return npmPath
|
|
363
|
+
}
|
|
364
|
+
return null
|
|
365
|
+
}
|
|
366
|
+
|
|
329
367
|
/**
|
|
330
368
|
* 查找共享模块目录
|
|
331
369
|
* --dev 模式下需要挂载此目录
|
|
@@ -427,10 +465,23 @@ export async function startServerSync(port, useGpu = false, dev = false) {
|
|
|
427
465
|
|
|
428
466
|
// 查找共享模块路径(--dev 模式需要)
|
|
429
467
|
const sharedModulesPath = dev ? findSharedModulesPath() : null
|
|
468
|
+
// 查找 kernel_runner.py 路径(--dev 模式需要)
|
|
469
|
+
const kernelRunnerPath = dev ? findKernelRunnerPath() : null
|
|
470
|
+
// 查找 dmla_progress.py 路径(--dev 模式需要)
|
|
471
|
+
const progressReporterPath = dev ? findProgressReporterPath() : null
|
|
472
|
+
|
|
430
473
|
if (dev && !sharedModulesPath) {
|
|
431
474
|
console.log(chalk.yellow('⚠️ --dev 模式需要共享模块目录'))
|
|
432
475
|
console.log(chalk.gray(' 未找到 shared_modules,将仅使用镜像内置模块'))
|
|
433
476
|
}
|
|
477
|
+
if (dev && !kernelRunnerPath) {
|
|
478
|
+
console.log(chalk.yellow('⚠️ --dev 模式需要 kernel_runner.py'))
|
|
479
|
+
console.log(chalk.gray(' 未找到 kernel_runner.py,将仅使用镜像内置版本'))
|
|
480
|
+
}
|
|
481
|
+
if (dev && !progressReporterPath) {
|
|
482
|
+
console.log(chalk.yellow('⚠️ --dev 模式需要 dmla_progress.py'))
|
|
483
|
+
console.log(chalk.gray(' 未找到 dmla_progress.py,将仅使用镜像内置版本'))
|
|
484
|
+
}
|
|
434
485
|
|
|
435
486
|
console.log(chalk.gray(` 镜像类型: ${imageResolution.message}`))
|
|
436
487
|
console.log(chalk.gray(' 同步模式启动...'))
|
|
@@ -438,6 +489,12 @@ export async function startServerSync(port, useGpu = false, dev = false) {
|
|
|
438
489
|
if (dev && sharedModulesPath) {
|
|
439
490
|
console.log(chalk.gray(` 共享模块: ${sharedModulesPath}`))
|
|
440
491
|
}
|
|
492
|
+
if (dev && kernelRunnerPath) {
|
|
493
|
+
console.log(chalk.gray(` 执行器: ${kernelRunnerPath}`))
|
|
494
|
+
}
|
|
495
|
+
if (dev && progressReporterPath) {
|
|
496
|
+
console.log(chalk.gray(` 进度报告: ${progressReporterPath}`))
|
|
497
|
+
}
|
|
441
498
|
console.log()
|
|
442
499
|
|
|
443
500
|
// 设置环境变量
|
|
@@ -452,6 +509,12 @@ export async function startServerSync(port, useGpu = false, dev = false) {
|
|
|
452
509
|
if (sharedModulesPath) {
|
|
453
510
|
process.env.SHARED_MODULES_PATH = sharedModulesPath
|
|
454
511
|
}
|
|
512
|
+
if (kernelRunnerPath) {
|
|
513
|
+
process.env.KERNEL_RUNNER_PATH = kernelRunnerPath
|
|
514
|
+
}
|
|
515
|
+
if (progressReporterPath) {
|
|
516
|
+
process.env.PROGRESS_REPORTER_PATH = progressReporterPath
|
|
517
|
+
}
|
|
455
518
|
}
|
|
456
519
|
|
|
457
520
|
// 动态 import 服务器模块并直接运行
|
|
@@ -542,9 +605,20 @@ export async function startServer(port, useGpu = false, dev = false) {
|
|
|
542
605
|
|
|
543
606
|
// 查找共享模块路径(--dev 模式需要)
|
|
544
607
|
const sharedModulesPath = dev ? findSharedModulesPath() : null
|
|
608
|
+
// 查找 kernel_runner.py 路径(--dev 模式需要)
|
|
609
|
+
const kernelRunnerPath = dev ? findKernelRunnerPath() : null
|
|
610
|
+
// 查找 dmla_progress.py 路径(--dev 模式需要)
|
|
611
|
+
const progressReporterPath = dev ? findProgressReporterPath() : null
|
|
612
|
+
|
|
545
613
|
if (dev && sharedModulesPath) {
|
|
546
614
|
console.log(chalk.gray(` 共享模块: ${sharedModulesPath}`))
|
|
547
615
|
}
|
|
616
|
+
if (dev && kernelRunnerPath) {
|
|
617
|
+
console.log(chalk.gray(` 执行器: ${kernelRunnerPath}`))
|
|
618
|
+
}
|
|
619
|
+
if (dev && progressReporterPath) {
|
|
620
|
+
console.log(chalk.gray(` 进度报告: ${progressReporterPath}`))
|
|
621
|
+
}
|
|
548
622
|
|
|
549
623
|
// 日志文件路径
|
|
550
624
|
const logDir = path.resolve(__dirname, '../../logs')
|
|
@@ -574,6 +648,12 @@ export async function startServer(port, useGpu = false, dev = false) {
|
|
|
574
648
|
if (sharedModulesPath) {
|
|
575
649
|
env.SHARED_MODULES_PATH = sharedModulesPath
|
|
576
650
|
}
|
|
651
|
+
if (kernelRunnerPath) {
|
|
652
|
+
env.KERNEL_RUNNER_PATH = kernelRunnerPath
|
|
653
|
+
}
|
|
654
|
+
if (progressReporterPath) {
|
|
655
|
+
env.PROGRESS_REPORTER_PATH = progressReporterPath
|
|
656
|
+
}
|
|
577
657
|
}
|
|
578
658
|
|
|
579
659
|
// 写入启动日志
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* DMLA CLI update 命令
|
|
3
|
+
* 通过 npm 更新程序
|
|
4
|
+
*/
|
|
5
|
+
import chalk from 'chalk'
|
|
6
|
+
import { execSync } from 'child_process'
|
|
7
|
+
|
|
8
|
+
/**
|
|
9
|
+
* 运行 update 命令
|
|
10
|
+
*/
|
|
11
|
+
export async function runUpdate() {
|
|
12
|
+
console.log(chalk.blue('更新 DMLA...'))
|
|
13
|
+
console.log()
|
|
14
|
+
|
|
15
|
+
try {
|
|
16
|
+
// 检查当前版本
|
|
17
|
+
const currentVersion = execSync('npm list -g @icyfenix-dmla/cli --depth=0 2>/dev/null | grep @icyfenix-dmla/cli', { encoding: 'utf-8' })
|
|
18
|
+
.trim()
|
|
19
|
+
.split('@')[2] || '未知'
|
|
20
|
+
|
|
21
|
+
console.log(chalk.gray(`当前版本: ${currentVersion}`))
|
|
22
|
+
console.log()
|
|
23
|
+
|
|
24
|
+
// 执行 npm 更新
|
|
25
|
+
console.log(chalk.cyan('正在从 npm 更新...'))
|
|
26
|
+
const output = execSync('npm update -g @icyfenix-dmla/cli @icyfenix-dmla/install', {
|
|
27
|
+
encoding: 'utf-8',
|
|
28
|
+
stdio: 'pipe'
|
|
29
|
+
})
|
|
30
|
+
|
|
31
|
+
console.log(output)
|
|
32
|
+
|
|
33
|
+
// 检查更新后的版本
|
|
34
|
+
const newVersion = execSync('npm list -g @icyfenix-dmla/cli --depth=0 2>/dev/null | grep @icyfenix-dmla/cli', { encoding: 'utf-8' })
|
|
35
|
+
.trim()
|
|
36
|
+
.split('@')[2] || '未知'
|
|
37
|
+
|
|
38
|
+
if (newVersion !== currentVersion) {
|
|
39
|
+
console.log()
|
|
40
|
+
console.log(chalk.green(`✓ 更新成功!`))
|
|
41
|
+
console.log(chalk.gray(` ${currentVersion} → ${newVersion}`))
|
|
42
|
+
} else {
|
|
43
|
+
console.log()
|
|
44
|
+
console.log(chalk.yellow('已是最新版本'))
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// 提示用户更新 Docker 镜像(可选)
|
|
48
|
+
console.log()
|
|
49
|
+
console.log(chalk.gray('提示: 如需更新 Docker 镜像,请运行: dmla install'))
|
|
50
|
+
|
|
51
|
+
} catch (error) {
|
|
52
|
+
console.error(chalk.red('更新失败:'))
|
|
53
|
+
console.error(chalk.red(error.message))
|
|
54
|
+
console.log()
|
|
55
|
+
console.log(chalk.yellow('建议手动执行: npm update -g @icyfenix-dmla/cli @icyfenix-dmla/install'))
|
|
56
|
+
process.exit(1)
|
|
57
|
+
}
|
|
58
|
+
}
|