How to clear CUDA memory in PyTorch

Each Answer to this Q is separated by one/two green lines.

I am trying to get the output of a neural network which I have already trained. The input is an image of the size 300×300. I am using a batch size of 1, but I still get a CUDA error: out of memory error after I have successfully got the output for 25 images.

I tried torch.cuda.empty_cache(), but this still doesn’t seem to solve the problem. Code:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_x = torch.tensor(train_x, dtype=torch.float32).view(-1, 1, 300, 300)
train_x = train_x.to(device)
dataloader = torch.utils.data.DataLoader(train_x, batch_size=1, shuffle=False)

right = []
for i, left in enumerate(dataloader):
    print(i)
    temp = model(left).view(-1, 1, 300, 300)
    right.append(temp.to('cpu'))
    del temp
    torch.cuda.empty_cache()

This for loop runs for 25 times every time before giving the memory error.

Every time, I am sending a new image in the network for computation. So, I don’t really need to store the previous computation results in the GPU after every iteration in the loop. Is there any way to achieve this?

I figured out where I was going wrong. I am posting the solution as an answer for others who might be struggling with the same problem.

Basically, what PyTorch does is that it creates a computational graph whenever I pass the data through my network and stores the computations on the GPU memory, in case I want to calculate the gradient during backpropagation. But since I only wanted to perform a forward propagation, I simply needed to specify torch.no_grad() for my model.

Thus, the for loop in my code could be rewritten as:

for i, left in enumerate(dataloader):
    print(i)
    with torch.no_grad():
        temp = model(left).view(-1, 1, 300, 300)
    right.append(temp.to('cpu'))
    del temp
    torch.cuda.empty_cache()

Specifying no_grad() to my model tells PyTorch that I don’t want to store any previous computations, thus freeing my GPU space.


The answers/resolutions are collected from stackoverflow, are licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0 .