nodev 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nodev/utils.py ADDED
@@ -0,0 +1,172 @@
1
+
2
+ import os
3
+ import subprocess
4
+ def _dot_func(f):
5
+ txt = ''
6
+ func_name = f.__class__.__name__
7
+
8
+ # 函数节点(蓝色)
9
+ txt += '{} [label="{}", color=lightblue, style=filled]\n'.format(
10
+ id(f), func_name
11
+ )
12
+
13
+ # 输入 -> 函数
14
+ for x in f.inputs:
15
+ txt += '{} -> {}\n'.format(id(x), id(f))
16
+
17
+ # 函数 -> 输出
18
+ for y in f.outputs:
19
+ txt += '{} -> {}\n'.format(id(f), id(y()))
20
+
21
+ return txt
22
+
23
+ def _dot_var(v,verbose=False):
24
+ dot_var = '{} [label="{}", color=orange, style=filled]\n'
25
+ name=''if v.name is None else v.name
26
+ if verbose and v.data is not None:
27
+ if v.name is not None:
28
+ name+=':'
29
+ name+=str(v.shape)+' '+str(v.dtype)
30
+ return dot_var.format(id(v),name)
31
+ def max_backward_shape(x_shape, axis):
32
+ if axis is None:
33
+ return (1,) * len(x_shape)
34
+
35
+ if not isinstance(axis, tuple):
36
+ axis = (axis,)
37
+
38
+ shape = list(x_shape)
39
+
40
+ for a in axis:
41
+ shape[a] = 1
42
+
43
+ return tuple(shape)
44
+ def get_dot_graph(output,verbose=True):
45
+ txt=''
46
+ funcs=[]
47
+ seen_set=set()
48
+
49
+ def add_func(f):
50
+ if f not in seen_set:
51
+ funcs.append(f)
52
+ seen_set.add(f)
53
+ add_func(output.creator)
54
+ txt+=_dot_var(output,verbose)
55
+ while funcs:
56
+ func=funcs.pop()
57
+ txt+=_dot_func(func)
58
+
59
+ for x in func.inputs:
60
+ txt+=_dot_var(x,verbose)
61
+
62
+ if x.creator is not None:
63
+ add_func(x.creator)
64
+ print(txt)
65
+ return 'digraph g {\n' + txt + '}'
66
+ def reshape_sum_backward(gy, x_shape, axis, keepdims):
67
+ ndim = len(x_shape)
68
+
69
+ if axis is None:
70
+ return gy.reshape((1,) * ndim)
71
+
72
+ if not isinstance(axis, tuple):
73
+ axis = (axis,)
74
+
75
+ if not keepdims:
76
+ actual_axis = [a if a >= 0 else a + ndim for a in axis]
77
+ shape = list(gy.shape)
78
+
79
+ for a in sorted(actual_axis):
80
+ shape.insert(a, 1)
81
+
82
+ gy = gy.reshape(shape)
83
+
84
+ return gy
85
+ def plot_dot_graph(output,verbose=True,to_file='graph.png'):
86
+ dot_graph=get_dot_graph(output,verbose)
87
+
88
+ tmp_dir=os.path.join(os.path.expanduser('~'),'.nodev')
89
+ if not os.path.exists(tmp_dir):
90
+ os.mkdir(tmp_dir)
91
+ graph_path=os.path.join(tmp_dir,'tmp_graph.dot')
92
+
93
+ with open(graph_path,'w') as f:
94
+ f.write(dot_graph)
95
+
96
+ extension=os.path.splitext(to_file)[1][1:]
97
+ cmd = 'dot -T{} {} -o {}'.format(extension, graph_path, to_file)
98
+ subprocess.run(cmd,shell=True)
99
+ print(cmd)
100
+
101
+
102
+ def sum_to(x, shape):
103
+ """
104
+ 将数组 x 沿着广播过的维度求和,使其形状变为 shape
105
+ """
106
+ ndim = len(shape)
107
+ lead = x.ndim - ndim
108
+
109
+ # 需要被压缩(求和)的前导轴
110
+ lead_axis = tuple(range(lead))
111
+
112
+ # 在 shape 中为 1 的轴,也需要被求和
113
+ axis = tuple(
114
+ i + lead for i, sx in enumerate(shape) if sx == 1
115
+ )
116
+
117
+ y = x.sum(lead_axis + axis, keepdims=True)
118
+
119
+ # 去掉多余的维度
120
+ if lead > 0:
121
+ y = y.squeeze(lead_axis)
122
+
123
+ return y #补充函数,正文没写
124
+
125
+
126
+ import os
127
+ import urllib.request
128
+
129
+ cache_dir = os.path.join(os.path.expanduser('~'), '.dezero')
130
+
131
+ def pair(x):
132
+ if isinstance(x,int):
133
+ return (x,x)
134
+ elif isinstance(x,tuple):
135
+ assert len(x)==2
136
+ return x
137
+ else:
138
+ raise ValueError
139
+ def get_file(url, file_name=None):
140
+ """Download a file if it is not in the cache."""
141
+ cache_dir = "../data"
142
+ if file_name is None:
143
+ file_name = url[url.rfind('/') + 1:]
144
+
145
+ file_path = os.path.join(cache_dir, file_name)
146
+
147
+ # 如果缓存目录不存在就创建
148
+ if not os.path.exists(cache_dir):
149
+ os.mkdir(cache_dir)
150
+
151
+ # 如果文件已经存在直接返回
152
+ if os.path.exists(file_path):
153
+ return file_path
154
+
155
+ print("加载: " + file_name)
156
+
157
+ try:
158
+ urllib.request.urlretrieve(url, file_path)
159
+ except (Exception, KeyboardInterrupt):
160
+ # 如果下载失败删除不完整文件
161
+ if os.path.exists(file_path):
162
+ os.remove(file_path)
163
+ raise
164
+
165
+ print("Done")
166
+
167
+ return file_path
168
+ def get_conv_outsize(size, k, s, p):
169
+ return (size + 2 * p - k) // s + 1
170
+
171
+
172
+ #print(pair(1),pair((1,4)))
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.1
2
+ Name: nodev
3
+ Version: 1.0.0
4
+ Summary: 一个轻量级深度学习第三方库,2026年设计。
5
+ Author-email: Yang Yaqing <youngtoothh@foxmail.com>
6
+ License: MIT
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.8
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: requests >=2.28.0
13
+ Requires-Dist: numpy
14
+
@@ -0,0 +1,17 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ nodev/__init__.py,sha256=_6TL6fjt94ICY750euKt7Fy6dGC5vHy5mQoR1v4arEw,597
3
+ nodev/core.py,sha256=BZN6ewqtGj49cKrFLC_PYGAtfuQhzQiXFH6nkTXn4Xc,9215
4
+ nodev/cuda.py,sha256=wone0xXTZdznQ0f9cZ5iW9JPgMrAk6-ePsVaPvDOoks,630
5
+ nodev/dataloaders.py,sha256=mxs4FGJvBmrFAZ1TVsCxaLUxbmUlTaNwhn-BxRwJEPk,2000
6
+ nodev/dataset.py,sha256=AOUMciTaCDMxpJq-wTCN_D_pLu5uo6uvRDPiP0hgs80,5705
7
+ nodev/layers.py,sha256=KLbP9zkyzFQy-QMD5ntWSq0cuIgrcNFRRhQF8Bh-ZLg,5879
8
+ nodev/model.py,sha256=Y5mVSD1L3367vKMUgcfUQ2lN8qy7BXv2B5BilBSLZS4,4312
9
+ nodev/operations.py,sha256=0-wtkfeyGWCdfXl9rNFMF21aK8q_Dr6JEEbIwS8o3pk,10293
10
+ nodev/operations_conv.py,sha256=PwzrfX5fyHwwJLKKCmxJd9_3rWMjuU_2SXqyeNoU_e4,15855
11
+ nodev/optimizer.py,sha256=nLeZg0STkSl0q4N9mCAE5UJGyhJNVCwdIaPxyZoMvSw,1874
12
+ nodev/transforms.py,sha256=66Y1CslwoIfMVPuWt1fh1oDl4_lJEXAaWwiR1dPllJk,1305
13
+ nodev/utils.py,sha256=SrPbOdSCORFAEUY0_GVro_4MYW24uzHNO3H0g7wDe_o,4321
14
+ nodev-1.0.0.dist-info/METADATA,sha256=Isew386qYzm2p4aWrz77sSJ2K9GZrVp7AS3xxAMa5CM,460
15
+ nodev-1.0.0.dist-info/WHEEL,sha256=BNRMDyzLkkcmlv0J8ppDQkk2VED33SesJDynr9ED1gc,91
16
+ nodev-1.0.0.dist-info/top_level.txt,sha256=mV5FbarWYa_abJuO6m6ZhbgC3rw8Lig5WBvL3xtHa_A,15
17
+ nodev-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.3.4)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ __init__
2
+ nodev