An annotated guide to the Kolmogorov-Arnold Network

This post is analogous to and heavily inspired by the Annotated Transformer but for KANs. It is fully functional as a standalone notebook, and provides intuition along with the code. Most of the code was written to be easy to follow and to mimic the structure of a standard deep learning model in PyTorch, but some parts like training loops and visualization code were adapted from the original codebase. We decided to remove some sections from the original paper that were deemed unimportant, and also includes some extra works to motivate future research on these models.

The original paper is titled “KAN: Kolmogorov-Arnold Networks” **Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljačić, Thomas Y. Hou, and Max Tegmark.**

Deep neural networks have been the driving force of developments in AI in the last decade. However, they currently suffer from several known issues such as a lack of interpretability, scaling issues, and data inefficiency – in other words, while they are powerful, they are not a perfect solution.

Kolmogorov-Arnold Networks (KANs) are an alternative representation to standard multi-layer perceptrons (MLPs). In short, they parameterize activation functions by re-wiring the “multiplication” in an MLP’s weight matrix-vector multiplication into function application. While KANs are not nearly as provably accomplished as MLPs, they are an exciting prospect for the field of AI and deserve some time for exploration.

I have separated this article into two sections. Parts I & II describe a minimal KAN architecture and training loop without an emphasis on B-spline optimizations. You can use the minimal KAN notebook if you’re interested in KANs at a high-level. Parts III & IV describe B-spline specific optimizations and an application of KANs, which includes a bit of extra machinery in the KAN code. You can use the full KAN notebook if you want to follow along there.

Before jumping into the implementation details, it is important to take a step back and understand why one should even care about these models. It is quite well known that Multi-layer Perceptrons (MLPs) have the “Universal Approximation Theorem”**existence** of an MLP that can approximate any function

KANs admit a similar guarantee through the Kolmogorov-Arnold representation theorem, though with a caveat*continuous, smooth*

where \(\Phi_{q,p}, \Phi_{q}\) are univariate functions from \(\mathbb{R}\) to \(\mathbb{R}\). In theory, we can parameterize and learn these (potentially non-smooth and highly irregular) univariate functions \(\Phi_{q,p}, \Phi_{q}\) by optimizing a loss function similar to any other deep learning model. But it’s not that obvious how one would “parameterize” a function the same way you would parameterize a weight matrix. For now, just assume that it is possible to parameterize these functions – the original authors choose to use a B-spline, but there is little reason to be stuck on this choice.

The expression from the theorem above does not describe a KAN with $L$ layers. This was an initial point of confusion for me. The universal approximation guarantee is only for models specifically in the form of the Kolmogorov-Arnold representation, but currently we have no notion of a “layer” or anything scalable. In fact, the number of parameters in the above theorem is a function of the number of covariates and not the choice of the engineer! Instead, the authors define a KAN layer \(\mathcal{K}_{m,n}\) with input dimension \(n\) and output dimension \(m\) as a parameterized matrix of univariate functions, \(\Phi = \{\Phi_{i,j}\}_{i \in [m], j \in [n]}\).

It may seem like the authors pulled this expression out of nowhere, but it is easy to see that the KAN representation theorem can be re-written as follows. For a set of covariates \(\boldsymbol{x} = (x_1,x_2,...,x_n)\), we can write any *continuous, smooth* function \(f(x_1,...,x_n) : \mathcal{D} \rightarrow \mathbb{R}\) over a bounded domain \(\mathcal{D}\) in the form

The KAN architecture, is therefore written as a composition of stacking these KAN layers, similar to how you would compose an MLP. I want to emphasize that unless the KAN is written in the form above, there is currently no *proven*

When first hearing about KANs, I was under the impression that the Kolmogorov-Arnold Representation Theorem was an analogous guarantee for KANs, but this is seemingly *not true*. Recall from the Kolmogorov-Arnold representation theorem that our guarantee is only for specific 2-layer KAN models. Instead, the authors prove that there exists a KAN using B-splines as the univariate functions \(\{\Phi_{i,j}\}_{i \in [m], j \in [n]}\) that can approximate a composition of continuously-differentiable functions within some *nice* error margin

*tldr; no, we have not shown that a generic KAN model serves as the same type of universal approximator as an MLP (yet).*

We talked quite extensively about “learnable activation functions”, but this notion might be unclear to some readers. In order to parameterize a function, we have to define some kind of “base” function that uses coefficients. When learning the function, we are actually learning the coefficients. The original Kolmogorov-Arnold representation theorem places no conditions on the family of learnable univariate activation functions. Ideally, we would want some kind of parameterized family of functions that can approximate any function, whether it be non-smooth, fractal, or some other kind of nasty property *on a bounded domain*

**Enter the B-spline**. B-splines are a generalization of spline functions, which themselves are piecewise polynomials. Polynomials of degree/order \(k\) are written as \(p(x) = a_0 + a_1x + a_2x^2 + ... + a_kx^k\) and can be parameterized according to their coefficients \(a_0,a_1,...,a_k\). From the Stone-Weierstrass theorem

Rather than be chunked explicitly like a spline, B-spline functions are written as a sum of basis functions of the form

where \(G\) denotes the number of grid points and therefore basis functions (which we have not defined yet), $k$ is the order of the B-spline, and \(c_i\) are learnable parameters. Like a spline, a B-spline has a set of $G$ grid points

We can plot an example for the basis functions of a B-spline with $G=5$ grid points of order $k=3$. In other words, the augmented grid size is $G+2k=11$:

When implementing B-splines for our KAN, we are not interested in the function \(f(\cdot)\) itself, rather we care about efficiently computing the function evaluated at a point \(f(x)\). We will later see a nice iterative bottom-up dynamic programming formulation of the Cox-de Boor recursion.

In this section, we describe a barebones, minimal KAN model. The goal is to show that the architecture is structured quite similarly to deep learning code that the reader has most likely seen in the past. To summarize the components, we modularize our code into (1) a high-level KAN module, (2) the KAN layer, (3) the parameter initialization scheme, and (4) the plotting function for interpreting the model activations.

If you’re using Colab, you can run the following as if they were code blocks. This implementation is also quite GPU-unfriendly, so a CPU will suffice.

In an attempt to make this code barebones, I’ve tried to use as little dependencies as possible. I’ve also included type annotations for the code.

The following config file holds some preset hyperparameters described in the paper. Most of these can be changed and may not even apply to a more generic KAN architecture.

If you understand how MLPs work, then the following architecture should look familiar. As always, given some set of input features \((x_1,...,x_n)\) and a desired output \((y_1,...,y_m)\), we can think of our KAN as a function \(f : \mathbb{R}^{n} \rightarrow \mathbb{R}^{m}\) parameterized by weights \(\theta\). Like any other deep learning model, we can decompose KANs in a layer-wise fashion and offload the computational details to the layer class. We will fully describe our model in terms of a list of integers `layer_widths`

, where the first number denotes the input dimension \(n\), and the last number denotes the output dimension \(m\).

The representation used at each layer is quite intuitive. For an input \(x \in \mathbb{R}^{n}\), we can directly compare a standard MLP layer with output dimension \(m\) to an equivalent KAN layer:

In other words, both layers can be written in terms of a generalized matrix-vector operation, where for an MLP it is scalar multiplication, while for a KAN it is some *learnable* non-linear function \(\Phi_{i,k}\). Interestingly, both layers look extremely similar!

Let’s think through how we would perform this computation. For our analysis, we will ignore the batch dimension, as generally this is an easy extension. Suppose we have a KAN layer \(\mathcal{K}_{m,n}\) with input dimension\(n\) and output dimension \(m\). As we discussed earlier, for input \((x_1,x_2,...,x_n)\),

In matrix form, this is can be nicely written as

The observant reader may notice that this looks exactly like the $Wx$ matrix used in an MLP. In other words, we have to compute and materialize

To finish off the abstract KAN layer (remember, we haven’t defined what the learnable activation function is), the authors define each learnable activation function $\Phi_{i,j}(\cdot)$ as a function of a learnable activation function $s_{i,j}(\cdot)$ to add residual connections in the network:

We can modularize the operation above into a “weighted residual layer” that acts over a matrix of \((\text{out_dim}, \text{in_dim})\) values. This layer is parameterized by each \(w^{(b)}_{i,j}\) and \(w^{(s)}_{i,j}\), so we can store \(\boldsymbol{w}^{(b)}\) and \(\boldsymbol{w}^{(s)}\) as parameterized weight matrices. The paper also specifies the initialization scheme of \(w^{(b)}_{i,j} \sim \mathcal{N}(0, 0.1)\) and \(w^{(s)}_{i,j} = 1\).

With these operations laid out in math, we have enough information to write a basic KAN layer by abstracting away the choice of learnable activation \(s_{i,j}(\cdot)\). Note that in the code below, the variables `spline_order`

, `grid_size`

, and `grid_range`

are specific to B-splines as the activation, and are only passed through the constructor. You can ignore them for now. In summary, we will first compute the matrix

following by the weighted residual across each entry, then we will finally sum along the rows to get our layer output. We also define a `cache()`

function to store the input vector \(\boldsymbol{x}\) and the \(\Phi \boldsymbol{x}\) matrix to compute regularization terms defined later.

Recall from the section on B-splines that each activation $s_{i,j}(\cdot)$ is a sum of products`[low_bound, up_bound]`

Again recall the Cox-de Boor recurrence from before. As a general rule of thumb we would like to avoid writing recurrent functions in the forward pass of a model. A common trick is to turn our recurrence into a dynamic-programming solution, which we make clear by writing in array notation:

*The tricky part is writing this in tensor notation*

The following explanation is a bit verbose, so bear with me. Our grid initialization function above generates a rank-3 tensor of shape `(out_dim, in_dim, G+2k+1)`

while the input $x$ is a rank-2 tensor of shape `(batch_size, in_dim)`

. We first notice that our grid applies to every input in the batch, so we broadcast it to a rank-4 tensor of shape `(batch_size, out_dim, in_dim, G+2k+1)`

. For the input $x$, we similarly need a copy for every output dimension and every basis function to evaluate over, giving us the same shape through broadcasting. We can align the `in_dim`

axis of both the grid and the input because $j$ aligns in $s_{i,j}(x_j)$. The $i$ indexes over the basis functions, or the last dimension of our tensors. We write out the vectorized DP in this form, as we note that we can fix $j$. Finally, we perform DP over our $j$ index based on the recurrence rule, yielding the B-spline basis functions evaluated on each input dimension to be used for each output dimension. This notation may be confusing, but the operation is actually quite simple – I would recommend ignoring the batch dimension and drawing out what you need to do.

*tldr; we need to compute something for each element in a batch, for each activation, for each B-spline basis. we can use broadcasting to do this concisely, from the code below*

With the B-spline logic out of the way, we have all of our intermediate computation logic done. We still have to define our parameters \(c_i\) and compute the B-splines from the basis functions, but this is just a simple element-wise multiplication and sum. We can now pass the B-spline output into the weighted residual layer defined earlier and compute our output vector. In summary, we are computing

If you’ve gotten to this point, congratulations! You’ve read through the hardest and most important part of this article. The rest of this post talks about a generic model training loop, visualization functions, and optimizations that can be made to B-spline specific KANs. If you’re interested in future directions for these models, I’d recommend reading into Awesome-KAN and getting started! Otherwise, if you’d like to have a deeper understanding of the original KAN paper, keep reading!

Rather unsurprisingly, regularization is an important component of KANs. The authors of KAN motivate two types of regularization – L1 regularization to limit the number of active activation functions, and entropy regularization to penalize duplicate activation functions.

L1 regularization for a weight matrix \(W\) in an MLP is straightforward – just take the Frobenius norm of the matrix. However, for activation functions, using the parameters of the function are not necessarily a good choice. Instead, the magnitude of the **function evaluated on the data** is used. More formally, suppose we have a batch of inputs \(\{x^{(b)}_1,...,x^{(b)}_n \}_{b \in \mathcal{B}}\) into a KAN layer $\mathcal{K}_{m,n}$. The L1 norm of an activation from input node $j$ to output node $i$ is defined as the absolute value of the mean of that activation on $x_j$, averaged over the batch. In other words,

The L1 norm of the layer is then defined as

In addition to wanting sparse activations for better interpretability and performance

The regularization term is just a weighted sum of the two terms above. These regularization expressions are not specific to the B-splines representation chosen by the authors, but their effect on other choices of learnable activation functions is underexplored at the moment.

In this section, we will discuss the basic training loop for a KAN, including a script for visualizing the network activations. As you will notice, the framework for training a KAN is almost identical to a standard deep learning train loop.

Despite the extra machinery necessary to apply our model parameters to our input, it is easy to see that the operations themselves are differentiable. In other words, barring some extra optimization tricks that we will discuss in Part III, the training loop for KANs is basically just a generic deep learning train loop that takes advantage of autodifferentiation and backpropagation. We first define a function for generating training data for a function \(f(x_1,...,x_n)\) over a bounded domain \(\mathcal{D} \in \mathbb{R}^{d}\).

As the reader will see below, the KAN training loop is extremely simple, and uses the familiar `zero_grad()`

, `backward`

, `step()`

PyTorch loop. We do not even use the L-BFGS

We can also define a simple plotting function that takes the `results`

dictionary from above.

We mostly adapt the network visualization code from the original repository. While the code is quite dense, all we need to do is plot our stored activations per layer, save the plots, then draw out the grid of network connections. You can mostly skim this code unless you’re interested in prettifying the visualizations.

For example, we can visualize the base network activations with the script below.

We can put this all together with a simple example. I would recommend scaling this further to a more interesting task, but for now you can verify that the model training is correct. Consider a function of the form \(f(x_1,x_2) = \exp \left( \sin(\pi x_1) + x_2^3 \right)\). We are going to learn this function using a KAN of the form \(f(x) = \mathcal{K}_{1,1} \left( \mathcal{K}_{1,2} \left( x_1, x_2 \right) \right)\).

The attentive reader may have noticed that the choice of B-spline is somewhat arbitrary, and the KAN itself is not necessarily tied to this choice of function approximator. In fact, B-splines are not the only choice to use, even among the family of different spline regressors.

A large portion of the original paper covers computation tricks to construct KANs with B-splines as the learnable activation function. While the authors prove a (type of) universal approximation theorem for KANs with B-splines, there are other choices of parameterized function classes that can be explored, potentially for computational efficiency.

**Remark**. Because we are modifying the code from Part I, I’ve tried to keep the code compact by only including areas where changes were made. You can either follow along, or use the full KAN notebook.

Recall that the flexibility of our B-splines are determined by the number of learnable coefficients, and therefore the number of basis functions that it has. Furthermore, the number of basis functions is determined by the number of knot points \(G\). Suppose now that we want to include \(G'\) knots for a finer granularity on our learnable activations. Ideally, we want to add more knot points while preserving the original shape of the function. In other words, we want

We can tensorize this expression with respect to a batch of inputs $(z_1,…,z_b)$

which is of the form $AX = B$. We can thus use least-square to solve for $X$, giving us our new coefficients on our finer set of knot points.

I wanted to mention that for the `driver`

parameter in `torch.linalg.lstsq`

, there are certain solvers like QR decomposition that require full-rank columns on the basis functions. I’ve chosen to avoid these solvers, but there are several ways to go about solving the least-squares problem efficiently.

We can visually evaluate the accuracy of our grid extension algorithm by simply looking at the activations before and after a grid extension.

Pruning network weights is not unique to KANs, but they help the models become more readable and interpretable. Our implementation of pruning is going to be *extremely inefficient*, as we will mask out activations **after they are calculated**. There is already a large body of works for neural networks dedicated to bringing about performance benefits through pruning*before* the computation, but tensorizing this process efficiently is not clean.

We also need to define a metric for pruning. We can define this function at the high-level KAN module. For every layer, each node is assigned two scores: the input score is the absolute value of the maximum activation averaged over the training batch input

If \(\text{score}^{(\ell, \text{in})}_{i} < \theta \lor \text{score}^{(\ell, \text{out})}_{i} < \theta\) for some threshold $\theta = 0.01$, then we can prune the node by masking its incoming and outgoing activations. We tensorize this operation as a product of two indicators below.

In practice, you will call the `prune(...)`

function after a certain number of training steps or post-training. Our current plotting function does not support these pruned activations, but we add this feature in the Appendix.

A large selling point of the original paper is that KANs can be thought of as a sort of “pseudo-symbolic regression”. In some sense, if you know the original activations before-hand or realize that the activations are converging to a known non-linear function (e.g. $b \sin(x)$), we can choose to fix these activations. There are many ways to implement this feature, but similar to the pruning section, I’ve chosen to favor readability over efficiency. The original paper mentions two features that **are not implemented below**. Namely, storing coefficients affine transformations of known functions (e.g. $a f(b x + c) + d$) and fitting the current B-spline approximation to a known function. The code below allows the programmer to directly fix symbolic functions in the form of univariate Python `lambda`

functions. First, we provide a function for a KAN model to fix (or unfix to the B-spline) a specific layer’s activation to a specified function.

We first define a `KANSymbolic`

module that is analogous to the `KANActivation`

module used to compute B-spline activations. Here, we store an array of functions \(\{f_{i,j}(\cdot)\}_{i \in [m], j \in [n]}\) that are applied in the forward pass to form a matrix \(\{f_{i,j}(x_j)\}_{i \in [m], j \in [n]}\). Each function is initialized to be an identity function. Unfortunately, there is not (to my knowledge) an efficient way to perform this operation in the general case where all the symbolic functions are unique.

We now have to define the symbolic activation logic inside the KAN layer. When computing the output activations, we use a similar trick to the pruning implementation by introducing a mask that is $1$ when the activation should be symbolic

We can test our implementation by learning the function \(f(x_1,x_2) = \sin(x_1) + x_2^2\) and plotting the result.

This section will be focused on applying KANs to a standard machine learning problem. The original paper details a series of examples where KANs learn to fit a highly non-linear or compositional function. Of course, while these functions are difficult to learn, the use of learnable univariate functions makes KANs suitable for these specific tasks. I emphasized the similarities between KANs and standard deep learning models throughout this post, so I also wanted to present a deep learning example (even though it doesn’t work very well). We will run through a simple example of training a KAN on the canonical MNIST handwritten digits dataset

In the interest of reusing the existing train logic we created earlier, we write a function to turn a `torch.Dataset`

with MNIST into the dictionary format. *For general applications, I recommend sticking with the torch Dataloader framework*.

Finally, like all previous examples, we can run a training loop over the MNIST dataset. We compute the training loss using the standard binary cross-entropy loss and define the KAN to produce logits from 0-9. Due to restrictions in our `train()`

function, we define our test loss as the total number of incorrectly marked samples out of $100$ validation samples.

You may notice that the training is significantly slower even for such a small model. Furthermore, the results here are not good as expected. I’m confident that with sufficient tuning of the model you can get MNIST to work (there are examples of more sophisticated KAN implementations

I hope this resource was useful to you – whether you learned something new, or gained a certain perspective along the way. I wrote up this annotated blog to clean up my notes on the topic, as I am interested in improving these models from an efficiency perspective. If you find any typos or have feedback about this resource, feel free to reach out!

I may re-visit this section in the future with some more meaningful experiments when I get the time.

The plotting function defined in Network Visualization doesn’t include logic for handling the pruned activation masks and the symbolic activations. We will include this logic separately, or you can follow the rest of the visualization code in the original repository.

It is known that these models currently do not scale well due to both memory and compute inefficiencies. Of course, it is unknown whether scaling these models will be useful, but the authors posit that they are more parameter efficient than standard deep learning models because of the flexibility of their learned univariate functions. As you saw in the MNIST example, it is not easy to scale the model even for MNIST training. I sort of avoided this question before, but I want to highlight a few reasons for these slowdowns.

- We fully materialize a lot of intermediate activations for the sake of demonstration, but even in an optimized implementation, some of these intermediate activations are unavoidable. Generally, materializing intermediate activations means lots of movement between DRAM and the processors, which can cause significant slowdown. There is a repository called KAN-benchmarking dedicated to evaluating different KAN implementations.
*I may include an extra section on profiling in the future.* - Each activation \(\Phi_{i,j}\) or edge in the network is potentially different. At an machine instruction level, this means that we cannot take advantage of SIMD or SIMT that standard GEMM or GEMV operations have on the GPU. There are alternative implementations of KANs that were mentioned earlier that attempt to get around these issues
, but even then they do not scale well compared to MLPs. I suspect the choice of the family of parameterized activations will be extremely important moving forward.

A natural question is whether we have to fix the knot points to be uniformly spaced, or if we can use the data to adjust our knot points. The original paper does not detail this optimization, but their codebase actually includes this feature. If time permits, I may later include a section on this – I think it may be important for performance of KANs with B-splines, but for general KANs maybe not.

Just as a formality, if you want to cite this for whatever reason, use the BibTeX below.

```
@article{zhang2024annotatedkan,
title = "Annotated KAN",
author = "Zhang, Alex",
year = "2024",
month = "June",
url = "https://alexzhang13.github.io/blog/2024/annotated-kan/"
}
```