跳转至

26 为什么说Mamba是Transformer的最强挑战者?

你好,我是独行。

在过去的几年里,Transformer模型在自然语言处理领域占据了主导地位。自从2017年谷歌提出Transformer以来,BERT、GPT-3等基于Transformer的模型取得了巨大的成功。

然而技术的进步从未停止,最近出现了一种新型模型——Mamba,被认为是 Transformer 的最强挑战者。那么,Mamba凭什么能与Transformer一较高下呢?这节课我就来带你看看Mamba的过人之处。

Transformer的局限

Transformer功能很强,但并不完美,尤其是在处理长序列方面,Transformer模型中自注意力机制的计算量会随着上下文长度的增加呈平方级增长,比如上下文长度增加32倍时,计算量可能会增长1000倍,计算效率非常低。为什么会这样?因为Transformer 模型在计算自注意力时,每个输入元素都要与序列中的其他元素进行比较,导致总体计算复杂度为$O(n^2*d)$,其中$n$是序列长度,$d$是元素表示的维度。

为了克服这些缺陷,研究者们开发出了很多注意力机制的高效变体,比如线性注意力(Linear Attention)、稀疏注意力(Sparse Attention)、低秩注意力(Low-Rank Attention)等等,但这往往是以牺牲其有效性为代价的。到目前为止,这些变体都还没有被证明能在不同领域发挥有效作用。为什么这些变体会牺牲注意力的有效性?我来解释下。

目前大模型扩展上下文的几个方式:稀疏注意力、滑动窗口、降采样,这几个方式有共同的缺点,我把它们称为“技术捷径”,就是因为无论采用哪种方式,都是想办法丢掉一部分“不重要的数据”,那么如何评判数据重不重要,这是一个很难的问题,所以很多种情况下会产生误杀。这里也一样,所谓稀疏注意力,就是有选择性地计算部分注意力权重,那么必然会有一定的概率,忽略一些重要的信息或者依赖,进而牺牲性能。

同样,线性注意力,将标准的自注意力机制中的矩阵乘法改为线性计算,从而将计算复杂度从$O(N^{2})$降低到$O(N)$,但是由于省略了softmax操作和全局注意力矩阵,线性注意力可能无法捕捉到所有重要的依赖关系,尤其是长距离依赖,最终导致模型的表达能力不足,可能会导致性能下降。低秩注意力逻辑也类似。

那有没有更好的解法呢?有人提出用RNN解决自注意力机制的计算复杂性问题。RNN之所以计算快,是因为 RNN 只需要考虑先前的隐藏状态和当前输入。它可以避免重新计算所有先前的隐藏状态,正好与Transformer相反,可谓是两个极端。但是RNN的这种计算方式有非常大的弊端,就是当依赖非常长的时候,RNN 往往会随着时间的推移而忘记之前的信息,因为它们只考虑一个先前的状态,或者说对先前信息的记忆会越来越淡,这样带来的问题就是上下文会失效。

那有没有什么办法可以从两个极端方案之间选择一个折中的办法呢?当然有,那就是Mamba!

Mamba 的优势

  1. 基于 S4 架构

S4全称是Structured State Spaces for Sequence Modeling,用于序列建模的结构化状态空间,是一种针对序列建模的高效模型。它结合有选择性的(Selective)状态空间模型(SSM)和深度学习技术,旨在处理长序列的依赖关系,并突破计算瓶颈。S4使用线性递归方程和频域方法,显著提升计算效率,同时通过结构化参数设计捕捉长距离依赖。该模型在自然语言处理、时间序列预测和语音处理等任务中表现优异。S4 的核心优势在于其计算效率和捕捉远程依赖关系的能力,为长序列建模提供了一种创新且强大的解决方案。

状态空间模型,是一类描述动态系统的数学模型,通过这个模型,我们可以理解和预测一个系统是如何随着时间的变化而变化的。举个简单的例子:你在迷宫里移动,寻找出口,状态空间模型就是用来描述你的位置和移动速度是如何随着时间的变化而变化的。有几个关键的概念如下:

  • 状态(State): 状态是系统内部的情况或条件。在我们的例子中,人的状态可以包括当前位置信息和当前移动速度。
  • 观测(Observation): 观测是我们可以实际测量到的东西。在我们的例子中,观测可能是通过GPS测得的人的位置。
  • 状态方程: 状态方程告诉我们系统的状态是如何随时间变化的。比如,人当前位置是根据它的上一个位置和移动速度和方向来决定的。
  • 观测方程: 观测方程告诉我们如何从状态中获得观测值。比如,通过人的实际位置和出口坐标,得到离出口最近的下一步移动方案等等。
  1. 高效性

Mamba在计算效率上表现突出,特别是在处理大规模数据的时候。与Transformer相比,Mamba采用了一些优化技术,比如S4和动态压缩,使计算复杂度更低。Transformer的自注意力机制计算量比较大,而Mamba通过改进这种机制,基于S4,显著减少了计算量,从而能够更快速地进行训练和推理。

你可以对比一下Transformer和Mamba计算复杂度的伪代码示例。

# 伪代码示例:对比Transformer和Mamba的计算复杂度
# Transformer自注意力机制的计算复杂度
def transformer_attention(Q, K, V):
    attention_scores = Q @ K.T  # 计算注意力得分
    attention_weights = softmax(attention_scores)  # 计算注意力权重
    output = attention_weights @ V  # 计算输出
    return output

# Mamba优化后的注意力机制计算复杂度
def mamba_attention(Q, K, V):
    attention_scores = efficient_dot_product(Q, K)  # 优化后的点积计算
    attention_weights = softmax(attention_scores)
    output = attention_weights @ V
    return output
  1. 适应性

Mamba不仅在NLP任务中表现优异,还能处理图像识别等任务。其设计使模型在面对不同类型的数据时,能够自如应对,主要原因包括其高效的架构设计和优化的计算方法。Mamba通过精心设计的层结构和连接方式,有效地提取和处理图像特征,同时使用加速技术提升计算效率。

此外,Mamba支持多任务学习,在图像分类、目标检测和图像分割等任务中表现优异。模型通过正则化技术和数据增强方法提高了鲁棒性和泛化能力,能够适应不同数据集。同时,在一些层与层之间的连接过程中,将上一层的输出和下一层的输入进行权重绑定,达到权重共享的目的,这样就相当于减少了一部分参数,通过这样的优化,使Mamba在保持高性能的同时,减少了计算和存储成本,进一步增强了其实用性。这种多功能性使Mamba在各类应用场景中都有出色表现。

  1. 内存利用

Mamba在内存使用上更加高效,这对于需要在资源有限的设备上运行的模型尤为重要。比如,在嵌入式系统或移动设备上,内存通常是限制因素。Mamba优化了内存分配和使用,具体来说,就是引入了一种动态序列长度调整(DSLA)的新技术,允许网络根据输入序列的复杂性和长度调整其内存大小,这样就可以更有效地使用内存,减少了内存消耗,使它能在这些设备上高效运行。

  1. 训练速度

由于Mamba特殊的网络结构以及内存优化技术,所以Mamba的训练速度比Transformer快得多,这对于需要快速迭代和测试的开发者来说非常重要。快速的训练速度不仅能节省时间,还能加快模型的开发和部署进程。

  1. 性能表现

在许多基准测试中,Mamba展示了与Transformer相媲美甚至更优的性能。特别是在处理长序列数据和复杂任务时,Mamba表现得特别出色。这让它在许多实际应用中成为更具吸引力的选择。我们来看一下Mamba、Transformer以及RNN在训练和推理性能方面的一些比较。

图片

Mamba架构之状态空间模型(SSM)

模型结构

SSM 用于根据某些输入预测它们的下一个状态,在时间$t$,SSM可以表示为:

  • 输入序列$x(t)$,:在迷宫中向左和向下移动;
  • 潜在状态表示$h(t)$:出口距离和人的坐标等;
  • 预测输出序列$y(t)$:再次向左/向右移动以更快到达出口。

图片

SSM描述的是动态系统,例如在 3D 空间中移动的物体,可以通过两个方程根据其在时间$t$的状态进行预测,其核心是两个方程:状态方程和输出方程(或者也可以称为:观测方程)。

$$h’(t)=Ah(t)+Bx(t)$$

$$y(t)=Ch(t)+Dx(t)$$

目标是找到这个状态$h(t)$,以便我们可以从输入计算得出输出序列。

图片

状态方程描述的是,矩阵A和B如何根据输入值和上一个状态值推导当前状态的值。

图片

输出方程描述的是,矩阵C和D如何通过状态值和输入值推导输出值的过程。

图片

注意:矩阵A、B、C 和 D 就是我们常说的参数,它们是可学习的。

将这两个方程合并可视化,我们可以得到以下架构:

图片

把这个架构图稍微细化一下就可以得到完整的处理序列图。

图片

矩阵D实际上没有参与SSM,直接从输入到输出,所以关于矩阵D的连接被称为跳跃连接。而矩阵D没有参与到SSM序列计算,所以我们说SSM是没有跳跃连接的,参考下面阴影部分。

图片

所以,实际上矩阵ABC才是SSM的核心,因此,最开始的示意图可以用下面的图示来表示:

图片

连续信号到离散信号

通过连续信号去计算状态是有一定的难度的,而通常我们的输入是离散的,所以我们需要将模型就行离散化。为了实现离散化,使用了一个叫 Zero-order hold 的技术,工作流程大概是这样的:每当接收到离散信号,先保留其值,直到收到新的离散信号,这样的话,使得SSM可以使用连续信号,示意图如下:

图片

具体指保留多长时间,是一个可学习的参数,称为步长,我们用$\Delta$表示。有了连续的输入信号,我们可以生成连续的输出,并且仅根据输入的时间步长$\Delta$对值进行采样,这些采样值就是我们的离散输出。

图片

数学上,我们可以通过如下方式应用Zero-order hold,使连续 SSM 转变为离散 SSM,该 SSM 由一个公式表示,该公式不再是函数到函数,x(t) → y(t),而是序列到序列,x ₖ → y ₖ 。

图片

连续SSM和离散SSM对比:

图片

使用$k$而不是$t$来表示离散时间步长,以便在我们提到连续与离散 SSM 时更加清晰。

循环表示

离散化 SSM 使我们能够以特定的时间步长而不是连续信号来处理问题,这一点类似于我们前面讲过的循环神经网络RNN,循环方法在这里也适用,在每个时间步,我们计算当前输入$Bx_{k}$如何影响之前的状态$Ah_{k-1}$,然后计算预测输出$Ch_{k}$,计算过程如下:

图片

回想一下RNN里,输入经过隐藏层进行循环计算的过程,通过示意图可以看出这里原理基本是一样的。

图片

所以我们用处理RNN的方式处理离散信号,可以表示为:

图片

展开隐藏层如下:

图片

使用这个方式处理,带来了RNN的天然优势和劣势,那就是训练慢、推理快!

卷积表示

SSM也可以使用卷积的方式表示,比如:

图片

图片

实际在计算输出值的时候,我们看一下SSM是怎么工作的。

图片

进一步移动:

图片

最后一步,得到如下完整计算公式:

图片

这样,将 SSM 表示为卷积,使它可以像卷积神经网络(CNN)一样进行并行训练。然而,由于内核大小固定,它们的推理速度不如 RNN 那样快。

三种方式对比

连续模型、循环模型、卷积模型放在一起如下图所示:

图片

SSM厉害之处在于,在训练的时候,我们可以使用卷积表示,而在推理的时候,可以选择使用循环表示,这一点使Mamba具备了训练和推理都很高效的特性。

图片

该模型被称为线性状态空间层(LSSL),这些表示具有一个重要特性,即线性时间不变性(LTI)。LTI 指出,SSM 的参数 A、B 和 C 在所有时间步长上都是固定的。这意味着对于 SSM 生成的每个 token,矩阵A、B 和 C 都是相同的。

矩阵A的重要性

从上面隐藏层的计算过程看得出来,矩阵A的值贯穿整个状态计算过程。

图片

因此矩阵A的创建自然不能随机产生,我们希望他能包含历史信息,也就是上下文信息,Mamba使用HiPPO创建矩阵A,关于HiPPO的解释,简单来看,就是一种通过多项式来进行历史信息循环记忆的方法,你可以阅读 HiPPO 论文了解更详细的信息。HiPPO矩阵计算公式为:

图片

对角线以上全是0,对角线上是$n+1$,对角线以下计算方法为$(2n+1){1/2}(2k+1)$,类似于这样:

图片

关于HiPPO为什么能记住历史状态,这是经过论证过的,感兴趣的话你可以继续阅读相关数学论文

学习到这里了,基本概念我就向你介绍完了,那我们再来回归正题,SSM和S4有什么关系呢?让我接着向你介绍。

Mamba架构之神奇的S4

S4全称是Structured State Spaces for Sequence Modeling,用于序列建模的结构化状态空间,是一种针对序列建模的高效模型。S4基于有选择性的SSM(关于有选择性的概念我在下面介绍),也就是说S4是SSM的一种具体实现。结合刚刚的分析来看,S4其实就是将HiPPO应用到上文提的卷积和循环表示,用来处理长距离依赖关系。由三部分组成:

  • 状态空间模型SSM;
  • HiPPO;
  • 用于创建循环和卷积表示的离散化。

图片

基于这样的架构,Mamba拥有了多种优势,既有卷积网络的并行训练特性,又有循环网络的快速推理能力,同时还具备Transformer的长距离依赖,可谓是集各种技术的优点与一身。接下来我再向你介绍下什么是有选择性SSM。

注:如果你对S4的源码感兴趣,可以阅读这篇文章《带注释的 S4》

Mamba架构之选择性SSM

选择性其实就是有选择的记录历史信息,SSM 的循环表示会创建一个非常高效的小状态,它会压缩整个历史记录,与不压缩历史记录(通过注意矩阵)的 Transformer 模型相比,它的功能要弱得多,但是 Mamba 的目标是兼具两全其美,小状态和 Transformer 状态一样强大。具体如何进行压缩呢?

在我们刚刚介绍状态空间模型的内容中提到过,矩阵ABC与输入无关,而且在不同的时间序列内,ABC都相同。这样的话,当我们要对输入序列进行选择的时候,无法确认要丢掉哪部分信息,所以我们其实是希望ABC矩阵的值和输入产生关联,这样就可以对输入内容进行判断,进而进行取舍。

Mamba把输入的序列和批量大小进行合并,让矩阵B和C以及步长依赖输入。

图片

对于每个输入的序列,可以有不同的矩阵B和C对应,从而也能感知内容的不同,这种情况下,选择保留什么忽略什么就可以做到了。

Mamba的劣势

虽然Mamba在训练速度等方面表现出色,但也存在一些劣势。

  1. 复杂性:Mamba模型相对于传统的Transformer模型来说,可能具有更高的复杂性,包括架构设计和实现上的复杂性。这可能会增加模型的调试和优化难度。
  2. 通用性:Mamba模型可能更适用于特定的任务或数据集,而不是一般性的应用。在某些场景下,可能需要对模型进行定制或调整,才能达到最佳效果。
  3. 迁移学习:由于Mamba模型可能具有特定的架构和设计,因此在迁移学习或应用到其他领域时,可能需要进行额外的工作来适应新的任务或数据集。
  4. 实验验证:由于Mamba模型相对较新,可能还需要进一步的实验验证和研究来证明其在各种任务和场景中的有效性和稳定性。

我认为Mamba还处于研究试验阶段,个创新思路能否得到学术界以及行业的认可,还需要更多实际应用来证明,毕竟Transformer架构已经大规模落地使用,经得住考验。

今年年初,Mamba被ICLR拒稿,其中一位审稿人提问:有没有训练更大的模型,和10B参数的Transformer比较如何?这就说明它确实还需要一定规模的训练和应用。据我所知,目前生产级的基于Mamba架构的模型不多,如果你感兴趣,可以去看下有52B参数的 jamba

小结

Mamba在多个方面有出色表现,包括高效性、适应性、内存利用、训练速度和性能表现。这些优势使Mamba成为Transformer的强有力竞争者。无论是在学术研究还是工业应用中,Mamba都有潜力带来显著的改进。

但Mamba也同样存在一些缺点,比如它复杂性高、资源需求大、迁移学习难度大、实验验证比较困难。一切都还需要时间去证明,而我们要做的就是时时跟进发展动态,以审慎的态度去了解、去尝试。

思考题

结合今天学习的内容,你来思考一下,Mamba在变成熟的道路上,可能会遇到最大的难题是什么?如何解决?欢迎你在评论区留言,我们一起讨论,如果你觉得这节课的内容对你有帮助的话,也欢迎你分享给其他朋友,在这里我们时时交流AI的最新动态,争取不让一个人掉队!

精选留言(6)
  • 牙小木 👍(6) 💬(2)

    从小学数学直接蹦到微积分了吗

    2024-06-26

  • 张申傲 👍(5) 💬(1)

    虽然有些原理没太理解,但是直观上感觉Mamba相较于Transformer而言,可以支持更长的上下文。现在Transformer在一些文生漫画、文生视频的场景下,还是没办法特别好地解决“时序”问题,导致会出现一些情节不连贯、甚至前后矛盾的情况,这在很大程度上是因为Transformer没法处理特别长的上下文。感觉Mamba在这类场景下可能更有优势~

    2024-06-17

  • Lonely绿豆蛙 👍(3) 💬(1)

    看懂了<50%,不过有个肤浅的疑问:为什么Mamba计算量更小、训推更加高效,反而说缺点之一是资源需求大呢?是因为需要更多的资源用于调参吗?

    2024-07-02

  • zMansi 👍(2) 💬(1)

    这节课好有深度哈,很多知识点不懂。但是看下来mamba需要更多算力来支撑,期待有更多对应的开源产品可以提供试验

    2024-06-12

  • 寒溪 👍(1) 💬(1)

    请教一下老师画图用的什么工具

    2024-06-14

  • 石云升 👍(1) 💬(0)

    在Mamba走向成熟的道路上,可能面临的最大挑战是适应性和通用性。具体来说: a) 任务适应性: 虽然Mamba在某些任务上表现出色,但它可能难以在所有类型的任务上都超越Transformer。不同任务可能需要不同的模型特性,Mamba需要证明它能在广泛的应用场景中保持竞争力。 b) 预训练和迁移学习: Transformer模型(如BERT、GPT等)的一大优势是其强大的预训练和迁移学习能力。Mamba需要开发类似的范式来实现在大规模数据上的预训练,并能够有效地将这些知识迁移到下游任务。 c) 工具生态系统: Transformer模型拥有丰富的工具、库和优化技术。Mamba需要建立类似的生态系统以支持其广泛应用。 d) 训练稳定性: 新的架构可能面临训练不稳定或收敛困难的问题,特别是在扩展到更大规模模型时。 可能的解决方案: a) 混合架构: 开发Mamba和Transformer的混合模型,结合两者的优势。这可能涉及在模型的不同部分使用不同的架构,或者开发能够动态选择最佳计算方法的模型。 b) 改进预训练方法: 设计专门针对Mamba架构的预训练任务和方法,可能需要重新思考自监督学习的范式。 c) 投资工具和框架: 大力投资开发支持Mamba的工具、库和框架,使其易于使用和优化。 d) 持续的理论研究: 深入研究Mamba的理论基础,以更好地理解其性能特征和局限性,从而指导进一步的改进。 e) 跨领域合作: 促进机器学习研究者与各个应用领域专家的合作,以发现Mamba的独特优势和潜在应用场景。 f) 优化训练算法: 开发专门针对Mamba架构的优化器和训练技巧,提高其训练稳定性和效率。

    2024-09-08