Ben Levy and Jacob Gildenblat, SagivTech

PyTorch is an incredible Deep Learning Python framework. It makes prototyping and debugging deep learning algorithms easier, and has great support for multi gpu training.
However, as always with Python, you need to be careful to avoid writing low performing code.
This gets especially important in Deep learning, where you’re spending money on all those GPUs.
Thus, speeding up your training code can have the same effect as buying many expensive GPUS.

In this post we will share a few lessons we learned while getting our PyTorch training code to run faster.

Data-loading and pre-processing

PyTorch offers a data-loader class for loading images in batches, and supports prefetching the batches using multiple worker threads.
Prefetching means that while the GPU is crunching, other threads are working on loading the data. So we can hide the IO bound latency behind the GPU computation.

PyTorch lets you write your own custom data loader/augmentation object, and then handles the multi-threading loading using DataLoader.
The multi-threading of the data loading and the augmentation, while the training forward/backward passes are done on the GPU, are crucial for a fast training loop.
To understand how this works under the hood, we will look at our own small version of this.

Data Loading

Lets create a data generator, where every input image is loaded and pre-processed in parallel from different threads.
The different threads will read the data and push batches to a shared queue.
Since the threads iterate on a generator that fetches the data from the disk, which is a common pattern, we need to take extra care to make the generator thread safe.
We also want to take extra care when using random implementations when it comes to the pre-processing step in multiple threads (like applying operations stochastically or adding random noise), using python local thread data to store numpy.random.RandomState objects for creating random number generators with different seeds can come in handy.

import threading
import numpy as np
import cv2
import random 

class threadsafe_iter:
  """Takes an iterator/generator and makes it thread-safe by
  serializing call to the `next` method of given iterator/generator.
  """
  def __init__(self, it):
    self.it = it
    self.lock = threading.Lock()

  def __iter__(self):
    return self

  def next(self):
    with self.lock:
      return self.it.next()

def get_path_i(paths_count):
  """Cyclic generator of paths indice
  """
  current_path_id = 0
  while True:
    yield current_path_id
    current_path_id	= (current_path_id + 1) % paths_count

class InputGen:
  def __init__(self, paths, batch_size):
    self.paths = paths
    self.index = 0
    self.batch_size = batch_size
    self.init_count = 0
    self.lock = threading.Lock() #mutex for input path
    self.yield_lock = threading.Lock() #mutex for generator yielding of batch
    self.path_id_generator = threadsafe_iter(get_path_i(len(self.paths))) 
    self.images = []
    self.labels = []
    
  def get_samples_count(self):
    """ Returns the total number of images needed to train an epoch """
    return len(self.paths)

  def get_batches_count(self):
    """ Returns the total number of batches needed to train an epoch """
    return int(self.get_samples_count() / self.batch_size)

  def pre_process_input(self, im,lb):
    """ Do your pre-processing here
                Need to be thread-safe function"""
    return im, lb

  def next(self):
    return self.__iter__()

  def __iter__(self):
    while True:
      #In the start of each epoch we shuffle the data paths			
      with self.lock: 
        if (self.init_count == 0):
          random.shuffle(self.paths)
          self.images, self.labels, self.batch_paths = [], [], []
          self.init_count = 1
      #Iterates through the input paths in a thread-safe manner
      for path_id in self.path_id_generator: 
        img, label = self.paths[path_id]
        img = cv2.imread(img, 1)
        label_img = cv2.imread(label,1)
        img, label = self.pre_process_input(img,label_img)
        #Concurrent access by multiple threads to the lists below
        with self.yield_lock: 
          if (len(self.images)) < self.batch_size:
            self.images.append(img)
            self.labels.append(label)
          if len(self.images) % self.batch_size == 0:					
            yield np.float32(self.images), np.float32(self.labels)
            self.images, self.labels = [], []
      #At the end of an epoch we re-init data-structures
      with self.lock: 
        self.init_count = 0
  def __call__(self):
    return self.__iter__()

Now to demonstrate the usage of the loader, here is an example training loop:

class thread_killer(object):
  """Boolean object for signaling a worker thread to terminate
  """
  def __init__(self):
    self.to_kill = False
  
  def __call__(self):
    return self.to_kill
  
  def set_tokill(self,tokill):
    self.to_kill = tokill
  
def threaded_batches_feeder(tokill, batches_queue, dataset_generator):
  """Threaded worker for pre-processing input data.
  tokill is a thread_killer object that indicates whether a thread should be terminated
  dataset_generator is the training/validation dataset generator
  batches_queue is a limited size thread-safe Queue instance.
  """
  while tokill() == False:
    for batch, (batch_images, batch_labels) \
      in enumerate(dataset_generator):
        #We fill the queue with new fetched batch until we reach the max       size.
        batches_queue.put((batch, (batch_images, batch_labels))\
                , block=True)
        if tokill() == True:
          return

def threaded_cuda_batches(tokill,cuda_batches_queue,batches_queue):
  """Thread worker for transferring pytorch tensors into
  GPU. batches_queue is the queue that fetches numpy cpu tensors.
  cuda_batches_queue receives numpy cpu tensors and transfers them to GPU space.
  """
  while tokill() == False:
    batch, (batch_images, batch_labels) = batches_queue.get(block=True)
    batch_images_np = np.transpose(batch_images, (0, 3, 1, 2))
    batch_images = torch.from_numpy(batch_images_np)
    batch_labels = torch.from_numpy(batch_labels)

    batch_images = Variable(batch_images).cuda()
    batch_labels = Variable(batch_labels).cuda()
    cuda_batches_queue.put((batch, (batch_images, batch_labels)), block=True)
    if tokill() == True:
      return

if __name__ =='__main__':
  import time
  import Thread
  import sys
  from Queue import Empty,Full,Queue
  
  num_epoches=1000
  #model is some Pytorch CNN model
  model.cuda()
  model.train()
  batches_per_epoch = 64
  #Training set list suppose to be a list of full-paths for all
  #the training images.
  training_set_list = None
  #Our train batches queue can hold at max 12 batches at any given time.
  #Once the queue is filled the queue is locked.
  train_batches_queue = Queue(maxsize=12)
  #Our numpy batches cuda transferer queue.
  #Once the queue is filled the queue is locked
  #We set maxsize to 3 due to GPU memory size limitations
  cuda_batches_queue = Queue(maxsize=3)


  training_set_generator = InputGen(training_set_list,batches_per_epoch)
  train_thread_killer = thread_killer()
  train_thread_killer.set_tokill(False)
  preprocess_workers = 4


  #We launch 4 threads to do load && pre-process the input images
  for _ in range(preprocess_workers):
    t = Thread(target=threaded_batches_feeder, \
           args=(train_thread_killer, train_batches_queue, training_set_generator))
    t.start()
  cuda_transfers_thread_killer = thread_killer()
  cuda_transfers_thread_killer.set_tokill(False)
  cudathread = Thread(target=threaded_cuda_batches, \
           args=(cuda_transfers_thread_killer, cuda_batches_queue, train_batches_queue))
  cudathread.start()

  
  #We let queue to get filled before we start the training
  time.sleep(8)
  for epoch in range(num_epoches):
    for batch in range(batches_per_epoch):
      
      #We fetch a GPU batch in 0's due to the queue mechanism
      _, (batch_images, batch_labels) = cuda_batches_queue.get(block=True)
            
      #train batch is the method for your training step.
      #no need to pin_memory due to diminished cuda transfers using queues.
      loss, accuracy = train_batch(batch_images, batch_labels)

  train_thread_killer.set_tokill(True)
  cuda_transfers_thread_killer.set_tokill(True)	
  for _ in range(preprocess_workers):
    try:
      #Enforcing thread shutdown
      train_batches_queue.get(block=True,timeout=1)
                  cuda_batches_queue.get(block=True,timeout=1)	
    except Empty:
      pass
  print "Training done"

Our sample code works as follows:

  • Init of model on the GPU
  • Init of two queues:
    – Input images queue: responsible for acquiring up-to 12 pre-processed input images along program execution lifetime in 4 different threads
  • Training loop, where we fetch in 0’s an input images batch and feed it to our “PytorchNetwork.train_batch” method for accomplishing an optimization step.
  • Cuda images queue: responsible for transferring input images from the “input images queue” to the GPU memory space in 1 different thread.
  • Resource termination, where we signal all threads to be terminated.

As a consequence, we see that the optimization paradigm above succeed in diminishing completely the expensive overhead caused by HOST TO GPU memory transfers, I/O and pre-process operations.

Some other lessons we learned along the way

Use cProfile to measure run times for everything.
Its as easy as this:

python -m cProfile -o 100_percent_gpu_utilization.prof train.py

The generated .prof file can be visualized by the awesome utility snakeviz
With the following command:

snakeviz 100_percent_gpu_utilization.prof

The output looks like this:

In the above figure (made for a run for 2 training epochs, 100 batches total training session) we see that our main training function (train_batch) is consuming 82% of the training time due to PyTorch primitive building-blocks: adam.py (optimizer), and the network forward / backward passes and the loss auto-grad variable backward. The rest of the time consumption belongs to session initializations, sleeps and validation accuracy (7%) shown in the code samples but calculated at the end of every epoch.
Notice how no time is spent on data-loading / preprocessing!

Do numpy-like operations on the GPU wherever you can

PyTorch tensors can do a lot of the things NumPy can do, but on the GPU.
We had a lot of operations like argmax that were being done in num
py in the CPU.
When doing these innocent looking operations for batches of data, they add up.

Free up memory using del

This is a common pitfall for new PyTorch users, and we think it isn’t documented enough.
After you’re done with some PyTorch tensor or variable, delete it using the python del operator to free up memory.

Avoid unnecessary transfer of data from the GPU

Cuda copies are expensive.
It turned out that a lot of our cuda copies were for batch statistics: the loss, accuracy, and other data.
Instead of displaying the loss and other metrics for every batch, aggregate them on the GPU and copy them to the CPU for display at the end of every epoch.

Avoid using pytorch DataParallel layer with Tensor.cuda() in parallel

DataParallel layer is used for distributing computations across multiple GPU’s/CPU’s.
Empirically, using Pytorch DataParallel layer in parallel to calling Tensor.cuda() variations, just like shown in the code snippet with the threaded cuda queue loop, has yielded wrong training results, probably due to the immature feature as in Pytorch version 0.1.12_2. While profiling the code, indications for problems of the two already came-up in the timing relations between the two.

Use pinned memory, and use async=True to parallelize data transfer and GPU number crunching

Let’s look at the next code snippet:

batch_images = batch_images.pin_memory() 
Batch_labels = Variable(batch_labels).cuda(async=True) 
outputs = model(batch_images) 
loss = criterion(outputs, batch_labels)

Batch_labels isn’t actually needed until line 4.
By moving it to pinned memory and making an asynchronous copy to the GPU,
The GPU data copy doesn’t cause any latency since it’s done during line 3 (the model forward pass).

In this post we shared a few lessons we learned about making PyTorch training code run faster, we invite you to share your own!

Jacob Gildenblat

Team Leader Deep Learning, SagivTech

Ben Levy

Algorithms & Software Developer, SagivTech

Legal Disclaimer:

You understand that when using the Site you may be exposed to content from a variety of sources, and that SagivTech is not responsible for the accuracy, usefulness, safety or intellectual property rights of, or relating to, such content and that such content does not express SagivTech’s opinion or endorsement of any subject matter and should not be relied upon as such. SagivTech and its affiliates accept no responsibility for any consequences whatsoever arising from use of such content. You acknowledge that any use of the content is at your own risk.