jcclang 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jcclang-0.1.1/PKG-INFO +9 -0
- jcclang-0.1.1/jcclang/__init__.py +0 -0
- jcclang-0.1.1/jcclang/adapter/__init__.py +17 -0
- jcclang-0.1.1/jcclang/adapter/base_adapter.py +10 -0
- jcclang-0.1.1/jcclang/adapter/modelarts.py +56 -0
- jcclang-0.1.1/jcclang/adapter/octopus.py +49 -0
- jcclang-0.1.1/jcclang/adapter/openi.py +50 -0
- jcclang-0.1.1/jcclang/api/__init__.py +4 -0
- jcclang-0.1.1/jcclang/api/context.py +24 -0
- jcclang-0.1.1/jcclang/api/data_loader/llama_loader.py +99 -0
- jcclang-0.1.1/jcclang/api/data_loader/pandas_loader.py +52 -0
- jcclang-0.1.1/jcclang/api/data_loader/tensorflow_loader.py +53 -0
- jcclang-0.1.1/jcclang/api/data_loader/torch_loader.py +78 -0
- jcclang-0.1.1/jcclang/api/data_loader/transformers_hook/vfile_io2.py +57 -0
- jcclang-0.1.1/jcclang/api/data_loader/transformers_hook/virtual_file_io.py +141 -0
- jcclang-0.1.1/jcclang/api/data_loader/transformers_loader.py +105 -0
- jcclang-0.1.1/jcclang/api/data_prepare.py +35 -0
- jcclang-0.1.1/jcclang/api/decorators.py +107 -0
- jcclang-0.1.1/jcclang/api/jobhub.py +88 -0
- jcclang-0.1.1/jcclang/api/jobhub_api.py +290 -0
- jcclang-0.1.1/jcclang/api/jobhub_api2.py +44 -0
- jcclang-0.1.1/jcclang/core/__init__.py +0 -0
- jcclang-0.1.1/jcclang/core/const.py +40 -0
- jcclang-0.1.1/jcclang/core/context.py +103 -0
- jcclang-0.1.1/jcclang/core/io/__init__.py +12 -0
- jcclang-0.1.1/jcclang/core/io/abc.py +35 -0
- jcclang-0.1.1/jcclang/core/io/chunked.py +244 -0
- jcclang-0.1.1/jcclang/core/io/ext.py +96 -0
- jcclang-0.1.1/jcclang/core/io/utils.py +36 -0
- jcclang-0.1.1/jcclang/core/jobhub/__init__.py +3 -0
- jcclang-0.1.1/jcclang/core/jobhub/abc.py +175 -0
- jcclang-0.1.1/jcclang/core/jobhub/exceptions.py +6 -0
- jcclang-0.1.1/jcclang/core/jobhub/jobhub.py +271 -0
- jcclang-0.1.1/jcclang/core/jobhub/rpc.py +118 -0
- jcclang-0.1.1/jcclang/core/jobhub/types.py +189 -0
- jcclang-0.1.1/jcclang/core/jobhub/utils.py +89 -0
- jcclang-0.1.1/jcclang/core/jobhub/yamux/__init__.py +3 -0
- jcclang-0.1.1/jcclang/core/jobhub/yamux/exceptions.py +28 -0
- jcclang-0.1.1/jcclang/core/jobhub/yamux/yamux.py +721 -0
- jcclang-0.1.1/jcclang/core/logger.py +55 -0
- jcclang-0.1.1/jcclang/core/metadata.py +18 -0
- jcclang-0.1.1/jcclang/core/model.py +45 -0
- jcclang-0.1.1/jcclang/core/registry.py +22 -0
- jcclang-0.1.1/jcclang/core/schema.py +22 -0
- jcclang-0.1.1/jcclang/core/utils/__init__.py +0 -0
- jcclang-0.1.1/jcclang/core/utils/args_mapping.py +9 -0
- jcclang-0.1.1/jcclang/core/utils/presign.py +88 -0
- jcclang-0.1.1/jcclang/core/utils/presign_test.py +29 -0
- jcclang-0.1.1/jcclang/core/utils/rw_lock.py +70 -0
- jcclang-0.1.1/jcclang/examples/__init__.py +0 -0
- jcclang-0.1.1/jcclang/examples/data_loader/llama_streaming.py +37 -0
- jcclang-0.1.1/jcclang/examples/data_loader/pandas_demo.py +162 -0
- jcclang-0.1.1/jcclang/examples/data_loader/pt_emnist.py +95 -0
- jcclang-0.1.1/jcclang/examples/data_loader/pt_llama.py +35 -0
- jcclang-0.1.1/jcclang/examples/data_loader/pt_llama_hook.py +80 -0
- jcclang-0.1.1/jcclang/examples/data_loader/pt_mnist.py +165 -0
- jcclang-0.1.1/jcclang/examples/data_loader/pt_mnist2.py +150 -0
- jcclang-0.1.1/jcclang/examples/data_loader/pt_mnist_dirct.py +192 -0
- jcclang-0.1.1/jcclang/examples/data_loader/tf_demo.py +27 -0
- jcclang-0.1.1/jcclang/examples/data_prepare/__init__.py +0 -0
- jcclang-0.1.1/jcclang/examples/data_prepare/data_preprocess.py +197 -0
- jcclang-0.1.1/jcclang/examples/data_prepare/fl.py +269 -0
- jcclang-0.1.1/jcclang/examples/data_prepare/train_helloworld.py +53 -0
- jcclang-0.1.1/jcclang/examples/fl_2/data_partition.py +50 -0
- jcclang-0.1.1/jcclang/examples/fl_2/fl_agent.py +711 -0
- jcclang-0.1.1/jcclang/examples/fl_2/fl_agent2.py +711 -0
- jcclang-0.1.1/jcclang/examples/fl_2/fl_agent3.py +691 -0
- jcclang-0.1.1/jcclang/examples/fl_3/agent01.py +59 -0
- jcclang-0.1.1/jcclang/examples/fl_3/agent02.py +59 -0
- jcclang-0.1.1/jcclang/examples/fl_3/client.py +386 -0
- jcclang-0.1.1/jcclang/examples/fl_3/const.py +9 -0
- jcclang-0.1.1/jcclang/examples/fl_3/data_partition.py +84 -0
- jcclang-0.1.1/jcclang/examples/fl_3/leader.py +615 -0
- jcclang-0.1.1/jcclang/examples/fl_3/leader01.py +63 -0
- jcclang-0.1.1/jcclang/examples/fl_3/model.py +71 -0
- jcclang-0.1.1/jcclang/examples/fl_3/utils.py +468 -0
- jcclang-0.1.1/jcclang/examples/s_a_fl_jobhub/__init__.py +0 -0
- jcclang-0.1.1/jcclang/examples/s_a_fl_jobhub/entry.py +44 -0
- jcclang-0.1.1/jcclang/examples/s_a_fl_jobhub/fl_client.py +196 -0
- jcclang-0.1.1/jcclang/examples/s_a_fl_jobhub/fl_server.py +217 -0
- jcclang-0.1.1/jcclang/examples/s_a_fl_jobhub/jobhub_test2.py +58 -0
- jcclang-0.1.1/jcclang/examples/workflow/__init__.py +0 -0
- jcclang-0.1.1/jcclang/examples/workflow/workflow01.py +21 -0
- jcclang-0.1.1/jcclang/nodes/__init__.py +18 -0
- jcclang-0.1.1/jcclang/nodes/base_node.py +191 -0
- jcclang-0.1.1/jcclang/nodes/bind_node.py +24 -0
- jcclang-0.1.1/jcclang/nodes/data_return_node.py +23 -0
- jcclang-0.1.1/jcclang/nodes/end_node.py +21 -0
- jcclang-0.1.1/jcclang/nodes/start_node.py +21 -0
- jcclang-0.1.1/jcclang/nodes/train_node.py +33 -0
- jcclang-0.1.1/jcclang/tests/__init__.py +0 -0
- jcclang-0.1.1/jcclang/tests/decorator_test.py +11 -0
- jcclang-0.1.1/jcclang/tests/input_test.py +20 -0
- jcclang-0.1.1/jcclang/tests/io_test.py +15 -0
- jcclang-0.1.1/jcclang/tests/job_hub_rev_var_test.py +79 -0
- jcclang-0.1.1/jcclang/tests/job_hub_stream_test.py +54 -0
- jcclang-0.1.1/jcclang/tests/job_hub_var_sync_test.py +90 -0
- jcclang-0.1.1/jcclang/tests/job_hub_var_test.py +96 -0
- jcclang-0.1.1/jcclang/tests/node_test.py +52 -0
- jcclang-0.1.1/jcclang/tests/node_test2.py +131 -0
- jcclang-0.1.1/jcclang/tests/node_test3.py +104 -0
- jcclang-0.1.1/jcclang/tests/node_test4.py +136 -0
- jcclang-0.1.1/jcclang/tests/node_test5.py +123 -0
- jcclang-0.1.1/jcclang/tests/output_test.py +13 -0
- jcclang-0.1.1/jcclang/tests/remotefile/__init__.py +0 -0
- jcclang-0.1.1/jcclang/tests/remotefile/common_dataset.py +40 -0
- jcclang-0.1.1/jcclang/tests/remotefile/common_demo.py +18 -0
- jcclang-0.1.1/jcclang/tests/remotefile/driver.py +94 -0
- jcclang-0.1.1/jcclang/tests/remotefile/jcweaver_dataset.py +29 -0
- jcclang-0.1.1/jcclang/tests/remotefile/jcweaver_tf_dataset.py +47 -0
- jcclang-0.1.1/jcclang/tests/remotefile/local_test.py +6 -0
- jcclang-0.1.1/jcclang/tests/remotefile/pt_demo.py +22 -0
- jcclang-0.1.1/jcclang/tests/remotefile/tensorfflow_demo.py +19 -0
- jcclang-0.1.1/jcclang/tests/remotefile/tf_demo.py +27 -0
- jcclang-0.1.1/jcclang/tests/split_test.py +16 -0
- jcclang-0.1.1/jcclang/tests/virtualfile/pandas_dirct_demo.py +189 -0
- jcclang-0.1.1/jcclang/virtualfile/__init__.py +0 -0
- jcclang-0.1.1/jcclang/virtualfile/block_fetcher.py +78 -0
- jcclang-0.1.1/jcclang/virtualfile/block_fetcher_test.py +15 -0
- jcclang-0.1.1/jcclang/virtualfile/cache_mgr.py +427 -0
- jcclang-0.1.1/jcclang/virtualfile/driver/__init__.py +0 -0
- jcclang-0.1.1/jcclang/virtualfile/driver/base_driver.py +29 -0
- jcclang-0.1.1/jcclang/virtualfile/driver/jcs.py +89 -0
- jcclang-0.1.1/jcclang/virtualfile/driver/jcs_test.py +12 -0
- jcclang-0.1.1/jcclang/virtualfile/virtual_file.py +231 -0
- jcclang-0.1.1/jcclang.egg-info/PKG-INFO +9 -0
- jcclang-0.1.1/jcclang.egg-info/SOURCES.txt +129 -0
- jcclang-0.1.1/jcclang.egg-info/dependency_links.txt +1 -0
- jcclang-0.1.1/jcclang.egg-info/top_level.txt +1 -0
- jcclang-0.1.1/pyproject.toml +22 -0
- jcclang-0.1.1/setup.cfg +4 -0
jcclang-0.1.1/PKG-INFO
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jcclang
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: A platform-adaptive task framework for cloud-edge scenarios
|
|
5
|
+
Author-email: jeshua <ren1366929814@gmail.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Operating System :: OS Independent
|
|
8
|
+
Requires-Python: >=3.7
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .modelarts import ModelArtsAdapter
|
|
2
|
+
from .octopus import OctopusAdapter
|
|
3
|
+
from .openi import OpenIAdapter
|
|
4
|
+
from ..core.const import Platform
|
|
5
|
+
from ..core.logger import jcwLogger
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_adapter(platform: str):
|
|
9
|
+
platform = platform.lower()
|
|
10
|
+
if platform == Platform.MODELARTS:
|
|
11
|
+
return ModelArtsAdapter()
|
|
12
|
+
elif platform == Platform.OPENI:
|
|
13
|
+
return OpenIAdapter()
|
|
14
|
+
elif platform == Platform.OCTOPUS:
|
|
15
|
+
return OctopusAdapter()
|
|
16
|
+
jcwLogger.error(f"Unsupported platform: {platform}")
|
|
17
|
+
return None
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from jcclang.adapter.base_adapter import BaseAdapter
|
|
5
|
+
from jcclang.core.const import DataType
|
|
6
|
+
from jcclang.core.logger import jcwLogger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelArtsAdapter(BaseAdapter):
|
|
10
|
+
def __init__(self):
|
|
11
|
+
self.parser = argparse.ArgumentParser(description='Model Training with input parameter')
|
|
12
|
+
self.output = ""
|
|
13
|
+
|
|
14
|
+
def before_task(self, inputs, context: dict):
|
|
15
|
+
jcwLogger.info("modelarts before task")
|
|
16
|
+
|
|
17
|
+
def after_task(self, outputs, context: dict):
|
|
18
|
+
jcwLogger.info("modelarts after task, output", self.output)
|
|
19
|
+
|
|
20
|
+
def input_prepare(self, data_type: str, file_path: str):
|
|
21
|
+
if data_type == DataType.DATASET:
|
|
22
|
+
if not any(a.dest == 'dataset_input' for a in self.parser._actions):
|
|
23
|
+
self.parser.add_argument('--dataset_input', default='./data', type=str,
|
|
24
|
+
help='dataset_input (default: %(default)s)')
|
|
25
|
+
elif data_type == DataType.MODEL:
|
|
26
|
+
if not any(a.dest == 'model_input' for a in self.parser._actions):
|
|
27
|
+
self.parser.add_argument('--model_input', default='./models', type=str,
|
|
28
|
+
help='model_input (default: %(default)s)')
|
|
29
|
+
elif data_type == DataType.CODE:
|
|
30
|
+
if not any(a.dest == 'code_input' for a in self.parser._actions):
|
|
31
|
+
self.parser.add_argument('--code_input', default='./src', type=str,
|
|
32
|
+
help='code_input (default: %(default)s)')
|
|
33
|
+
else:
|
|
34
|
+
jcwLogger.error(f"Unknown data type for input: {data_type}")
|
|
35
|
+
return ""
|
|
36
|
+
|
|
37
|
+
args, _ = self.parser.parse_known_args()
|
|
38
|
+
|
|
39
|
+
base_path = ""
|
|
40
|
+
if data_type == DataType.DATASET:
|
|
41
|
+
base_path = args.dataset_input
|
|
42
|
+
elif data_type == DataType.MODEL:
|
|
43
|
+
base_path = args.model_input
|
|
44
|
+
elif data_type == DataType.CODE:
|
|
45
|
+
base_path = args.code_input
|
|
46
|
+
|
|
47
|
+
path = Path(base_path) / file_path
|
|
48
|
+
return path
|
|
49
|
+
|
|
50
|
+
def output_prepare(self, data_type: str, file_path: str):
|
|
51
|
+
if not any(a.dest == 'output' for a in self.parser._actions):
|
|
52
|
+
self.parser.add_argument('--output', default='/output', type=str,
|
|
53
|
+
help='output (default: %(default)s)')
|
|
54
|
+
args, _ = self.parser.parse_known_args()
|
|
55
|
+
path = Path(args.output) / file_path
|
|
56
|
+
return path.as_posix()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from jcclang.adapter.base_adapter import BaseAdapter
|
|
4
|
+
from jcclang.core.const import DataType
|
|
5
|
+
from jcclang.core.logger import jcwLogger
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OctopusAdapter(BaseAdapter):
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.output = ""
|
|
11
|
+
|
|
12
|
+
def before_task(self, inputs, context: dict):
|
|
13
|
+
jcwLogger.info("execute before task")
|
|
14
|
+
|
|
15
|
+
def after_task(self, outputs, context: dict):
|
|
16
|
+
jcwLogger.info("execute after task")
|
|
17
|
+
|
|
18
|
+
def input_prepare(self, data_type: str, file_path: str):
|
|
19
|
+
if data_type == DataType.DATASET:
|
|
20
|
+
data_path = os.environ.get("dataset_input")
|
|
21
|
+
if not data_path:
|
|
22
|
+
jcwLogger.error("dataset_input is not set")
|
|
23
|
+
return ""
|
|
24
|
+
return os.path.join(data_path, file_path)
|
|
25
|
+
|
|
26
|
+
if data_type == DataType.MODEL:
|
|
27
|
+
data_path = os.environ.get("model_input")
|
|
28
|
+
if not data_path:
|
|
29
|
+
jcwLogger.error("model_input is not set")
|
|
30
|
+
return ""
|
|
31
|
+
return os.path.join(data_path, file_path)
|
|
32
|
+
|
|
33
|
+
if data_type == DataType.CODE:
|
|
34
|
+
data_path = os.environ.get("code_input")
|
|
35
|
+
if not data_path:
|
|
36
|
+
jcwLogger.error("code_input is not set")
|
|
37
|
+
return ""
|
|
38
|
+
return os.path.join(data_path, file_path)
|
|
39
|
+
jcwLogger.error(f"Unknown data type for input: {data_type}")
|
|
40
|
+
return ""
|
|
41
|
+
|
|
42
|
+
def output_prepare(self, data_type: str, file_path: str):
|
|
43
|
+
data_path = os.environ.get("output", "./output")
|
|
44
|
+
|
|
45
|
+
if not os.path.exists(data_path):
|
|
46
|
+
jcwLogger.info(f"create output directory: {data_path}")
|
|
47
|
+
os.makedirs(data_path)
|
|
48
|
+
|
|
49
|
+
return os.path.join(data_path, file_path)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from jcclang.adapter.base_adapter import BaseAdapter
|
|
5
|
+
from jcclang.core.const import DataType
|
|
6
|
+
from jcclang.core.logger import jcwLogger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpenIAdapter(BaseAdapter):
|
|
10
|
+
def __init__(self):
|
|
11
|
+
from c2net.context import prepare, upload_output
|
|
12
|
+
self._prepare = prepare()
|
|
13
|
+
self._upload_output = upload_output
|
|
14
|
+
self.output = ""
|
|
15
|
+
self.parser = argparse.ArgumentParser(description='Model Training with input parameter')
|
|
16
|
+
|
|
17
|
+
def before_task(self, inputs, context: dict):
|
|
18
|
+
jcwLogger.info("execute before task")
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
def after_task(self, outputs, context: dict):
|
|
22
|
+
jcwLogger.info("execute after task")
|
|
23
|
+
self._upload_output()
|
|
24
|
+
|
|
25
|
+
def input_prepare(self, data_type: str, file_path: str):
|
|
26
|
+
if data_type == DataType.DATASET:
|
|
27
|
+
self.parser.add_argument('--dataset_input', default='./data', type=str,
|
|
28
|
+
help='dataset_input (default: %(default)s)')
|
|
29
|
+
args, _ = self.parser.parse_known_args()
|
|
30
|
+
name_no_ext, _ = os.path.splitext(args.dataset_input)
|
|
31
|
+
return os.path.join(self._prepare.dataset_path, name_no_ext)
|
|
32
|
+
|
|
33
|
+
if data_type == DataType.MODEL:
|
|
34
|
+
self.parser.add_argument('--model_input', default='./data', type=str,
|
|
35
|
+
help='dataset_input (default: %(default)s)')
|
|
36
|
+
args, _ = self.parser.parse_known_args()
|
|
37
|
+
name_no_ext, _ = os.path.splitext(args.model_input)
|
|
38
|
+
return os.path.join(self._prepare.pretrain_model_path, name_no_ext)
|
|
39
|
+
if data_type == DataType.CODE:
|
|
40
|
+
self.parser.add_argument('--code_input', default='./data', type=str,
|
|
41
|
+
help='dataset_input (default: %(default)s)')
|
|
42
|
+
args, _ = self.parser.parse_known_args()
|
|
43
|
+
name_no_ext, _ = os.path.splitext(args.code_input)
|
|
44
|
+
return os.path.join(self._prepare.code_path, name_no_ext)
|
|
45
|
+
jcwLogger.error(f"Unknown data type for input: {data_type}")
|
|
46
|
+
return ""
|
|
47
|
+
|
|
48
|
+
def output_prepare(self, data_type: str, file_path: str):
|
|
49
|
+
self.output = os.path.join(self._prepare.output_path, file_path)
|
|
50
|
+
return self.output
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import trio
|
|
2
|
+
from api.jobhub import JobHub
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Context:
|
|
6
|
+
def __init__(
|
|
7
|
+
self, nursery: trio.Nursery, job_set_id: str, job_id: str, job_hub: JobHub
|
|
8
|
+
):
|
|
9
|
+
self._nursery = nursery
|
|
10
|
+
self._job_set_id = job_set_id
|
|
11
|
+
self._job_id = job_id
|
|
12
|
+
self._job_hub = job_hub
|
|
13
|
+
|
|
14
|
+
def nursery(self):
|
|
15
|
+
return self._nursery
|
|
16
|
+
|
|
17
|
+
def job_set_id(self) -> str:
|
|
18
|
+
return self._job_set_id
|
|
19
|
+
|
|
20
|
+
def job_id(self) -> str:
|
|
21
|
+
return self._job_id
|
|
22
|
+
|
|
23
|
+
def job_hub(self) -> JobHub:
|
|
24
|
+
return self._job_hub
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from safetensors.torch import safe_open
|
|
5
|
+
|
|
6
|
+
# ============================
|
|
7
|
+
# 1️⃣ 模型配置类
|
|
8
|
+
# ============================
|
|
9
|
+
|
|
10
|
+
class ModelConfig:
|
|
11
|
+
"""
|
|
12
|
+
简单配置类,描述模型结构
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len):
|
|
15
|
+
self.vocab_size = vocab_size
|
|
16
|
+
self.hidden_size = hidden_size
|
|
17
|
+
self.num_layers = num_layers
|
|
18
|
+
self.num_heads = num_heads
|
|
19
|
+
self.max_seq_len = max_seq_len
|
|
20
|
+
|
|
21
|
+
# ============================
|
|
22
|
+
# 2️⃣ 自定义模型类
|
|
23
|
+
# ============================
|
|
24
|
+
|
|
25
|
+
class GPTLayer(nn.Module):
|
|
26
|
+
def __init__(self, hidden_size, num_heads):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)
|
|
29
|
+
self.ffn = nn.Sequential(
|
|
30
|
+
nn.Linear(hidden_size, 4 * hidden_size),
|
|
31
|
+
nn.GELU(),
|
|
32
|
+
nn.Linear(4 * hidden_size, hidden_size)
|
|
33
|
+
)
|
|
34
|
+
self.ln1 = nn.LayerNorm(hidden_size)
|
|
35
|
+
self.ln2 = nn.LayerNorm(hidden_size)
|
|
36
|
+
|
|
37
|
+
def forward(self, x):
|
|
38
|
+
attn_out, _ = self.attn(x, x, x)
|
|
39
|
+
x = self.ln1(x + attn_out)
|
|
40
|
+
x = self.ln2(x + self.ffn(x))
|
|
41
|
+
return x
|
|
42
|
+
|
|
43
|
+
class GPTModel(nn.Module):
|
|
44
|
+
def __init__(self, config: ModelConfig):
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.config = config
|
|
47
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
48
|
+
self.layers = nn.ModuleList([GPTLayer(config.hidden_size, config.num_heads)
|
|
49
|
+
for _ in range(config.num_layers)])
|
|
50
|
+
self.ln_f = nn.LayerNorm(config.hidden_size)
|
|
51
|
+
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
52
|
+
|
|
53
|
+
def forward(self, input_ids):
|
|
54
|
+
x = self.embed_tokens(input_ids)
|
|
55
|
+
for layer in self.layers:
|
|
56
|
+
x = layer(x)
|
|
57
|
+
x = self.ln_f(x)
|
|
58
|
+
logits = self.head(x)
|
|
59
|
+
return logits
|
|
60
|
+
|
|
61
|
+
# ============================
|
|
62
|
+
# 3️⃣ 流式 Safetensors 权重加载器
|
|
63
|
+
# ============================
|
|
64
|
+
|
|
65
|
+
class StreamingSafeTensorsLoader:
|
|
66
|
+
"""
|
|
67
|
+
支持按需加载 safetensors 权重到 CPU/GPU
|
|
68
|
+
"""
|
|
69
|
+
def __init__(self, filename: str, device="cpu"):
|
|
70
|
+
self.filename = filename
|
|
71
|
+
self.device = device
|
|
72
|
+
self._file = safe_open(filename, framework="pt")
|
|
73
|
+
self.keys = self._file.keys()
|
|
74
|
+
|
|
75
|
+
def load_tensor(self, name: str, device=None):
|
|
76
|
+
"""按需加载单个 tensor"""
|
|
77
|
+
dev = device or self.device
|
|
78
|
+
tensor = self._file.get_tensor(name).to(dev)
|
|
79
|
+
return tensor
|
|
80
|
+
|
|
81
|
+
def load_state_dict(self, model: nn.Module, device=None):
|
|
82
|
+
"""
|
|
83
|
+
按 tensor 名称加载权重到模型
|
|
84
|
+
支持流式加载:可选择只加载部分层
|
|
85
|
+
"""
|
|
86
|
+
dev = device or self.device
|
|
87
|
+
state_dict = {}
|
|
88
|
+
for name, param in model.named_parameters():
|
|
89
|
+
if name in self.keys:
|
|
90
|
+
tensor = self._file.get_tensor(name).to(dev)
|
|
91
|
+
state_dict[name] = tensor
|
|
92
|
+
else:
|
|
93
|
+
print(f"[Warning] weight {name} not found in safetensors")
|
|
94
|
+
model.load_state_dict(state_dict, strict=False)
|
|
95
|
+
|
|
96
|
+
# ============================
|
|
97
|
+
# 4️⃣ 使用示例
|
|
98
|
+
# ============================
|
|
99
|
+
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import io
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
from jcclang.core.model import VirtualFileParams, Source
|
|
6
|
+
from jcclang.virtualfile.virtual_file import VirtualFile
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Pandas:
|
|
10
|
+
"""虚拟化的 pandas I/O 层"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, source: Source, virtual_file_params: VirtualFileParams = None):
|
|
13
|
+
if virtual_file_params is None:
|
|
14
|
+
virtual_file_params = VirtualFileParams()
|
|
15
|
+
self.virtual_file_params = virtual_file_params
|
|
16
|
+
self.source = source
|
|
17
|
+
|
|
18
|
+
def _open_virtual(self):
|
|
19
|
+
"""返回一个基于 VirtualFile 的文件对象"""
|
|
20
|
+
vf = VirtualFile(source=self.source, params=self.virtual_file_params)
|
|
21
|
+
return vf
|
|
22
|
+
|
|
23
|
+
def _read_bytes(self) -> bytes:
|
|
24
|
+
"""统一从虚拟文件系统读取数据"""
|
|
25
|
+
vf = self._open_virtual()
|
|
26
|
+
data = vf.read()
|
|
27
|
+
vf.close()
|
|
28
|
+
return data
|
|
29
|
+
|
|
30
|
+
# ==============================
|
|
31
|
+
# Pandas-like API
|
|
32
|
+
# ==============================
|
|
33
|
+
|
|
34
|
+
def read_csv(self, **kwargs) -> pd.DataFrame:
|
|
35
|
+
raw = self._read_bytes()
|
|
36
|
+
buf = io.BytesIO(raw)
|
|
37
|
+
return pd.read_csv(buf, **kwargs)
|
|
38
|
+
|
|
39
|
+
def read_json(self, **kwargs) -> pd.DataFrame:
|
|
40
|
+
raw = self._read_bytes()
|
|
41
|
+
buf = io.BytesIO(raw)
|
|
42
|
+
return pd.read_json(buf, **kwargs)
|
|
43
|
+
|
|
44
|
+
def read_parquet(self, **kwargs) -> pd.DataFrame:
|
|
45
|
+
raw = self._read_bytes()
|
|
46
|
+
buf = io.BytesIO(raw)
|
|
47
|
+
return pd.read_parquet(buf, **kwargs)
|
|
48
|
+
|
|
49
|
+
def read_excel(self, **kwargs) -> pd.DataFrame:
|
|
50
|
+
raw = self._read_bytes()
|
|
51
|
+
buf = io.BytesIO(raw)
|
|
52
|
+
return pd.read_excel(buf, **kwargs)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
from jcclang.core.model import Sources, VirtualFileParams
|
|
3
|
+
from jcclang.virtualfile.virtual_file import VirtualFile
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Dataset:
|
|
7
|
+
def __init__(self, sources: Sources, decoder=None,
|
|
8
|
+
virtual_file_params: VirtualFileParams = None):
|
|
9
|
+
"""
|
|
10
|
+
:param sources: Sources 对象
|
|
11
|
+
:param decode_fn: 数据解码函数,输入 raw bytes,输出 numpy / tensor
|
|
12
|
+
:param virtual_file_params: 虚拟文件参数
|
|
13
|
+
"""
|
|
14
|
+
if virtual_file_params is None:
|
|
15
|
+
virtual_file_params = VirtualFileParams()
|
|
16
|
+
self.virtual_file_params = virtual_file_params
|
|
17
|
+
self.sources = sources
|
|
18
|
+
self.decode_fn = decoder or (lambda x: x) # 默认不解码
|
|
19
|
+
|
|
20
|
+
def generator(self):
|
|
21
|
+
"""
|
|
22
|
+
数据生成器,每次 yield (sample, label)
|
|
23
|
+
"""
|
|
24
|
+
for info in self.sources.items:
|
|
25
|
+
# 读取数据
|
|
26
|
+
vf = VirtualFile(info, params=self.virtual_file_params)
|
|
27
|
+
raw = vf.read()
|
|
28
|
+
vf.close()
|
|
29
|
+
|
|
30
|
+
# 解码
|
|
31
|
+
sample = self.decode_fn(raw)
|
|
32
|
+
|
|
33
|
+
# 输出 sample 和 label
|
|
34
|
+
yield sample, info.label
|
|
35
|
+
|
|
36
|
+
def to_tf_dataset(self, output_types=tf.float32, output_shapes=None, batch_size=32, shuffle=True):
|
|
37
|
+
"""
|
|
38
|
+
转为 tf.data.Dataset
|
|
39
|
+
:param output_types: 输出类型,可以是 tf.float32, tf.int32 等
|
|
40
|
+
:param output_shapes: 输出形状,如 (28,28) 或 (None,)
|
|
41
|
+
:param batch_size: 批大小
|
|
42
|
+
:param shuffle: 是否打乱
|
|
43
|
+
"""
|
|
44
|
+
ds = tf.data.Dataset.from_generator(
|
|
45
|
+
self.generator,
|
|
46
|
+
output_types=(output_types, tf.int32),
|
|
47
|
+
output_shapes=(output_shapes, ())
|
|
48
|
+
)
|
|
49
|
+
if shuffle:
|
|
50
|
+
ds = ds.shuffle(buffer_size=len(self.sources.items))
|
|
51
|
+
ds = ds.batch(batch_size)
|
|
52
|
+
ds = ds.prefetch(tf.data.AUTOTUNE)
|
|
53
|
+
return ds
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from io import BytesIO
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils.data import Dataset as Ds
|
|
5
|
+
|
|
6
|
+
from jcclang.core.model import VirtualFileParams, Sources
|
|
7
|
+
from jcclang.virtualfile.virtual_file import VirtualFile
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Dataset(Ds):
|
|
11
|
+
def __init__(self, sources: Sources, transform=None, decoder=None,
|
|
12
|
+
virtual_file_params: VirtualFileParams = None):
|
|
13
|
+
if virtual_file_params is None:
|
|
14
|
+
virtual_file_params = VirtualFileParams()
|
|
15
|
+
self.virtual_file_params = virtual_file_params
|
|
16
|
+
self.sources = sources
|
|
17
|
+
self.transform = transform
|
|
18
|
+
self.decoder = decoder or (lambda x: x)
|
|
19
|
+
|
|
20
|
+
def __len__(self):
|
|
21
|
+
return len(self.sources.items)
|
|
22
|
+
|
|
23
|
+
def __getitem__(self, idx):
|
|
24
|
+
info = self.sources.items[idx]
|
|
25
|
+
|
|
26
|
+
# 读取数据
|
|
27
|
+
vf = VirtualFile(info, params=self.virtual_file_params)
|
|
28
|
+
raw = vf.read()
|
|
29
|
+
vf.close()
|
|
30
|
+
|
|
31
|
+
# 解码
|
|
32
|
+
sample = self.decoder(raw)
|
|
33
|
+
|
|
34
|
+
# transform
|
|
35
|
+
if self.transform:
|
|
36
|
+
sample = self.transform(sample)
|
|
37
|
+
|
|
38
|
+
return sample, info.label
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class JCWeaverModel:
|
|
42
|
+
"""
|
|
43
|
+
通过 JCWeaver VirtualFile 加载模型权重,支持 PyTorch
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, sources: Sources, virtual_file_params: VirtualFileParams = None):
|
|
47
|
+
if virtual_file_params is None:
|
|
48
|
+
virtual_file_params = VirtualFileParams()
|
|
49
|
+
self.virtual_file_params = virtual_file_params
|
|
50
|
+
self.sources = sources
|
|
51
|
+
|
|
52
|
+
def load_state_dict(self, model_class, map_location="cpu"):
|
|
53
|
+
"""
|
|
54
|
+
从 JCWeaver 读取模型数据,并加载到指定 PyTorch model_class
|
|
55
|
+
"""
|
|
56
|
+
vf = VirtualFile(info, params=self.virtual_file_params)
|
|
57
|
+
raw = vf.read()
|
|
58
|
+
vf.close()
|
|
59
|
+
|
|
60
|
+
buf = BytesIO(raw)
|
|
61
|
+
state_dict = torch.load(buf, map_location=map_location)
|
|
62
|
+
if isinstance(model_class, type):
|
|
63
|
+
model = model_class()
|
|
64
|
+
else:
|
|
65
|
+
model = model_class
|
|
66
|
+
model.load_state_dict(state_dict)
|
|
67
|
+
return model
|
|
68
|
+
|
|
69
|
+
def load_torch_model(self, map_location="cpu"):
|
|
70
|
+
"""
|
|
71
|
+
直接读取 PyTorch 完整模型对象
|
|
72
|
+
"""
|
|
73
|
+
vf = VirtualFile(info, params=self.virtual_file_params)
|
|
74
|
+
raw = vf.read()
|
|
75
|
+
vf.close()
|
|
76
|
+
buf = BytesIO(raw)
|
|
77
|
+
model = torch.load(buf, map_location=map_location)
|
|
78
|
+
return model
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import io
|
|
2
|
+
|
|
3
|
+
from jcclang.core.model import Sources, VirtualFileParams
|
|
4
|
+
from jcclang.virtualfile.virtual_file import VirtualFile
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class VirtualFileIO(io.RawIOBase):
|
|
8
|
+
"""
|
|
9
|
+
用 VirtualFile 封装成类文件对象,支持 seek/tell/read。
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, vf: VirtualFile):
|
|
13
|
+
self.vf = vf
|
|
14
|
+
self.closed_flag = False
|
|
15
|
+
|
|
16
|
+
def read(self, size=-1):
|
|
17
|
+
return self.vf.read(size)
|
|
18
|
+
|
|
19
|
+
def seek(self, offset, whence=0):
|
|
20
|
+
self.vf.seek(offset, whence)
|
|
21
|
+
return self.vf.tell()
|
|
22
|
+
|
|
23
|
+
def tell(self):
|
|
24
|
+
return self.vf.tell()
|
|
25
|
+
|
|
26
|
+
def close(self):
|
|
27
|
+
if not self.closed_flag:
|
|
28
|
+
self.vf.close()
|
|
29
|
+
self.closed_flag = True
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def closed(self):
|
|
33
|
+
return self.closed_flag
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def virtual_cached_file(pretrained_model_name_or_path, filename, **kwargs):
|
|
37
|
+
"""
|
|
38
|
+
替代 transformers.utils.cached_file。
|
|
39
|
+
从 VirtualFile 系统获取指定文件内容,返回 BytesIO 对象。
|
|
40
|
+
"""
|
|
41
|
+
sources: Sources = kwargs.pop("sources")
|
|
42
|
+
vparams: VirtualFileParams = kwargs.pop("vparams")
|
|
43
|
+
|
|
44
|
+
# 找到对应 Source
|
|
45
|
+
src_map = {s.path: s for s in sources.items}
|
|
46
|
+
if filename not in src_map:
|
|
47
|
+
# 尝试只匹配文件名(去掉路径)
|
|
48
|
+
file_name_only_map = {s.path.split("/")[-1]: s for s in sources.items}
|
|
49
|
+
src = file_name_only_map.get(filename)
|
|
50
|
+
if src is None:
|
|
51
|
+
raise FileNotFoundError(f"{filename} not found in virtual sources")
|
|
52
|
+
else:
|
|
53
|
+
src = src_map[filename]
|
|
54
|
+
|
|
55
|
+
vf = VirtualFile(src, params=vparams)
|
|
56
|
+
# 直接返回类文件对象,支持 seek/read
|
|
57
|
+
return VirtualFileIO(vf)
|