congrads 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/descriptor.py CHANGED
@@ -122,7 +122,7 @@ class Descriptor:
122
122
  """
123
123
  return tag in self._tag_to_key and tag in self._tag_to_index
124
124
 
125
- def location(self, tag: str) -> tuple[str, int]:
125
+ def location(self, tag: str) -> tuple[str, int | None]:
126
126
  """Get the key and index for a given tag.
127
127
 
128
128
  Looks up the mapping for a registered tag and returns the associated
@@ -132,16 +132,16 @@ class Descriptor:
132
132
  tag (str): The tag identifier. Must be registered.
133
133
 
134
134
  Returns:
135
- tuple ((str, int)): A tuple containing:
135
+ tuple ((str, int | None)): A tuple containing:
136
136
  - The key in the data dictionary which holds the data (str).
137
- - The tensor index where the data is present (int).
137
+ - The tensor index where the data is present or None (int | None).
138
138
 
139
139
  Raises:
140
140
  ValueError: If the tag is not registered in the descriptor.
141
141
  """
142
142
  key = self._tag_to_key.get(tag)
143
143
  index = self._tag_to_index.get(tag)
144
- if key is None or index is None:
144
+ if key is None:
145
145
  raise ValueError(f"Tag '{tag}' is not registered in descriptor.")
146
146
  return key, index
147
147
 
@@ -150,6 +150,7 @@ class Descriptor:
150
150
 
151
151
  Retrieves the key and index associated with a tag and selects
152
152
  the corresponding slice from the given prediction tensor.
153
+ Returns the full tensor if no index was specified when registering the tag.
153
154
 
154
155
  Args:
155
156
  tag (str): The tag identifier. Must be registered.
@@ -157,10 +158,12 @@ class Descriptor:
157
158
 
158
159
  Returns:
159
160
  Tensor: A tensor slice of shape ``(batch_size, 1)`` containing
160
- the predictions for the specified tag.
161
+ the predictions for the specified tag, or the full tensor if no index was specified when registering the tag.
161
162
 
162
163
  Raises:
163
164
  ValueError: If the tag is not registered in the descriptor.
164
165
  """
165
166
  key, index = self.location(tag)
167
+ if index is None:
168
+ return data[key]
166
169
  return data[key][:, index : index + 1]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: congrads
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: A toolbox for using Constraint Guided Gradient Descent when training neural networks.
5
5
  Author: Wout Rombouts, Quinten Van Baelen, Peter Karsmakers
6
6
  Author-email: Wout Rombouts <wout.rombouts@kuleuven.be>, Quinten Van Baelen <quinten.vanbaelen@kuleuven.be>, Peter Karsmakers <peter.karsmakers@kuleuven.be>
@@ -3,12 +3,12 @@ congrads/checkpoints.py,sha256=VKSl8gOtp_VPF4SVMDdj34WuAJqsv5zVtf2YlSDVOqE,7217
3
3
  congrads/constraints.py,sha256=vtYt1rpmYCekWWES_85L8hgQAkJROP-R5asuizpBQqE,52336
4
4
  congrads/core.py,sha256=_Ah2EE9WxJL7laUBcoMcgGWuZj3Hnuby4Mu4HOfkT9c,36161
5
5
  congrads/datasets.py,sha256=RfffRiA7Qijc69cJTBJhItTZ8x9B-p1kXMjvcfEC_nA,31102
6
- congrads/descriptor.py,sha256=Z_qGtNMQ3BKYjQpUNJxlsmEWBDwYk98ozig7nwm7RkI,6681
6
+ congrads/descriptor.py,sha256=MXaLPt3jOLbHR_p7ZHrBMPYfrxyxyYqU5hvte-Sqfuk,6903
7
7
  congrads/metrics.py,sha256=QxQiLIVaCV1yyryXsKOb1HKtfYFnHNIJMG8fUDG_1JE,4614
8
8
  congrads/networks.py,sha256=UPzPDU0wI2zoOEvi697QBSDOtaa3Rc0rgCb-tCxbjak,2252
9
9
  congrads/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  congrads/transformations.py,sha256=57xXmf87eUHmmLUqzt0bQZdDlAc0SFxShXK6Bku1z34,3327
11
11
  congrads/utils.py,sha256=Bqf5gdioBNot8IhQR298n6WHpQZO8W_Ex_VwB96tlmc,41508
12
- congrads-1.1.0.dist-info/WHEEL,sha256=Pi5uDq5Fdo_Rr-HD5h9BiPn9Et29Y9Sh8NhcJNnFU1c,79
13
- congrads-1.1.0.dist-info/METADATA,sha256=336e2CyqdhojvWovJaHrBuN1OekXA5GqU28_91GMvVY,10534
14
- congrads-1.1.0.dist-info/RECORD,,
12
+ congrads-1.1.1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
13
+ congrads-1.1.1.dist-info/METADATA,sha256=8MvQd7CqMnOyX72kNx3xHaROT237MaQASvMDxQ5GyXA,10534
14
+ congrads-1.1.1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.8.17
2
+ Generator: uv 0.8.24
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any