x-transformers 2.4.10__py3-none-any.whl → 2.4.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +13 -7
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.11.dist-info}/METADATA +1 -1
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.11.dist-info}/RECORD +5 -5
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.11.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.11.dist-info}/licenses/LICENSE +0 -0
| @@ -183,6 +183,7 @@ class AutoregressiveWrapper(Module): | |
| 183 183 | 
             
                    prompts,
         | 
| 184 184 | 
             
                    seq_len,
         | 
| 185 185 | 
             
                    beams = 4,
         | 
| 186 | 
            +
                    return_beams_and_scores = False,
         | 
| 186 187 | 
             
                    eos_token = None,
         | 
| 187 188 | 
             
                    temperature = 1.,
         | 
| 188 189 | 
             
                    stochastic = False,
         | 
| @@ -285,7 +286,7 @@ class AutoregressiveWrapper(Module): | |
| 285 286 | 
             
                        out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
         | 
| 286 287 | 
             
                        samples = rearrange(samples, 'b beams -> (b beams) 1')
         | 
| 287 288 |  | 
| 288 | 
            -
                        if should_cache:
         | 
| 289 | 
            +
                        if should_cache and is_first:
         | 
| 289 290 | 
             
                            modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
         | 
| 290 291 |  | 
| 291 292 | 
             
                        # concat sample
         | 
| @@ -303,14 +304,13 @@ class AutoregressiveWrapper(Module): | |
| 303 304 |  | 
| 304 305 | 
             
                            scores = scores[:, :beams]
         | 
| 305 306 | 
             
                            top_beams_indices = sort_indices[:, :beams]
         | 
| 306 | 
            -
             | 
| 307 | 
            +
             | 
| 308 | 
            +
                            top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
         | 
| 307 309 |  | 
| 308 310 | 
             
                            flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
         | 
| 309 311 |  | 
| 310 312 | 
             
                            out = out[flattened_beam_indices]
         | 
| 311 313 |  | 
| 312 | 
            -
                            modify_cached_kv(cache, lambda t: t[flattened_beam_indices])
         | 
| 313 | 
            -
             | 
| 314 314 | 
             
                        scores = rearrange(scores, 'b beams -> (b beams)')
         | 
| 315 315 |  | 
| 316 316 | 
             
                        if not exists(eos_token):
         | 
| @@ -331,11 +331,17 @@ class AutoregressiveWrapper(Module): | |
| 331 331 |  | 
| 332 332 | 
             
                    out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
         | 
| 333 333 |  | 
| 334 | 
            -
                    out = out[ | 
| 334 | 
            +
                    out = out[..., orig_seq_len:]
         | 
| 335 335 |  | 
| 336 | 
            -
                    out, = unpack(out, packed_shape, '* n')
         | 
| 336 | 
            +
                    out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension
         | 
| 337 337 |  | 
| 338 | 
            -
                     | 
| 338 | 
            +
                    if not return_beams_and_scores:
         | 
| 339 | 
            +
                        return out[..., 0, :]
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    scores = rearrange(scores, '(b beams) -> beams b', b = batch)
         | 
| 342 | 
            +
                    out = rearrange(out, 'b beams n -> beams b n')
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    return out, scores
         | 
| 339 345 |  | 
| 340 346 | 
             
                @torch.no_grad()
         | 
| 341 347 | 
             
                @eval_decorator
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
         | 
| 2 2 | 
             
            x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
         | 
| 3 | 
            -
            x_transformers/autoregressive_wrapper.py,sha256= | 
| 3 | 
            +
            x_transformers/autoregressive_wrapper.py,sha256=n8ueNBMvIjO4B1J7VvSyDzJvqUi9YmCrri1p44n-FTY,17831
         | 
| 4 4 | 
             
            x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
         | 
| 5 5 | 
             
            x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
         | 
| 6 6 | 
             
            x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
         | 
| @@ -12,7 +12,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80, | |
| 12 12 | 
             
            x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
         | 
| 13 13 | 
             
            x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
         | 
| 14 14 | 
             
            x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
         | 
| 15 | 
            -
            x_transformers-2.4. | 
| 16 | 
            -
            x_transformers-2.4. | 
| 17 | 
            -
            x_transformers-2.4. | 
| 18 | 
            -
            x_transformers-2.4. | 
| 15 | 
            +
            x_transformers-2.4.11.dist-info/METADATA,sha256=N0EjJyBQ_2EjiQRJK-Rlvt7lzCn8XqFWXFiUyqUDwU8,90224
         | 
| 16 | 
            +
            x_transformers-2.4.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         | 
| 17 | 
            +
            x_transformers-2.4.11.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
         | 
| 18 | 
            +
            x_transformers-2.4.11.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |