原文:
towardsdatascience.com/kolmogorov-arnold-networks-the-latest-advance-in-neural-networks-simply-explained-f083cf994a85
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b639e9d4561750c41296ff98dfce0c64.png
MLP 和 KAN 的比较。图片来自 论文。
四月,一篇名为 KAN: Kolmogorov–Arnold Networks 的论文出现在 arXiv 上。宣布的推文获得了 ~5k 个赞,对于一个论文的宣布来说,这相当具有病毒性。相关的 GitHub 仓库 已经有 7.6k 个星标,并且还在增加。
cdn.embedly.com/widgets/media.html?type=text%2Fhtml&key=a19fcc184b9711e1b4764040d3dc5c07&schema=twitter&url=https%3A//twitter.com/ZimingLiu11/status/1785483967719981538&image=
科尔莫哥洛夫-阿诺德网络(KAN)是神经网络构建块的一个全新类别。它旨在比多层感知器(MLP)更具表现力、更不易过拟合且更可解释。MLP 在深度学习模型中无处不在。例如,我们知道它们被用于 GPT-2、3 和(可能是)4 的变压器块之间。对 MLP 的改进将对机器学习世界产生广泛的影响。
MLP 的概述
MLP 实际上是一个非常古老的架构,其历史可以追溯到 50 年代。其想法是复制大脑的结构;拥有大量相互连接的神经元,这些神经元将信息向前传递,因此得名前馈网络。
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b497fd60dffccaef24130143db2fae1f.png
来自arxiv.org/pdf/2101.03541的 MLP 表示
MLPs 通常会展示出类似于上图那样的图表。对于外行人来说,这非常有用,但在我看来,它并没有传达任何关于真正发生的事情的深层理解。用数学来表示它要容易得多。
假设我们有一些输入x和一些输出y。一个两层的 MLP 将如下所示:
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/239d3dcf06387dfbf4924b75a89b093f.png
其中,W是可学习权重的矩阵,b是偏置向量的向量。函数f是一个非线性函数。从这些方程中可以看出,MLP 是一系列带有非线性项的线性回归模型。这是一个非常基本的设置。
尽管基础,但它极其具有表现力。有数学保证表明多层感知器(MLPs)是通用****逼近器,即:它们可以逼近任何函数,类似于所有函数都可以用泰勒级数表示。
为了训练模型的权重,我们使用反向传播,这得益于自动微分(autodiff)。这里我不会深入探讨,但重要的是要注意,autodiff 可以作用于任何可微分的函数,这将在以后变得很重要。
MLPs 的问题
MLPs 被广泛应用于各种用例,但有一些严重的缺点。
由于它们作为模型非常灵活,可以很好地拟合任何数据。因此,它们很可能过拟合。
模型中通常有很多权重,很难解释这些权重以从数据中得出结论。我们经常说深度学习模型是“黑盒”。
有很多权重也意味着它们可能需要很长时间来训练,GPT-3 的大多数参数都在 MLP 层中。
科尔莫哥洛夫-阿诺德网络
科尔莫哥洛夫-阿诺德表示定理
科尔莫哥洛夫-阿诺德表示定理在目标上与支撑 MLPs 的通用逼近定理相似,但前提不同。它本质上表明任何多元函数都可以通过 1 维非线性函数的加法来表示。例如:向量**v=(x1, x2)**的除法运算可以用对数和指数来替换:
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2d0f92cf54461dbfffbcc364fa0536a6.png
那么,这究竟有什么用?这实际上实现了什么?
这为我们提供了一个不同但简单的范例来开始构建神经网络架构。作者声称这种架构比使用 MLPs 更可解释、参数效率更高,并且具有更好的泛化能力。在 MLPs 中,非线性函数是固定的,并且在训练过程中永远不会改变。在 KANs 中,没有更多的权重矩阵或偏差,只有拟合到数据的 1 维非线性函数。这些非线性函数随后被相加。然后我们可以堆叠更多层来创建更复杂的函数。
B 样条
在 KANs 中非线性函数的表示方式中,有一些重要的事情需要注意。与 MLPs 中它们被明确定义为例如ReLU(),Tanh(),silu()等不同,在 KANs 中,作者使用样条。这些实际上是分段多项式。它们来自计算机图形学领域,在那里过参数化不是需要担心的事情。
样条曲线解决了在多个点之间平滑插值的问题。如果你熟悉机器学习理论,你会知道要完美地插值n个数据点,你需要一个n-1阶的多项式。问题是高阶多项式可能会变得非常扭曲,看起来并不平滑。
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/d3f988566042a0e23e5d1ad125fde088.png
10 个数据点被一个 9 阶多项式完美拟合。图片由作者提供。
样条通过将分段多项式函数拟合到数据点之间的各个部分来解决这个问题。这里我们使用立方样条。
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ab1c0e30ab678e8d90d8c24431cbd9e2.png
立方样条插值效果更好,但无法泛化。图片由作者提供。
对于立方样条(样条的一种类型),为确保平滑性,在数据点(或节点)的位置对一阶和二阶导数设置约束。数据点两侧的曲线必须在数据点处具有匹配的一阶和二阶导数。
KANs 使用 B 样条,另一种具有局部性(移动一个点不会影响曲线的整体形状)和匹配的二阶导数(也称为 C2 连续性)特性的样条。这以实际上不通过点(除了在端点处)为代价。
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e815af935eab6479f1a5f3c7581facc5.png
5 个数据点的 3 个 B 样条。来自维基百科。注意曲线并不通过数据点。
在机器学习中,尤其是在应用于物理学时,不通过每个数据点是可以接受的,因为我们期望测量会有噪声。
这就是 KANs 的计算图中每个边发生的事情。一维数据被一组 B 样条拟合。
进入 KAN
因此,我们现在在每个计算图的边上都有一个分段参数曲线。在每个节点,这些值都会相加:我们之前看到,我们可以通过这种方式逼近任何函数。
<…/Images/b5754643c2098922cdfe20f82c628fe5.png>
KAN[2, 5, 1](2 个输入,5 个隐藏节点和 1 个输出)。图片来自paper。
要训练此类模型,我们可以使用标准的反向传播。在这种情况下,作者使用了 LBFGS,这是一种二阶优化方法(与一阶的 Adam 相比)。另一个需要注意的细节:在每个边,代表一维函数,都有一个 B 样条,但作者还添加了一个非线性:一个silu函数。
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f1570fb7608cc2a71ff4a84503c1ed88.png
对于这个解释并不清楚,但最可能的原因是梯度消失(至少这是我的猜测)。
让我们使用它
我将使用作者提供的代码,它工作得非常好,并且有很多示例可以帮助更好地理解它。
他们使用合成数据,这些数据是从以下函数生成的:
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b724dd571952e53be575bbd00a14cc59.png
定义模型
model=KAN(width=[2,5,1],grid=5,k=3,seed=0)这定义了 3 个参数:
宽度,其定义方式与 MLP 类似:一个列表,其中每个元素对应一个层,元素值是该层的宽度。在这种情况下,有 3 层;输入维度是 2,有 5 个隐藏维度,输出维度是 1。
网格与 B 样条相关,它描述了数据点之间网格的精细程度。增加这个值可以创建出更多波动的函数。
k是 B 样条的多项式阶数,通常情况下,三次是一个不错的选择,因为三次曲线对于样条来说具有很好的特性。
种子,随机种子:样条权重以高斯噪声随机初始化(就像在常规 MLP 中一样)。
训练
model.train(dataset,opt="LBFGS",steps=20,lamb=0.01,lamb_entropy=10.0)该库的 API 非常直观,我们可以看到我们正在使用 LBFGS 优化器,进行 20 次训练步骤。接下来的两个参数与网络的正则化有关。
训练之后的下一步是修剪模型,这会移除低于相关阈值的所有边和节点,完成这一步后,建议重新训练一下。然后,将每个样条边转换成符号函数(如 log、exp、sin 等).这可以通过手动或自动完成。该库提供了一个非常棒的工具,可以通过model.plot()方法查看模型内部的情况。
# Code to fit symbolic functions to the fitted splinesifmode=="manual":# manual modemodel.fix_symbolic(0,0,0,"sin")model.fix_symbolic(0,1,0,"x²")model.fix_symbolic(1,0,0,"exp")elifmode=="auto":# automatic modelib=["x","x²","x³","x⁴","exp","log","sqrt","sin","abs"]model.auto_symbolic(lib=lib)一旦在每个边放置了符号函数,就会进行最后的重新训练,以确保每个边的仿射参数是合理的。
整个训练过程总结在下方的图中。
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/21db754c74b2cca0eeccbd9915cde300.png
KAN 进行符号回归的一个示例。图来自论文。
完整的训练代码看起来像这样:
# Define the modelmodel=KAN(width=[2,5,1],grid=5,k=3,seed=0)# First trainingmodel.train(dataset,opt="LBFGS",steps=20,lamb=0.01,lamb_entropy=10.0)# Prune edges that have low importancemodel=model.prune()# Retrain the pruned model with no regularisationmodel.train(dataset,opt="LBFGS",steps=50)# Find the symbolic functionsmodel.auto_symbolic(lib=["x","x²","x³","x⁴","exp","log","sqrt","sin","abs"])# Find the afine parameters of the fitted functions without regularisationmodel.train(dataset,opt="LBFGS",steps=50)# Display the resultant equationmodel.symbolic_formula()[0][0]# Print the resultant symbolic function一些思考
不稳定性
模型中有许多超参数可以进行调整,这些调整可以产生非常不同的结果。例如,在上面的例子中:将隐藏神经元的数量从 5 改为 6 意味着 KAN 无法找到正确的函数。
<…/Images/985586d2952a7565d8e58d2c4217eba1.png>
KAN[2,6,1]找到的结果函数。图由作者提供。
这种不稳定性是可以预见的,因为架构是全新的。人们花费了几十年时间才找到最佳方法来调整 MLP 的超参数(如学习率、批量大小、初始化等)。
为什么不进行符号回归?
KANs 的目标是与符号回归以及 MLPs 竞争。将 KANs 与 PySR 库进行比较,我们发现它也能找到训练数据的正确函数形式。另一方面,它比 KANs 更容易出现异常。作者也在他们的论文中提到了这一点。在我的案例中,改变 PySR 模型的随机种子使得得到的方程变得有些不合逻辑。而 KANs 则不太依赖随机性。
你可以在我的GitHub 仓库中找到我所有的实验。
结论
多层感知器(MLP)已经存在很长时间了,它迫切需要升级。我们知道这种变化是可能的,大约 6 年前,在序列建模中无处不在的长短期记忆网络(LSTMs)被转换器(transformers)取代,成为标准语言模型架构的构建块。如果这种变化发生在 MLPs 上,那将是非常令人兴奋的。另一方面,这种架构仍然不够稳定,效果并不十分出色。时间将证明,社区是否能够找到一种方法来克服这种不稳定性,并释放 KANs 的真正潜力,或者 KAN 会被遗忘,成为机器学习的一个小知识点。
我对这种新的架构感到非常兴奋,但同时也感到怀疑。
我确实更倾向于兴奋的一面,因为我身处科学 AI 领域。KAN 是为了我感兴趣的任务而构建的。尽管如此,也不太可能有一些 AI 实验室已经在尝试用 KAN 替换他们 LLMs 中的 MLP 层。
参考文献
arxiv.org/abs/2404.19756
en.wikipedia.org/wiki/Multilayer_perceptron
en.wikipedia.org/wiki/Universal_approximation_theorem
en.wikipedia.org/wiki/B-spline
www.youtube.com/watch?v=YMl25iCCRew&list=PLWfDJ5nla8UpwShx-lzLJqcp575fKpsSO