congrads 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/descriptor.py
CHANGED
|
@@ -122,7 +122,7 @@ class Descriptor:
|
|
|
122
122
|
"""
|
|
123
123
|
return tag in self._tag_to_key and tag in self._tag_to_index
|
|
124
124
|
|
|
125
|
-
def location(self, tag: str) -> tuple[str, int]:
|
|
125
|
+
def location(self, tag: str) -> tuple[str, int | None]:
|
|
126
126
|
"""Get the key and index for a given tag.
|
|
127
127
|
|
|
128
128
|
Looks up the mapping for a registered tag and returns the associated
|
|
@@ -132,16 +132,16 @@ class Descriptor:
|
|
|
132
132
|
tag (str): The tag identifier. Must be registered.
|
|
133
133
|
|
|
134
134
|
Returns:
|
|
135
|
-
tuple ((str, int)): A tuple containing:
|
|
135
|
+
tuple ((str, int | None)): A tuple containing:
|
|
136
136
|
- The key in the data dictionary which holds the data (str).
|
|
137
|
-
- The tensor index where the data is present (int).
|
|
137
|
+
- The tensor index where the data is present or None (int | None).
|
|
138
138
|
|
|
139
139
|
Raises:
|
|
140
140
|
ValueError: If the tag is not registered in the descriptor.
|
|
141
141
|
"""
|
|
142
142
|
key = self._tag_to_key.get(tag)
|
|
143
143
|
index = self._tag_to_index.get(tag)
|
|
144
|
-
if key is None
|
|
144
|
+
if key is None:
|
|
145
145
|
raise ValueError(f"Tag '{tag}' is not registered in descriptor.")
|
|
146
146
|
return key, index
|
|
147
147
|
|
|
@@ -150,6 +150,7 @@ class Descriptor:
|
|
|
150
150
|
|
|
151
151
|
Retrieves the key and index associated with a tag and selects
|
|
152
152
|
the corresponding slice from the given prediction tensor.
|
|
153
|
+
Returns the full tensor if no index was specified when registering the tag.
|
|
153
154
|
|
|
154
155
|
Args:
|
|
155
156
|
tag (str): The tag identifier. Must be registered.
|
|
@@ -157,10 +158,12 @@ class Descriptor:
|
|
|
157
158
|
|
|
158
159
|
Returns:
|
|
159
160
|
Tensor: A tensor slice of shape ``(batch_size, 1)`` containing
|
|
160
|
-
the predictions for the specified tag.
|
|
161
|
+
the predictions for the specified tag, or the full tensor if no index was specified when registering the tag.
|
|
161
162
|
|
|
162
163
|
Raises:
|
|
163
164
|
ValueError: If the tag is not registered in the descriptor.
|
|
164
165
|
"""
|
|
165
166
|
key, index = self.location(tag)
|
|
167
|
+
if index is None:
|
|
168
|
+
return data[key]
|
|
166
169
|
return data[key][:, index : index + 1]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: congrads
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: A toolbox for using Constraint Guided Gradient Descent when training neural networks.
|
|
5
5
|
Author: Wout Rombouts, Quinten Van Baelen, Peter Karsmakers
|
|
6
6
|
Author-email: Wout Rombouts <wout.rombouts@kuleuven.be>, Quinten Van Baelen <quinten.vanbaelen@kuleuven.be>, Peter Karsmakers <peter.karsmakers@kuleuven.be>
|
|
@@ -36,7 +36,7 @@ Requires-Dist: torch>=2.0.0
|
|
|
36
36
|
Requires-Dist: torchvision>=0.15.1
|
|
37
37
|
Requires-Dist: tqdm>=4.65.0
|
|
38
38
|
Requires-Dist: matplotlib>=3.7.0 ; extra == 'examples'
|
|
39
|
-
Requires-Dist: tensorboard>=2.
|
|
39
|
+
Requires-Dist: tensorboard>=2.18.0 ; extra == 'examples'
|
|
40
40
|
Requires-Python: >=3.11
|
|
41
41
|
Provides-Extra: examples
|
|
42
42
|
Description-Content-Type: text/markdown
|
|
@@ -3,12 +3,12 @@ congrads/checkpoints.py,sha256=VKSl8gOtp_VPF4SVMDdj34WuAJqsv5zVtf2YlSDVOqE,7217
|
|
|
3
3
|
congrads/constraints.py,sha256=vtYt1rpmYCekWWES_85L8hgQAkJROP-R5asuizpBQqE,52336
|
|
4
4
|
congrads/core.py,sha256=_Ah2EE9WxJL7laUBcoMcgGWuZj3Hnuby4Mu4HOfkT9c,36161
|
|
5
5
|
congrads/datasets.py,sha256=RfffRiA7Qijc69cJTBJhItTZ8x9B-p1kXMjvcfEC_nA,31102
|
|
6
|
-
congrads/descriptor.py,sha256=
|
|
6
|
+
congrads/descriptor.py,sha256=MXaLPt3jOLbHR_p7ZHrBMPYfrxyxyYqU5hvte-Sqfuk,6903
|
|
7
7
|
congrads/metrics.py,sha256=QxQiLIVaCV1yyryXsKOb1HKtfYFnHNIJMG8fUDG_1JE,4614
|
|
8
8
|
congrads/networks.py,sha256=UPzPDU0wI2zoOEvi697QBSDOtaa3Rc0rgCb-tCxbjak,2252
|
|
9
9
|
congrads/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
congrads/transformations.py,sha256=57xXmf87eUHmmLUqzt0bQZdDlAc0SFxShXK6Bku1z34,3327
|
|
11
11
|
congrads/utils.py,sha256=Bqf5gdioBNot8IhQR298n6WHpQZO8W_Ex_VwB96tlmc,41508
|
|
12
|
-
congrads-1.1.
|
|
13
|
-
congrads-1.1.
|
|
14
|
-
congrads-1.1.
|
|
12
|
+
congrads-1.1.2.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
13
|
+
congrads-1.1.2.dist-info/METADATA,sha256=6Hbb_tlG8-4CMXAgkHN3SO_vcx-ntwZc94GB9tMJCEQ,10534
|
|
14
|
+
congrads-1.1.2.dist-info/RECORD,,
|