class: center, middle, inverse, title-slide # Lec 24 - pytorch - GPU ##
Statistical Computing and Computation ### Sta 663 | Spring 2022 ###
Dr. Colin Rundel --- exclude: true ```python import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt import pandas as pd import seaborn as sns import scipy import torch import os import math plt.rcParams['figure.dpi'] = 200 np.set_printoptions( edgeitems=30, linewidth=200, precision = 5, suppress=True #formatter=dict(float=lambda x: "%.5g" % x) ) pd.set_option("display.width", 130) pd.set_option("display.max_columns", 10) pd.set_option("display.precision", 6) ``` ```r knitr::opts_chunk$set( fig.align="center", cache=FALSE ) library(lme4) ``` ``` ## Loading required package: Matrix ``` ```r local({ hook_err_old <- knitr::knit_hooks$get("error") # save the old hook knitr::knit_hooks$set(error = function(x, options) { # now do whatever you want to do with x, and pass # the new x to the old hook x = sub("## \n## Detailed traceback:\n.*$", "", x) x = sub("Error in py_call_impl\\(.*?\\)\\: ", "", x) #x = stringr::str_wrap(x, width = 100) hook_err_old(x, options) }) hook_warn_old <- knitr::knit_hooks$get("warning") # save the old hook knitr::knit_hooks$set(warning = function(x, options) { x = sub("<string>:1: ", "", x) #x = stringr::str_wrap(x, width = 100) hook_warn_old(x, options) }) hook_msg_old <- knitr::knit_hooks$get("output") # save the old hook knitr::knit_hooks$set(output = function(x, options) { if (is.null(options$wrap)) options$wrap = TRUE x = stringr::str_replace(x, "(## ).* ([A-Za-z]+Warning:)", "\\1\\2") x = stringr::str_split(x, "\n")[[1]] #x = stringr::str_wrap(x, width = 120, exdent = 3) x = stringr::str_remove_all(x, "\r") if (options$wrap) x = stringi::stri_wrap(x, width=120, exdent = 3, normalize=FALSE) x = paste(x, collapse="\n") #x = stringr::str_wrap(x, width = 100) hook_msg_old(x, options) }) }) ``` --- ## CUDA > CUDA (or Compute Unified Device Architecture) is a parallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing unit (GPU) for general purpose processing, an approach called general-purpose computing on GPUs (GPGPU). CUDA is a software layer that gives direct access to the GPU's virtual instruction set and parallel computational elements, for the execution of compute kernels. <br/> Core libraries: .col3_left[ * cuBLAS * cuSOLVER * cuSPARSE ] .col3_mid[ * cuFFT * cuTENSOR * cuRAND ] .col3_right[ * Thrust * cuDNN ] .footnote[[Source](https://en.wikipedia.org/wiki/CUDA)] --- ## CUDA Kernels ```c // Kernel - Adding two matrices MatA and MatB __global__ void MatAdd(float MatA[N][N], float MatB[N][N], float MatC[N][N]) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; if (i < N && j < N) MatC[i][j] = MatA[i][j] + MatB[i][j]; } int main() { ... // Matrix addition kernel launch from host code dim3 threadsPerBlock(16, 16); dim3 numBlocks( (N + threadsPerBlock.x -1) / threadsPerBlock.x, (N+threadsPerBlock.y -1) / threadsPerBlock.y ); MatAdd<<<numBlocks, threadsPerBlock>>>(MatA, MatB, MatC); ... } ``` --- class: center, middle <img src="imgs/gpu_bench1.png" width="60%" style="display: block; margin: auto;" /> --- class: center, middle <img src="imgs/gpu_bench2.png" width="60%" style="display: block; margin: auto;" /> --- ## GPU Status ```bash nvidia-smi ``` ``` ## Wed Apr 6 10:22:09 2022 ## +-----------------------------------------------------------------------------+ ## | NVIDIA-SMI 470.103.01 Driver Version: 470.103.01 CUDA Version: 11.4 | ## |-------------------------------+----------------------+----------------------+ ## | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | ## | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | ## | | | MIG M. | ## |===============================+======================+======================| ## | 0 Tesla P100-PCIE... Off | 00000000:02:00.0 Off | 0 | ## | N/A 42C P0 33W / 250W | 1521MiB / 16280MiB | 0% Default | ## | | | N/A | ## +-------------------------------+----------------------+----------------------+ ## | 1 Tesla P100-PCIE... Off | 00000000:03:00.0 Off | 0 | ## | N/A 39C P0 27W / 250W | 2MiB / 16280MiB | 0% Default | ## | | | N/A | ## +-------------------------------+----------------------+----------------------+ ## ## +-----------------------------------------------------------------------------+ ## | Processes: | ## | GPU GI CI PID Type Process name GPU Memory | ## | ID ID Usage | ## |=============================================================================| ## | 0 N/A N/A 2749475 C ...tudio-server/bin/rsession 1519MiB | ## +-----------------------------------------------------------------------------+ ``` --- ## Torch GPU Information ```python torch.cuda.is_available() ``` ``` ## True ``` ```python torch.cuda.device_count() ``` ``` ## 2 ``` ```python torch.cuda.get_device_name("cuda:0") ``` ``` ## 'Tesla P100-PCIE-16GB' ``` ```python torch.cuda.get_device_name("cuda:1") ``` ``` ## 'Tesla P100-PCIE-16GB' ``` ```python torch.cuda.get_device_properties(0) ``` ``` ## _CudaDeviceProperties(name='Tesla P100-PCIE-16GB', major=6, minor=0, total_memory=16280MB, multi_processor_count=56) ``` ```python torch.cuda.get_device_properties(1) ``` ``` ## _CudaDeviceProperties(name='Tesla P100-PCIE-16GB', major=6, minor=0, total_memory=16280MB, multi_processor_count=56) ``` --- ## GPU Tensors Usage of the GPU is governed by the location of the Tensors - to use the GPU we allocate them on the GPU device. .pull-left[ ```python cpu = torch.device('cpu') cuda0 = torch.device('cuda:0') cuda1 = torch.device('cuda:1') x = torch.linspace(0,1,5, device=cuda0) y = torch.randn(5,2, device=cuda0) z = torch.rand(2,3, device=cpu) ``` ```python x ``` ``` ## tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000], device='cuda:0') ``` ```python y ``` ``` ## tensor([[ 0.6870, 2.4937], ## [ 0.1939, -0.1403], ## [ 0.5835, -0.7860], ## [-0.4810, -0.0132], ## [-1.8345, -1.3653]], device='cuda:0') ``` ```python z ``` ``` ## tensor([[0.7300, 0.4717, 0.5368], ## [0.8460, 0.4802, 0.5372]]) ``` ] -- .pull-right[ ```python x @ y ``` ``` ## tensor([-1.8550, -1.8033], device='cuda:0') ``` ```python y @ z ``` ``` ## RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm) ``` ```python y @ z.to(cuda0) ``` ``` ## tensor([[ 2.6112, 1.5216, 1.7083], ## [ 0.0228, 0.0241, 0.0287], ## [-0.2391, -0.1022, -0.1090], ## [-0.3623, -0.2332, -0.2653], ## [-2.4942, -1.5211, -1.7181]], device='cuda:0') ``` ] --- ## NN Layers + GPU NN layers (parameters) also need to be assigned to the GPU to be used with GPU tensors, ```python nn = torch.nn.Linear(5,5) X = torch.randn(10,5).cuda() ``` -- ```python nn(X) ``` ``` ## RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm) ``` -- ```python nn.cuda()(X) ``` ``` ## tensor([[ 0.5100, 0.7798, 0.8372, 0.3515, -0.1180], ## [ 0.7318, 0.0803, 0.5940, 0.2220, 0.0618], ## [-0.4500, -0.4220, 0.8975, -0.3160, 0.4002], ## [ 0.1245, -0.1168, -0.0518, 0.2124, 0.1446], ## [ 0.2666, -0.1466, 0.2908, -0.0510, 0.3981], ## [ 0.2873, -0.0099, 0.3003, -0.1141, -0.0164], ## [-0.1170, -0.2701, 0.2517, 0.1761, -0.1481], ## [ 0.3376, -0.1155, 0.0512, 0.2084, 0.0318], ## [-0.0619, -0.0692, 0.6493, 0.1803, -0.0687], ## [-0.4949, -1.0181, 0.0118, -0.1059, -0.0487]], device='cuda:0', ## grad_fn=<AddmmBackward0>) ``` ```python nn.to(device="cuda")(X) ``` ``` ## tensor([[ 0.5100, 0.7798, 0.8372, 0.3515, -0.1180], ## [ 0.7318, 0.0803, 0.5940, 0.2220, 0.0618], ## [-0.4500, -0.4220, 0.8975, -0.3160, 0.4002], ## [ 0.1245, -0.1168, -0.0518, 0.2124, 0.1446], ## [ 0.2666, -0.1466, 0.2908, -0.0510, 0.3981], ## [ 0.2873, -0.0099, 0.3003, -0.1141, -0.0164], ## [-0.1170, -0.2701, 0.2517, 0.1761, -0.1481], ## [ 0.3376, -0.1155, 0.0512, 0.2084, 0.0318], ## [-0.0619, -0.0692, 0.6493, 0.1803, -0.0687], ## [-0.4949, -1.0181, 0.0118, -0.1059, -0.0487]], device='cuda:0', ## grad_fn=<AddmmBackward0>) ``` --- ## Back to MNIST Same MNIST data from last time (1x8x8 images), ```python from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split digits = load_digits() X, y = digits.data, digits.target X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.20, shuffle=True, random_state=1234 ) X_train = torch.from_numpy(X_train).float() y_train = torch.from_numpy(y_train) X_test = torch.from_numpy(X_test).float() y_test = torch.from_numpy(y_test) ``` -- To use the GPU for computation we need to copy these tensors to the GPU, ```python X_train_cuda = X_train.to(device=cuda0) y_train_cuda = y_train.to(device=cuda0) X_test_cuda = X_test.to(device=cuda0) y_test_cuda = y_test.to(device=cuda0) ``` --- ## Convolutional NN ```python class mnist_conv_model(torch.nn.Module): def __init__(self, device): super().__init__() self.device = torch.device(device) self.cnn = torch.nn.Conv2d( in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1 ).to(device=self.device) self.relu = torch.nn.ReLU().to(device=self.device) self.pool = torch.nn.MaxPool2d(kernel_size=2).to(device=self.device) self.lin = torch.nn.Linear(8 * 4 * 4, 10).to(device=self.device) def forward(self, X): out = self.cnn(X.view(-1, 1, 8, 8)) out = self.relu(out) out = self.pool(out) out = self.lin(out.view(-1, 8 * 4 * 4)) return out def fit(self, X, y, lr=0.001, n=1000, acc_step=10): opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) losses = [] for i in range(n): opt.zero_grad() loss = torch.nn.CrossEntropyLoss()(self(X), y) loss.backward() opt.step() losses.append(loss.item()) return losses def accuracy(self, X, y): val, pred = torch.max(self(X), dim=1) return( (pred == y).sum() / len(y) ) ``` --- ## CPU vs Cuda .pull-left[ ```python m = mnist_conv_model(device="cpu") loss = m.fit(X_train, y_train, n=1000) loss[-5:] ``` ``` ## [0.04613681882619858, 0.046090248972177505, 0.046043746173381805, 0.04599732160568237, 0.04595096781849861] ``` ```python m.accuracy(X_test, y_test) ``` ``` ## tensor(0.9778) ``` ] .pull-right[ ```python m_cuda = mnist_conv_model(device="cuda") loss = m_cuda.fit(X_train_cuda, y_train_cuda, n=1000) loss[-5:] ``` ``` ## [0.036959268152713776, 0.036920323967933655, 0.0368814654648304, 0.03684266656637192, 0.036803923547267914] ``` ```python m_cuda.accuracy(X_test_cuda, y_test_cuda) ``` ``` ## tensor(0.9750, device='cuda:0') ``` ] --- ```python m_cuda = mnist_conv_model(device="cuda") start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() loss = m_cuda.fit(X_train_cuda, y_train_cuda, n=1000) end.record() torch.cuda.synchronize() print(start.elapsed_time(end)) ``` ``` ## 2772.14794921875 ``` ```python m = mnist_conv_model(device="cpu") start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() loss = m.fit(X_train, y_train, n=1000) end.record() torch.cuda.synchronize() print(start.elapsed_time(end)) ``` ``` ## 8505.6484375 ``` --- ## CPU vs GPU Profiles .small[ ```python m_cuda = mnist_conv_model(device="cuda") with torch.autograd.profiler.profile(with_stack=True) as prof_cuda: tmp = m_cuda(X_train_cuda) ``` ```python print(prof_cuda.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## aten::cudnn_convolution 72.08% 821.000us 76.21% 868.000us 868.000us 1 ## cudaLaunchKernel 6.06% 69.000us 6.06% 69.000us 9.857us 7 ## aten::addmm 3.60% 41.000us 5.27% 60.000us 60.000us 1 ## aten::clamp_min 2.19% 25.000us 6.50% 74.000us 37.000us 2 ## aten::add_ 2.11% 24.000us 2.63% 30.000us 30.000us 1 ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 1.139ms ``` ] -- .small[ ```python m = mnist_conv_model(device="cpu") with torch.autograd.profiler.profile(with_stack=True, profile_memory=True) as prof_cpu: tmp = m(X_train) ``` ```python print(prof_cpu.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ## aten::mkldnn_convolution 41.88% 3.086ms 42.05% 3.098ms 3.098ms 2.81 Mb 0 b 1 ## aten::max_pool2d_with_indices 41.86% 3.084ms 41.86% 3.084ms 3.084ms 2.10 Mb 2.10 Mb 1 ## aten::clamp_min 13.15% 969.000us 26.25% 1.934ms 967.000us 5.61 Mb 2.81 Mb 2 ## aten::addmm 1.68% 124.000us 1.95% 144.000us 144.000us 56.13 Kb 56.13 Kb 1 ## aten::convolution 0.23% 17.000us 42.40% 3.124ms 3.124ms 2.81 Mb 0 b 1 ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 7.368ms ``` ] --- class: middle .center[ ## CIFAR10 ] .footnote[[homepage](https://www.cs.toronto.edu/~kriz/cifar.html)] --- ## Loading the data ```python import torchvision training_data = torchvision.datasets.CIFAR10( root="/data", train=True, download=True, transform=torchvision.transforms.ToTensor() ) ``` ``` ## Files already downloaded and verified ``` ```python test_data = torchvision.datasets.CIFAR10( root="/data", train=False, download=True, transform=torchvision.transforms.ToTensor() ) ``` ``` ## Files already downloaded and verified ``` --- ## CIFAR10 data ```python training_data.classes ``` ``` ## ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] ``` ```python training_data.data.shape ``` ``` ## (50000, 32, 32, 3) ``` ```python test_data.data.shape ``` ``` ## (10000, 32, 32, 3) ``` -- ```python training_data[0] ``` ``` ## (tensor([[[0.2314, 0.1686, 0.1961, ..., 0.6196, 0.5961, 0.5804], ## [0.0627, 0.0000, 0.0706, ..., 0.4824, 0.4667, 0.4784], ## [0.0980, 0.0627, 0.1922, ..., 0.4627, 0.4706, 0.4275], ## ..., ## [0.8157, 0.7882, 0.7765, ..., 0.6275, 0.2196, 0.2078], ## [0.7059, 0.6784, 0.7294, ..., 0.7216, 0.3804, 0.3255], ## [0.6941, 0.6588, 0.7020, ..., 0.8471, 0.5922, 0.4824]], ## ## [[0.2431, 0.1804, 0.1882, ..., 0.5176, 0.4902, 0.4863], ## [0.0784, 0.0000, 0.0314, ..., 0.3451, 0.3255, 0.3412], ## [0.0941, 0.0275, 0.1059, ..., 0.3294, 0.3294, 0.2863], ## ..., ## [0.6667, 0.6000, 0.6314, ..., 0.5216, 0.1216, 0.1333], ## [0.5451, 0.4824, 0.5647, ..., 0.5804, 0.2431, 0.2078], ## [0.5647, 0.5059, 0.5569, ..., 0.7216, 0.4627, 0.3608]], ## ## [[0.2471, 0.1765, 0.1686, ..., 0.4235, 0.4000, 0.4039], ## [0.0784, 0.0000, 0.0000, ..., 0.2157, 0.1961, 0.2235], ## [0.0824, 0.0000, 0.0314, ..., 0.1961, 0.1961, 0.1647], ## ..., ## [0.3765, 0.1333, 0.1020, ..., 0.2745, 0.0275, 0.0784], ## [0.3765, 0.1647, 0.1176, ..., 0.3686, 0.1333, 0.1333], ## [0.4549, 0.3686, 0.3412, ..., 0.5490, 0.3294, 0.2824]]]), 6) ``` --- <img src="Lec24_files/figure-html/unnamed-chunk-25-1.png" width="85%" style="display: block; margin: auto;" /> --- ## Data Loaders ```python batch_size = 100 training_loader = torch.utils.data.DataLoader( training_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True ) test_loader = torch.utils.data.DataLoader( test_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True ) ``` -- ```python training_loader ``` ``` ## <torch.utils.data.dataloader.DataLoader object at 0x7f586c2e8280> ``` -- ```python X, y = next(iter(training_loader)) X.shape ``` ``` ## torch.Size([100, 3, 32, 32]) ``` ```python y.shape ``` ``` ## torch.Size([100]) ``` --- .small[ ```python class cifar_conv_model(torch.nn.Module): def __init__(self, device): super().__init__() self.device = torch.device(device) self.model = torch.nn.Sequential( torch.nn.Conv2d(3, 6, kernel_size=5), torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2), torch.nn.Conv2d(6, 16, kernel_size=5), torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2), torch.nn.Flatten(), torch.nn.Linear(16 * 5 * 5, 120), torch.nn.ReLU(), torch.nn.Linear(120, 84), torch.nn.ReLU(), torch.nn.Linear(84, 10) ).to(device=self.device) def forward(self, X): return self.model(X) def fit(self, loader, epochs=10, n_report=250, lr=0.001): opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) for epoch in range(epochs): running_loss = 0.0 for i, (X, y) in enumerate(loader): X, y = X.to(self.device), y.to(self.device) opt.zero_grad() loss = torch.nn.CrossEntropyLoss()(self(X), y) loss.backward() opt.step() # print statistics running_loss += loss.item() if i % n_report == (n_report-1): # print every 100 mini-batches print(f'[Epoch {epoch + 1}, Minibatch {i + 1:4d}] loss: {running_loss / n_report:.3f}') running_loss = 0.0 ``` ] .footnote[Based on [source](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)] --- ## Forward step performance .small[ ```python m_cuda = cifar_conv_model(device="cuda") X, y = next(iter(training_loader)) with torch.autograd.profiler.profile(with_stack=True) as prof_cuda: X, y = X.to(device="cuda"), y.to(device="cuda") tmp = m_cuda(X) ``` ```python print(prof_cuda.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## aten::to 38.77% 725.000us 53.21% 995.000us 497.500us 2 ## cudaLaunchKernel 8.56% 160.000us 8.56% 160.000us 8.889us 18 ## aten::cudnn_convolution 6.63% 124.000us 10.37% 194.000us 97.000us 2 ## cudaStreamSynchronize 6.26% 117.000us 6.26% 117.000us 58.500us 2 ## aten::addmm 6.26% 117.000us 9.14% 171.000us 57.000us 3 ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 1.870ms ``` ] -- .small[ ```python m_cpu = cifar_conv_model(device="cpu") X, y = next(iter(training_loader)) with torch.autograd.profiler.profile(with_stack=True) as prof_cpu: tmp = m_cpu(X) ``` ```python print(prof_cpu.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## aten::mkldnn_convolution 62.37% 6.180ms 62.56% 6.198ms 3.099ms 2 ## aten::max_pool2d_with_indices 27.46% 2.721ms 27.46% 2.721ms 1.361ms 2 ## aten::clamp_min 4.25% 421.000us 8.36% 828.000us 103.500us 8 ## aten::addmm 3.22% 319.000us 3.72% 369.000us 123.000us 3 ## aten::convolution 0.50% 50.000us 63.24% 6.266ms 3.133ms 2 ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 9.908ms ``` ] --- ## Fit - 1 epoch .small[ ```python m_cuda = cifar_conv_model(device="cuda") with torch.autograd.profiler.profile(with_stack=True) as prof_cuda: m_cuda.fit(loader=training_loader, epochs=1, n_report=501) ``` ```python print(prof_cuda.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## cudaLaunchKernel 15.50% 619.786ms 15.50% 619.786ms 12.790us 48460 ## Optimizer.step#SGD.step 11.77% 470.570ms 22.73% 908.671ms 1.817ms 500 ## enumerate(DataLoader)#_MultiProcessingDataLoaderIter... 8.34% 333.381ms 8.43% 337.105ms 672.864us 501 ## aten::add_ 7.75% 309.836ms 12.29% 491.312ms 30.745us 15980 ## Optimizer.zero_grad#SGD.zero_grad 2.99% 119.458ms 7.02% 280.481ms 560.962us 500 ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 3.998s ``` ] -- .small[ ```python m_cpu = cifar_conv_model(device="cpu") with torch.autograd.profiler.profile(with_stack=True) as prof_cpu: m_cpu.fit(loader=training_loader, epochs=1, n_report=501) ``` ```python print(prof_cpu.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## aten::mkldnn_convolution 31.96% 2.457s 32.11% 2.468s 2.468ms 1000 ## aten::convolution_backward 29.23% 2.247s 29.31% 2.253s 2.253ms 1000 ## aten::max_pool2d_with_indices 14.50% 1.114s 14.50% 1.114s 1.114ms 1000 ## aten::threshold_backward 5.42% 416.788ms 5.42% 416.788ms 208.394us 2000 ## enumerate(DataLoader)#_MultiProcessingDataLoaderIter... 3.07% 235.781ms 3.09% 237.270ms 473.593us 501 ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 7.688s ``` ] --- ## Fit - 2 epochs .small[ ```python m_cuda = cifar_conv_model(device="cuda") with torch.autograd.profiler.profile(with_stack=True) as prof_cuda: m_cuda.fit(loader=training_loader, epochs=2, n_report=501) ``` ```python print(prof_cuda.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## cudaLaunchKernel 14.73% 1.221s 14.73% 1.221s 12.593us 96960 ## Optimizer.step#SGD.step 11.35% 940.821ms 21.78% 1.805s 1.805ms 1000 ## enumerate(DataLoader)#_MultiProcessingDataLoaderIter... 9.87% 818.526ms 10.02% 830.834ms 829.176us 1002 ## aten::add_ 7.30% 604.797ms 11.52% 955.289ms 29.871us 31980 ## Optimizer.zero_grad#SGD.zero_grad 3.02% 250.166ms 7.08% 586.920ms 586.920us 1000 ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 8.289s ``` ] -- .small[ ```python m_cpu = cifar_conv_model(device="cpu") with torch.autograd.profiler.profile(with_stack=True) as prof_cpu: m_cpu.fit(loader=training_loader, epochs=2, n_report=501) ``` ```python print(prof_cpu.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## aten::mkldnn_convolution 29.41% 4.998s 29.61% 5.032s 2.516ms 2000 ## aten::convolution_backward 28.05% 4.766s 28.19% 4.790s 2.395ms 2000 ## aten::max_pool2d_with_indices 15.23% 2.588s 15.23% 2.588s 1.294ms 2000 ## aten::threshold_backward 6.09% 1.034s 6.09% 1.034s 258.579us 4000 ## enumerate(DataLoader)#_MultiProcessingDataLoaderIter... 4.00% 679.409ms 4.01% 682.062ms 680.701us 1002 ## ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 16.993s ``` ] --- ## Loaders & Accuracy ```python def accuracy(model, loader, device): total, correct = 0, 0 with torch.no_grad(): for X, y in loader: X, y = X.to(device=device), y.to(device=device) pred = model(X) # the class with the highest energy is what we choose as prediction val, idx = torch.max(pred, 1) total += pred.size(0) correct += (idx == y).sum().item() return correct / total ``` --- ## Model fitting ```python m = cifar_conv_model("cuda") m.fit(training_loader, epochs=10, n_report=500, lr=0.01) ## [Epoch 1, Minibatch 500] loss: 2.098 ## [Epoch 2, Minibatch 500] loss: 1.692 ## [Epoch 3, Minibatch 500] loss: 1.482 ## [Epoch 4, Minibatch 500] loss: 1.374 ## [Epoch 5, Minibatch 500] loss: 1.292 ## [Epoch 6, Minibatch 500] loss: 1.226 ## [Epoch 7, Minibatch 500] loss: 1.173 ## [Epoch 8, Minibatch 500] loss: 1.117 ## [Epoch 9, Minibatch 500] loss: 1.071 ## [Epoch 10, Minibatch 500] loss: 1.035 ``` ```python accuracy(m, training_loader, "cuda") ## 0.63444 accuracy(m, test_loader, "cuda") ## 0.572 ``` --- ## More epochs If we use fit with the existing model we continue fitting, ```python m.fit(training_loader, epochs=10, n_report=500) ## [Epoch 1, Minibatch 500] loss: 0.885 ## [Epoch 2, Minibatch 500] loss: 0.853 ## [Epoch 3, Minibatch 500] loss: 0.839 ## [Epoch 4, Minibatch 500] loss: 0.828 ## [Epoch 5, Minibatch 500] loss: 0.817 ## [Epoch 6, Minibatch 500] loss: 0.806 ## [Epoch 7, Minibatch 500] loss: 0.798 ## [Epoch 8, Minibatch 500] loss: 0.787 ## [Epoch 9, Minibatch 500] loss: 0.780 ## [Epoch 10, Minibatch 500] loss: 0.773 ``` ```python accuracy(m, training_loader, "cuda") ## 0.73914 accuracy(m, test_loader, "cuda") ## 0.624 ``` --- ## More epochs (again) ```python m.fit(training_loader, epochs=10, n_report=500) ## [Epoch 1, Minibatch 500] loss: 0.764 ## [Epoch 2, Minibatch 500] loss: 0.756 ## [Epoch 3, Minibatch 500] loss: 0.748 ## [Epoch 4, Minibatch 500] loss: 0.739 ## [Epoch 5, Minibatch 500] loss: 0.733 ## [Epoch 6, Minibatch 500] loss: 0.726 ## [Epoch 7, Minibatch 500] loss: 0.718 ## [Epoch 8, Minibatch 500] loss: 0.710 ## [Epoch 9, Minibatch 500] loss: 0.702 ## [Epoch 10, Minibatch 500] loss: 0.698 ``` ```python accuracy(m, training_loader, "cuda") ## 0.76438 accuracy(m, test_loader, "cuda") ## 0.6217 ``` --- ## The VGG16 model ```python class VGG16(torch.nn.Module): def __init__(self, device): super().__init__() self.device = torch.device(device) self.model = self.make_layers() def forward(self, X): return self.model(X) def make_layers(self): cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] layers = [] in_channels = 3 for x in cfg: if x == 'M': layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2)] else: layers += [torch.nn.Conv2d(in_channels, x, kernel_size=3, padding=1), torch.nn.BatchNorm2d(x), torch.nn.ReLU(inplace=True)] in_channels = x layers += [ torch.nn.AvgPool2d(kernel_size=1, stride=1), torch.nn.Flatten(), torch.nn.Linear(512,10) ] return torch.nn.Sequential(*layers).to(self.device) ``` .footnote[Based on code from [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar), original [paper](https://arxiv.org/abs/1409.1556)] --- .small[ ```python VGG16("cuda").model ``` ``` ## Sequential( ## (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (2): ReLU(inplace=True) ## (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (5): ReLU(inplace=True) ## (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ## (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (9): ReLU(inplace=True) ## (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (12): ReLU(inplace=True) ## (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ## (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (16): ReLU(inplace=True) ## (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (19): ReLU(inplace=True) ## (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (22): ReLU(inplace=True) ## (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ## (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (26): ReLU(inplace=True) ## (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (29): ReLU(inplace=True) ## (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (32): ReLU(inplace=True) ## (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ## (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (36): ReLU(inplace=True) ## (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (39): ReLU(inplace=True) ## (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ## (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ## (42): ReLU(inplace=True) ## (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ## (44): AvgPool2d(kernel_size=1, stride=1, padding=0) ## (45): Flatten(start_dim=1, end_dim=-1) ## (46): Linear(in_features=512, out_features=10, bias=True) ## ) ``` ] --- ## Minibatch performance .small[ ```python m_cuda = VGG16(device="cuda") X, y = next(iter(training_loader)) with torch.autograd.profiler.profile(with_stack=True) as prof_cuda: X, y = X.to(device="cuda"), y.to(device="cuda") tmp = m_cuda(X) ``` ```python print(prof_cuda.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## aten::to 69.03% 51.397ms 69.38% 51.653ms 25.826ms 2 ## aten::cudnn_batch_norm 12.10% 9.007ms 15.14% 11.273ms 867.154us 13 ## cudaMalloc 11.64% 8.664ms 11.64% 8.664ms 866.400us 10 ## aten::max_pool2d_with_indices 1.41% 1.050ms 2.93% 2.182ms 436.400us 5 ## aten::_convolution 1.20% 897.000us 11.12% 8.282ms 637.077us 13 ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 74.451ms ``` ] -- .small[ ```python m_cpu = VGG16(device="cpu") X, y = next(iter(training_loader)) with torch.autograd.profiler.profile(with_stack=True) as prof_cpu: tmp = m_cpu(X) ``` ```python print(prof_cpu.key_averages().table(sort_by='self_cpu_time_total', row_limit=5)) ``` ``` ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## aten::mkldnn_convolution 85.80% 1.033s 85.92% 1.035s 79.587ms 13 ## aten::native_batch_norm 7.13% 85.896ms 8.33% 100.295ms 7.715ms 13 ## aten::max_pool2d_with_indices 4.91% 59.109ms 4.91% 59.109ms 11.822ms 5 ## aten::sum 1.05% 12.592ms 1.05% 12.645ms 972.692us 13 ## aten::clamp_min 0.70% 8.387ms 0.70% 8.387ms 645.154us 13 ## --------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ## Self CPU time total: 1.204s ``` ] --- ## Fitting ```python def fit(model, loader, epochs=10, n_report=250, lr = 0.01): opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) for epoch in range(epochs): running_loss = 0.0 for i, (X, y) in enumerate(loader): X, y = X.to(model.device), y.to(model.device) opt.zero_grad() loss = torch.nn.CrossEntropyLoss()(model(X), y) loss.backward() opt.step() running_loss += loss.item() if i % n_report == (n_report-1): print(f'[Epoch {epoch + 1}, Minibatch {i + 1:4d}] loss: {running_loss / n_report:.3f}') running_loss = 0.0 ``` --- ## `lr = 0.01` ```python m = VGG16(device="cuda") fit(m, training_loader, epochs=10, n_report=500, lr=0.01) ## [Epoch 1, Minibatch 500] loss: 1.345 ## [Epoch 2, Minibatch 500] loss: 0.790 ## [Epoch 3, Minibatch 500] loss: 0.577 ## [Epoch 4, Minibatch 500] loss: 0.445 ## [Epoch 5, Minibatch 500] loss: 0.350 ## [Epoch 6, Minibatch 500] loss: 0.274 ## [Epoch 7, Minibatch 500] loss: 0.215 ## [Epoch 8, Minibatch 500] loss: 0.167 ## [Epoch 9, Minibatch 500] loss: 0.127 ## [Epoch 10, Minibatch 500] loss: 0.103 ``` -- ```python accuracy(model=m, loader=training_loader, device="cuda") ## 0.97008 accuracy(model=m, loader=test_loader, device="cuda") ## 0.8318 ``` --- ## `lr = 0.001` ```python m = VGG16(device="cuda") fit(m, training_loader, epochs=10, n_report=500, lr=0.001) ## [Epoch 1, Minibatch 500] loss: 1.279 ## [Epoch 2, Minibatch 500] loss: 0.827 ## [Epoch 3, Minibatch 500] loss: 0.599 ## [Epoch 4, Minibatch 500] loss: 0.428 ## [Epoch 5, Minibatch 500] loss: 0.303 ## [Epoch 6, Minibatch 500] loss: 0.210 ## [Epoch 7, Minibatch 500] loss: 0.144 ## [Epoch 8, Minibatch 500] loss: 0.108 ## [Epoch 9, Minibatch 500] loss: 0.088 ## [Epoch 10, Minibatch 500] loss: 0.063 ``` -- ```python accuracy(model=m, loader=training_loader, device="cuda") ## 0.9815 accuracy(model=m, loader=test_loader, device="cuda") ## 0.7816 ``` --- ## Report ```python from sklearn.metrics import classification_report def report(model, loader, device): y_true, y_pred = [], [] with torch.no_grad(): for X, y in loader: X = X.to(device=device) y_true.append( y.cpu().numpy() ) y_pred.append( model(X).max(1)[1].cpu().numpy() ) y_true = np.concatenate(y_true) y_pred = np.concatenate(y_pred) return classification_report(y_true, y_pred, target_names=loader.dataset.classes) ``` --- ```python print(report(model=m, loader=test_loader, device="cuda")) ## precision recall f1-score support ## ## airplane 0.82 0.88 0.85 1000 ## automobile 0.95 0.89 0.92 1000 ## bird 0.85 0.70 0.77 1000 ## cat 0.68 0.74 0.71 1000 ## deer 0.84 0.83 0.83 1000 ## dog 0.81 0.73 0.77 1000 ## frog 0.83 0.92 0.87 1000 ## horse 0.87 0.87 0.87 1000 ## ship 0.89 0.92 0.90 1000 ## truck 0.86 0.93 0.89 1000 ## ## accuracy 0.84 10000 ## macro avg 0.84 0.84 0.84 10000 ## weighted avg 0.84 0.84 0.84 10000 ```