Unflatten in pytorch

I need to change the shape of tensor from [2, 48, 196] to [2, 48, 14,14]. I read there a "unflatten" in pytorch. But I couldn't understand how to use it. Is there any example?

1 Answer

Here is example for your question.

import torch
input = torch.randn([2,48,196])
unflatten = torch.nn.Unflatten(2, (14,14))
output = unflatten(input)

If you check output.shape, the shape is [2,48,14,14].

Unflatten function is to expand specific dim to a desired shape. In your case, you want to expand the shape 196 in "dim 2" to new shape of the unflatten dimension "(14,14)".

There are two parameters in Unflatten function.

  1. First parameter is dim. it is specific dimension which you want to be unflatten. In your case, it is 2.
  2. Second parameter is unflatten_size. It is the new shape of the unflatten dimension of the tensor. So it is (14,14).

Therefore, your Unflatten function should be looked like unflatten = torch.nn.Unflatten(2, (14,14))

8

Your Answer

Sign up or log in

Sign up using Google Sign up using Facebook Sign up using Email and Password

Post as a guest

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge that you have read and understand our privacy policy and code of conduct.

You Might Also Like