Thursday, September 9, 2021

Training k-nn with weighted hamming distance

In this post I'll describe a particular ML classification problem for which parameter optimization using the naive autograd implementation is inadequate, due to memory and performance issues. I'll propose a work-around by manually writing the forward and backward passes directly in CUDA and finally fine-tune its performance.

The full code used in this post can be found at https://github.com/sergioloff/kNN_WHam

Sample problem

For illustration purposes, we'll use as a toy dataset the MNIST database of hand-written digits, restricted to the classes "0" and "not 0". Further, since we'll be dealing exclusively with the hamming distance, the dataset's pixels will be converted to bits, where a bit is 1 if its greyscale value exceeds a given threshold:



Our objective is to create a classifier based on such training data. It should be able to accurately predict whether a new unknown digit is a zero or one.

We'll build our classifier using a continuous approximation to the k-nearest neighbours algorithm. Our model parameters will be a weight mask W = {Wij}, which for each pixel location (i,j)  will indicate the pixel's importance in the calculation of a sample's class. It's to be expected that pixels that are always 0 or 1, such as those near the border of the image, will have no information whatsoever about the digit's classification, and in such cases Wij will be close to zero. On the other hand, some pixels will be vital for a proper classification, such as those near the centre of the image, and their corresponding Wij should have a large absolute value. When calculating the distance between 2 binary images A ∈ {0,1}M*D and B ∈ {0,1}N*D , we'll use a weighted hamming distance which incorporates W: HDist(A,B,W) = ||(A - B)*W||L1 , where * denotes the element-wise product. We opted for binarization + hamming distance over another metric such as the Euclidian distance, since it allows for significant GPU optimizations regarding memory footprint and integer arithmetic.

Also, we'll use the whole training set when computing the classification (instead of the k-nearest training samples, as in the traditional algorithm). Each element's contribution to the final classification will be modified by its distance, so that far-away images will contribute little for the final classification score. This modified algorithm has the advantage of having a fully-differentiable loss function.

Notation and definitions

Throughout this series, we'll represent our training data samples and labels by X={Xij} and Y={Yi}, respectively. Both X and Y are binary, i.e., X = {Xij} ∈ {0,1}M*D  and Y = {Yi} ∈ {0,1}M. The classification problem will consist in using [X,Y] as a nearest neighbour classifier matrix, weighted by the trainable vector W in D.

Approximate, differentiable weighted k-nearest neighbours

In this algorithm, we'll classify a new datum Up {0,1}D by averaging the individual classification scores over the whole training set [X,Y] = [{Xi}, {Yi}].

WHDip = |(Xi – Up) * W|  

Qip = F(1 - WHDip / || W ||)

unbalanced_score_p = AVG(i, ||Qip * Yi || )

* is the element-wise product, the modulus is applied element-wise, and the norm is L1.

The function F should be a parametrizable monotonically increasing, differentiable bijection on the unit interval. It will modify the distance score by giving more or less emphasis to large values. In this implementation we shall use a sigmoid adjustable by the parameters j and k:

I also chose to linearly adjust the final score by trainable parameters alpha and beta, so that it balances the distribution before applying a cross entropy loss.

score_p = beta * (unbalanced_score_p + alpha)

Each epoch we'll shuffle the available train data set [X,Y] and split it into batches Bi, where each batch (the holdout set, or X2i) will be compared against the remaining train data, X-Bi (the batch train set, or X1i).

Note: the holdout set should not be confused with what in the litterature is called the validation/test sets. In our case, holdout set is a subset of the actual train data.

In order to optimize the parameters W, alpha, beta, k, and j, we'll minimize the binary cross-entropy of the score's logits:

loss = nn.BCEWithLogitsLoss(score)

where score = {score_p} is the vector with the scores over the whole batch holdout set.

On this will toy problem, we'll converge after only a few epochs, and obtain a classifier with accuracy>90% on the holdout test dataset.

The performance of this algorithm isn't particularly good when compared to other classifiers on MNIST, but its real interest is as a feature importance / dimensionality reduction pre-processing stage, such as PCA or SVD. 

A cursory glance at W tells us that ignoring all the pixels with small Wij, will result in a dim reduction of more than 90% for this particular classification problem.

Autograd and memory size

The difficulty when implementing this algorithm on a GPU, which is also a general problem when using kernel products, is that it's not readily extensible to a large number of features due to memory limitations. The following sum reduction

|(np.expand_dims(X1, 1) - X2) * W|.sum(-1)

is applied to a tensor of rank (M,N,D), where M = len(batch train set), N = len(holdout set), D = tot features. If our batch has N = 64 and M = 1.000.000, then we already exceed 7GB on the GPU with D>32.

One option would be to use the KeOps library, which offers implicit auto differentiation for such tensor products through the LazyTensor class:

X1_l = LazyTensor(X1[:,None,:])
X2_l = LazyTensor(X2[None,:,:])
W_l = LazyTensor(self.W.view([1,1,-1]))
WHDx_l = ((X1_l – X2_l) * W_l).abs().sum(-1)

Unfortunately KeOps fails with compilation timeouts once the cardinality of the reduction dimension becomes large, usually when D > 128. The library keeps evolving, so I won't be surprised if they've removed this limitation by now.

Alternatively, we can bypass the autodif engine when computing the WHDx term by writing CUDA code for its forward and backward passes.

The tensor derivative of WHDx wrt W can be annoying to compute due to the scattering of the feature indexes. Here's the pseudocode for the formula:

# X1: (M,D)
# X2: (N,D)
X1_r = X1.repeat_interleave(repeats=N, dim=0) # X1_r:(M*N,D)
X2_ri = X2.repeat(M, 1) # X2_ri:(N*M,D)
grad_W = (grad_output.view(M*N) * ((X1_r – X2_ri)*((X1_r – X2_ri)*W).sign()) .T).sum(-1)

where grad_output is the derivative of the loss wrt WHDx and * is the element-wise product.

CUDA implementation

This is a fairly small project, so we'll simply inline our cuda functions in our python code using cupy. We'll assume that we're optimizing for 1 GPU. Multi-GPUs require additional complexity and are out of the scope of this post.

We'll use the suffix "_T" to denote a pre-computed transpose of a tensor.

Since the data needed for both the forward and backward passes, when binarized, is fairly small, it will fit wholly in the GPU. So at each batch iteration we only need to move to the GPU the indexes of the elements used for the knn set and for the hould-out set. If the data, even when binarized, were too large to fit at once on the GPU, a fair compromise would be to use a random subset of the available knn training datums, different for every batch, and preemptively copy to the gpu the data needed for the next batch at the same time as we are running the computation of the current batch, thereby overlapping a GPU-copy op with a GPU-exec op.

Forward pass

Taking advantage of having binary data, we can pack each image's pixel into 1 bit. X1 and X2 will become arrays of uint32 where each element is a packed list of 32 features.

This problem is essentially a scan of a space of N*M*D elements so the bottleneck will become memory access. Once we read a memory block, we must strive to reuse it as much as possible by storing it in registers or local memory. So the question is in which order should we scan our problem space? Of all the combinations, it appears that M -> D -> N gives the best results.

Mapping each global thread to exactly 1 row vector in X1 offers the best compromise in memory reuse, since each thread will access and buffer the same elements of the W vector at the same time.

totThreads = 128 # const

totBlocks = M // totThreads  

 

Each thread will scan each 32-bit feature block in tandem, and we'll cache its corresponding sub-vector of W (containing the 32 W items needed for this feature block) into shared memory as a coalesced memory access.

for (int dimCol = 0; dimCol < totFeatures/32; dimCol++)
{

       sh_W[threadIdx.x] = abs(W[dimCol * 32 + threadIdx.x]);

Done this way, reading W is not a memory access bottleneck.

At this point, each thread can read its corresponding pixels for the current feature block, packed into 32 bits. In this manner there shall be no redundant reads of feature pixels.

unsigned int x1 = uX1_T[uX1_ix]; uX1_ix += M;

each thread will now synchronously (at least at the warp level) iterate the hold-out vector's bits, which will result in coalesced, non-redundant reads.

for (int n = 0; n < N; n++)
{
      
unsigned int x2 = uX2_T[uX2_ix++];

At this point each thread will hold the elements needed to compute a weighted Hamming distance for that attribute block: x1, x2 and sh_W:

float sum = 0;

unsigned int XORx = x1 ^ x2;
for
(int bix = 0; bix < 32; bix++, XORx = XORx >> 1)
      
sum += (XORx & 0x1U) == 0 ? 0 :
sh_W[bix];

 One final improvement is to take advantage of our small register footprint and unroll our loops by scanning 4 attribute blocks at a time. It's faster but will force us to lay out the attribute buffers a bit differently (please refer to the accompanying code for further details on this implementation).

Backward pass

Description coming soon. Already implemented in the accompanying code)