PyTorch Puzzles

Some tensor brain teasers

Dimension shift

Here's a puzzle posed by Francois Fleuret:

Probably dumb @PyTorch question : how to [elegantly] write
x = x.permute(0, ..., d-1, d+1, ..., x.dim()-1, d)
?

To use an example, find a way to permute axes such that for an x:

x = torch.tensor(
    [[[[[1, 1], [1, 1]], [[2, 2], [2, 2]], [[3, 3], [3, 3]], [[4, 4], [4, 4]]]]]
)

Click to show answer movedim or its alias moveaxis.
def shift(x: torch.Tensor, d: int):
    return x.movedim(d, -1)
Either prove the test cases
>>> x.movedim(-1,-1).equal(x)
True
>>> x.movedim(2,-1).equal(x.permute(0,1,3,4,2))
True