Rethinking Softmax: Self-Attention with Polynomial Activations

Hemanth Saratchandran, Jianqiao Zheng, Yiping Ji, Wenbo Zhang & Simon Lucey
Australian Institute of Machine Learning
University of Adelaide
[email protected]
Abstract

This paper challenges the conventional belief that softmax attention in transformers is effective primarily because it generates a probability distribution for attention allocation. Instead, we theoretically show that its success lies in its ability to implicitly regularize the Frobenius norm of the attention matrix during training. We then explore alternative activations that regularize the Frobenius norm of the attention matrix, demonstrating that certain polynomial activations can achieve this effect, making them suitable for attention-based architectures. Empirical results indicate these activations perform comparably or better than softmax across various computer vision and language tasks, suggesting new possibilities for attention mechanisms beyond softmax.

1 Introduction

Transformer architectures (Vaswani et al., 2017) have become the state-of-the-art model architecture for diverse areas such as natural language processing (NLP) (Vaswani et al., 2017; Devlin et al., 2018; Zhuang et al., 2021; Zhen et al., 2022), computer vision (Dosovitskiy et al., 2020; Carion et al., 2020; Liu et al., 2021; Touvron et al., 2021), and robotics (Fu et al., 2024; Maiti et al., 2023; Salzmann et al., 2020). A key component in the transformer architecture is the softmax attention block, enabling transformers to evaluate the importance of individual input elements during output generation. This feature facilitates an efficient method to attend to diverse input elements throughout training, allowing transformers to effectively capture spatial dependencies within sequential data. Unlike traditional recurrent neural networks (RNNs) and convolutional neural networks (CNNs), transformers exhibit the ability to scale to large datasets without a significant degradation in performance. This characteristic has made them an ideal architecture for handling large-scale machine learning tasks.

Softmax is widely recognized for its effectiveness in attention mechanisms due to its ability to produce an attention matrix that meets three key conditions (Vaswani et al., 2017; Dosovitskiy et al., 2020; Zhen et al., 2022; AUEB et al., 2016): (i) non-negativity, (ii) rows that are normalized to sum to 1 and (iii) sparsity. The general consensus is that non-negativity guarantees that attention weights remain positive, facilitating the model in assigning significance to various input elements. The normalization constraint ensures that the attention weights for all input elements collectively sum to 1111, rendering the weights interpretable as probabilities. Additionally, sparsity aids the model in focusing on a select few key elements, thereby enhancing efficiency. It has been argued that these properties are crucial for enabling the attention mechanism to attend to pertinent segments of the input sequence while efficiently filtering out irrelevant details. However, this approach to attention has become somewhat axiomatic as it is mostly motivated by empirical results with little theoretical foundation. Despite the exploration of alternative activations in several studies (Shen et al., 2023; Fang et al., 2022; Correia et al., 2019), softmax attention continues to dominate, largely due to its interpretability.

In this paper, we question this view by proposing that the effectiveness of softmax stems from its implicit regularization of the Frobenius norm of the attention matrix during training, preventing attention weights from becoming excessively large or small. We then derive a theoretical framework that produces polynomial activations that deliberately violate one or more of the three conditions mentioned earlier, yet are able to regularize the Frobenius norm of the attention weights during training. Our findings demonstrate that such activations can achieve comparable or even superior performance to softmax across various vision and natural language processing (NLP) tasks, even though they seem to violate our understanding of attention.

We advise the reader that this paper diverges from the usual pursuit of creating cutting-edge transformer architectures for achieving state-of-the-art results on benchmark datasets. Instead, our focus is on critically examining softmax attention to determine whether its effectiveness is a result of true interpretability or a more nuanced regularization mechanism. By questioning established views, we aim to uncover deeper insights into transformer architectures that could lead to broader applications and improved understanding. Nonetheless, we validate our theory on multiple transformer-based tasks, including image classification, segmentation, object detection, and NLP, often achieving results that match or surpass those of softmax attention. Our main contributions are:

  • 1.

    We question the widely accepted notion that softmax’s effectiveness in attention mechanisms is solely due to its ability to produce normalized sparse attention weights. Instead, we theoretically show that softmax has a regularization effect on attention and argue this plays a crucial role in its success.

  • 2.

    We explore activations that deliberately deviate from traditional softmax attention conditions. These activations are found to regularize the Frobenius norm of the attention matrix during training, akin to softmax, and demonstrate comparable or superior performance across various vision and NLP tasks.

2 Related Work

Several studies have explored alternative activations for attention mechanisms in transformers. Shen et al. (2023) investigated ReLU activations, finding them to outperform softmax in tasks with long sequences, such as document translation. Banerjee et al. (2020) examined Taylor series approximations to softmax, which showed superior performance to softmax in image classification. Wang et al. (2021) proposed periodic alternatives to softmax, designed to provide better gradients for attention mechanisms and achieved better results than softmax on simple networks for image classification. Koohpayegani & Pirsiavash (2024) demonstrated that applying l1superscript𝑙1l^{1}italic_l start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT normalization to linear attention mechanisms can achieve performance comparable to softmax. Our work differs from all these in that we identify a clear theoretical relationship between the scale of the Frobenius norm of the self-attention matrix and the input sequence length. Using this insight to derive potential acitvations that can perform on par with softmax.

3 Preliminaries and Notation

In this section we outline the definition of a transformer via the transformer block and set the notation of various mathematical quantities we will be using in future sections. For more details on transformers the reader can consult Vaswani et al. (2017); Dosovitskiy et al. (2020).

Transformer architectures comprise of transformer blocks, defined as follows. A transformer block is a mapping 𝐓:N×DN×D:𝐓superscript𝑁𝐷superscript𝑁𝐷\mathbf{T}:\mathbb{R}^{N\times D}\rightarrow\mathbb{R}^{N\times D}bold_T : blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT defined as

𝐓(x)=𝐅(𝐀(x)+x)𝐓𝑥𝐅𝐀𝑥𝑥\mathbf{T}(x)=\mathbf{F}(\mathbf{A}(x)+x)bold_T ( italic_x ) = bold_F ( bold_A ( italic_x ) + italic_x ) (3.1)

where 𝐅𝐅\mathbf{F}bold_F is a feedforward MLP with a residual connection and 𝐀𝐀\mathbf{A}bold_A is an attention head.

The attention head 𝐀𝐀\mathbf{A}bold_A is defined as follows: It comprises of three learnable matrices, a query (q𝑞qitalic_q), key (k𝑘kitalic_k) and value (v𝑣vitalic_v) defined by: q=QX𝑞𝑄𝑋q=QXitalic_q = italic_Q italic_X, k=KX𝑘𝐾𝑋k=KXitalic_k = italic_K italic_X, v=VX𝑣𝑉𝑋v=VXitalic_v = italic_V italic_X for an input sequence XN×D𝑋superscript𝑁𝐷X\in\mathbb{R}^{N\times D}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT with Q𝑄Qitalic_Q, KD×d𝐾superscript𝐷𝑑K\in\mathbb{R}^{D\times d}italic_K ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d end_POSTSUPERSCRIPT and VD×M𝑉superscript𝐷𝑀V\in\mathbb{R}^{D\times M}italic_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_M end_POSTSUPERSCRIPT. The attention head 𝐀(X)𝐀𝑋\mathbf{A}(X)bold_A ( italic_X ) is then defined by

𝐀(X)=ϕ(𝒮(q,k))v𝐀𝑋italic-ϕ𝒮𝑞𝑘𝑣\mathbf{A}(X)=\phi(\mathcal{S}(q,k))vbold_A ( italic_X ) = italic_ϕ ( caligraphic_S ( italic_q , italic_k ) ) italic_v (3.2)

where 𝒮𝒮\mathcal{S}caligraphic_S is a similarity transformation and ϕitalic-ϕ\phiitalic_ϕ is an activation function. The most common used 𝒮𝒮\mathcal{S}caligraphic_S is the dot-product: 𝒮(q,v)=qkT𝒮𝑞𝑣𝑞superscript𝑘𝑇\mathcal{S}(q,v)=qk^{T}caligraphic_S ( italic_q , italic_v ) = italic_q italic_k start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, known as self-attention, and will be the one we focus on in this paper. The most common activation function ϕitalic-ϕ\phiitalic_ϕ that is used by authors is softmax. This leads to the most common form of the attention head given by

𝐀(X)=𝐬𝐨𝐟𝐭𝐦𝐚𝐱(qkTd)v=𝐬𝐨𝐟𝐭𝐦𝐚𝐱(XQKTXTd)XV.𝐀𝑋𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝑞superscript𝑘𝑇𝑑𝑣𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑑𝑋𝑉\mathbf{A}(X)=\mathbf{softmax}\bigg{(}\frac{qk^{T}}{\sqrt{d}}\bigg{)}v=\mathbf% {softmax}\bigg{(}\frac{XQK^{T}X^{T}}{\sqrt{d}}\bigg{)}XV.bold_A ( italic_X ) = bold_softmax ( divide start_ARG italic_q italic_k start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) italic_v = bold_softmax ( divide start_ARG italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) italic_X italic_V . (3.3)

The function 𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐬𝐨𝐟𝐭𝐦𝐚𝐱\mathbf{softmax}bold_softmax is the matrix softmax map that applies the usual softmax function row-wise:

𝐬𝐨𝐟𝐭𝐦𝐚𝐱([x11x1nxn1xnn])=[ex11j=1nex1jex1nj=1nex1jexn1j=1nexnjexnnj=1nexnj]𝐬𝐨𝐟𝐭𝐦𝐚𝐱matrixsubscript𝑥11subscript𝑥1𝑛subscript𝑥𝑛1subscript𝑥𝑛𝑛matrixsuperscript𝑒subscript𝑥11superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑥1𝑗superscript𝑒subscript𝑥1𝑛superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑥1𝑗superscript𝑒subscript𝑥𝑛1superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑥𝑛𝑗superscript𝑒subscript𝑥𝑛𝑛superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑥𝑛𝑗\mathbf{softmax}\bigg{(}\begin{bmatrix}x_{11}&\cdots&x_{1n}\\ \vdots&\vdots&\vdots\\ x_{n1}&\cdots&x_{nn}\end{bmatrix}\bigg{)}=\begin{bmatrix}\frac{e^{x_{11}}}{% \sum_{j=1}^{n}e^{x_{1j}}}&\cdots&\frac{e^{x_{1n}}}{\sum_{j=1}^{n}e^{x_{1j}}}\\ \vdots&\vdots&\vdots\\ \frac{e^{x_{n1}}}{\sum_{j=1}^{n}e^{x_{nj}}}&\cdots&\frac{e^{x_{nn}}}{\sum_{j=1% }^{n}e^{x_{nj}}}\end{bmatrix}bold_softmax ( [ start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_x start_POSTSUBSCRIPT 1 italic_n end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_n 1 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_x start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) = [ start_ARG start_ROW start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL start_CELL ⋯ end_CELL start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_n 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_n italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL start_CELL ⋯ end_CELL start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_n italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW end_ARG ] (3.4)

The factor 1d1𝑑\frac{1}{\sqrt{d}}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG, as explained in Vaswani et al. (2017), is a scaling to prevent the gradients of softmax from being too small. For the theoretical analysis in this paper we will only use the dot-product similarity qkT𝑞superscript𝑘𝑇qk^{T}italic_q italic_k start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and call the N×N𝑁𝑁N\times Nitalic_N × italic_N matrix softmax(qkT)𝑠𝑜𝑓𝑡𝑚𝑎𝑥𝑞superscript𝑘𝑇softmax(qk^{T})italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( italic_q italic_k start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) the softmax self-attention matrix. In the experiments section, Sec. 5, we will empirically validate our theoretical framework on more general softmax attention blocks.

For general transformer architectures, multiple heads 𝐀isubscript𝐀𝑖\mathbf{A}_{i}bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for 1in1𝑖𝑛1\leq i\leq n1 ≤ italic_i ≤ italic_n are used. Each attention head is defined by equation 3.3 and then all outputs of each attention head are concatenated together before going into the feedforward layer.

We will need notation for the derivative of the matrix softmax map defined by equation 3.4. Given a matrix AN×N𝐴superscript𝑁𝑁A\in\mathbb{R}^{N\times N}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT we can differentiate the matrix map 𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐬𝐨𝐟𝐭𝐦𝐚𝐱\mathbf{softmax}bold_softmax at A𝐴Aitalic_A and obtain the gradient linear map 𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A):N×NN×N:𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴superscript𝑁𝑁superscript𝑁𝑁\mathbf{\nabla softmax}(A):\mathbb{R}^{N\times N}\rightarrow\mathbb{R}^{N% \times N}∇ bold_softmax ( italic_A ) : blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT that is defined by the formula

𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A):=𝐉𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)Tassign𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐉𝐬𝐨𝐟𝐭𝐦𝐚𝐱superscript𝐴𝑇\mathbf{\nabla softmax}(A):=\mathbf{Jsoftmax}(A)^{T}∇ bold_softmax ( italic_A ) := bold_Jsoftmax ( italic_A ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT (3.5)

where 𝐉𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)𝐉𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴\mathbf{Jsoftmax}(A)bold_Jsoftmax ( italic_A ) is the Jacobian of 𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐬𝐨𝐟𝐭𝐦𝐚𝐱\mathbf{softmax}bold_softmax at A𝐴Aitalic_A.

Given a matrix An×m𝐴superscript𝑛𝑚A\in\mathbb{R}^{n\times m}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_m end_POSTSUPERSCRIPT, we denote its Frobenius norm by AFsubscriptnorm𝐴𝐹||A||_{F}| | italic_A | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT. Additionally, we use the notation 𝔼𝔼\mathbb{E}blackboard_E to represent the expectation of a random variable, where the specific random variable being considered will be clear from the context.

4 Theoretical Analysis

4.1 Implicit regulatization of Softmax

This section presents a theoretical result showing that the softmax activation imposes control over the Frobenius norm of the self-attention matrix in a way that grows sub-linearly with the input sequence’s token length. Additionally, we demonstrate that the gradient of the softmax with respect to the self-attention matrix also exhibits a similar degree of regularity. While previous work has analyzed the regularity of softmax self-attention through the lens of the Lipschitz constant (Kim et al., 2021; Castin et al., 2023), our theorem offers a novel perspective by directly linking the Frobenius norm regularity to the token length. This provides insights into how self-attention activations should scale with token length to maintain stability during training, especially with gradient descent-based algorithms.

Theorem 4.1.

Let 𝐬𝐨𝐟𝐭𝐦𝐚𝐱:N×NN×N:𝐬𝐨𝐟𝐭𝐦𝐚𝐱superscript𝑁𝑁superscript𝑁𝑁\mathbf{softmax}:\mathbb{R}^{N\times N}\rightarrow\mathbb{R}^{N\times N}bold_softmax : blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT be the matrix softmax map defined by equation 3.4 and let 𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A):N×NN×N:𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴superscript𝑁𝑁superscript𝑁𝑁\mathbf{\nabla softmax}(A):\mathbb{R}^{N\times N}\rightarrow\mathbb{R}^{N% \times N}∇ bold_softmax ( italic_A ) : blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT denote the gradient of 𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐬𝐨𝐟𝐭𝐦𝐚𝐱\mathbf{softmax}bold_softmax at AN×N𝐴superscript𝑁𝑁A\in\mathbb{R}^{N\times N}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT. We then have the following bounds on the Frobenius norms

𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)Fsubscriptnorm𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐹\displaystyle||\mathbf{softmax}(A)||_{F}| | bold_softmax ( italic_A ) | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT Nabsent𝑁\displaystyle\leq\sqrt{N}≤ square-root start_ARG italic_N end_ARG (4.1)
𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)Fsubscriptnorm𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐹\displaystyle||\mathbf{\nabla softmax}(A)||_{F}| | ∇ bold_softmax ( italic_A ) | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT 2N.absent2𝑁\displaystyle\leq 2\sqrt{N}.≤ 2 square-root start_ARG italic_N end_ARG . (4.2)

The key implication of theorem 4.1 is that during the training of a transformer with softmax self-attention, the Frobenius norm of each softmax self-attention matrix remains bounded by a value that grows as 𝒪(N)𝒪𝑁\mathcal{O}(\sqrt{N})caligraphic_O ( square-root start_ARG italic_N end_ARG ). This ensures that backpropagation through the weights of the self-attention matrix does not lead to excessively large gradients. The proof hinges on the fact that the row normalization inherent in softmax effectively controls the Frobenius norm. For a detailed proof see appendix A.1.1.

4.2 Polynomial activations for self-attention

In section 4.1, we demonstrated that softmax implicitly regularizes the Frobenius norm of the self-attention matrix. Building on this, we now show that by scaling specific polynomial activations, a similar regularization effect on the Frobenius norm can be achieved in expectation, closely replicating the impact of softmax.

Theorem 4.2.

Let XN×D𝑋superscript𝑁𝐷X\in\mathbb{R}^{N\times D}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT and Q𝑄Qitalic_Q, KD×d𝐾superscript𝐷𝑑K\in\mathbb{R}^{D\times d}italic_K ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d end_POSTSUPERSCRIPT be i.i.d random variables distributed according to X𝒩(0,σx)similar-to𝑋𝒩0subscript𝜎𝑥X\sim\mathcal{N}(0,\sigma_{x})italic_X ∼ caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) and Q𝑄Qitalic_Q, K𝒩(0,σt)similar-to𝐾𝒩0subscript𝜎𝑡K\sim\mathcal{N}(0,\sigma_{t})italic_K ∼ caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). We have the following expectations of the Frobenius norms of powers of the N×N𝑁𝑁N\times Nitalic_N × italic_N matrix (XQKTXT)psuperscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑝(XQK^{T}X^{T})^{p}( italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT for p1𝑝1p\geq 1italic_p ≥ 1

𝔼(XQKTXTd)pF𝒪(N)𝔼subscriptnormsuperscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑑𝑝𝐹𝒪𝑁\mathbb{E}\bigg{|}\bigg{|}\bigg{(}\frac{XQK^{T}X^{T}}{\sqrt{d}}\bigg{)}^{p}% \bigg{|}\bigg{|}_{F}\leq\mathcal{O}(N)blackboard_E | | ( divide start_ARG italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ caligraphic_O ( italic_N ) (4.3)

By scaling such an activation by 1N1𝑁\frac{1}{\sqrt{N}}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG we can obtain a 𝒪(N)𝒪𝑁\mathcal{O}(\sqrt{N})caligraphic_O ( square-root start_ARG italic_N end_ARG ) bound.

Corollary 4.3.

Assume the same conditions as in theorem 4.2. Then

𝔼1N(XQKTXTd)pF𝒪(N).𝔼subscriptnorm1𝑁superscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑑𝑝𝐹𝒪𝑁\mathbb{E}\bigg{|}\bigg{|}\frac{1}{\sqrt{N}}\bigg{(}\frac{XQK^{T}X^{T}}{\sqrt{% d}}\bigg{)}^{p}\bigg{|}\bigg{|}_{F}\leq\mathcal{O}(\sqrt{N}).blackboard_E | | divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG ( divide start_ARG italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ caligraphic_O ( square-root start_ARG italic_N end_ARG ) . (4.4)

Corollary 4.3 establishes that activations of the form ϕ(x):=1Nxpassignitalic-ϕ𝑥1𝑁superscript𝑥𝑝\phi(x):=\frac{1}{\sqrt{N}}x^{p}italic_ϕ ( italic_x ) := divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG italic_x start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT provide a level of regularization, in expectation, similar to that of softmax when applied to the self-attention matrix. The proof of theorem 4.2 can be found in appendix A.1.2. The next property we want to prove is one similar to the gradient bound obtained in theorem 4.1. Since the self-attention matrix has parameters given by the queries Q𝑄Qitalic_Q and keys K𝐾Kitalic_K (Vaswani et al., 2017), this implies that during the training of a transformer the Q𝑄Qitalic_Q and K𝐾Kitalic_K matrices are the only aspects of the self-attention matrix that get updated. Therefore, we compute a regularity result with respect to the Q𝑄Qitalic_Q and K𝐾Kitalic_K derivatives.

Theorem 4.4.

Let XN×D𝑋superscript𝑁𝐷X\in\mathbb{R}^{N\times D}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT and Q𝑄Qitalic_Q, KD×d𝐾superscript𝐷𝑑K\in\mathbb{R}^{D\times d}italic_K ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d end_POSTSUPERSCRIPT be i.i.d random variables distributed according to X𝒩(0,σx)similar-to𝑋𝒩0subscript𝜎𝑥X\sim\mathcal{N}(0,\sigma_{x})italic_X ∼ caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) and Q𝑄Qitalic_Q, K𝒩(0,σt)similar-to𝐾𝒩0subscript𝜎𝑡K\sim\mathcal{N}(0,\sigma_{t})italic_K ∼ caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). Then the expectation of the of the derivative of the matrix (XQKTXT)pdsuperscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑝𝑑\frac{(XQK^{T}X^{T})^{p}}{\sqrt{d}}divide start_ARG ( italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG w.r.t the Q𝑄Qitalic_Q parameter matrix for p1𝑝1p\geq 1italic_p ≥ 1 is given by

𝔼Q((XQKTXT)pd)𝒪(N)𝔼norm𝑄superscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑝𝑑𝒪𝑁\mathbb{E}\bigg{|}\bigg{|}\frac{\partial}{\partial Q}\bigg{(}\frac{(XQK^{T}X^{% T})^{p}}{\sqrt{d}}\bigg{)}\bigg{|}\bigg{|}\leq\mathcal{O}(N)blackboard_E | | divide start_ARG ∂ end_ARG start_ARG ∂ italic_Q end_ARG ( divide start_ARG ( italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) | | ≤ caligraphic_O ( italic_N ) (4.5)

The above theorem then suggests that if we scale the polynomial xxp𝑥superscript𝑥𝑝x\rightarrow x^{p}italic_x → italic_x start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT by 1N1𝑁\frac{1}{\sqrt{N}}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG the Q𝑄Qitalic_Q derivative will have 𝒪(N)𝒪𝑁\mathcal{O}(\sqrt{N})caligraphic_O ( square-root start_ARG italic_N end_ARG ) growth.

Corollary 4.5.

Assume the same condition as in theorem 4.4. Then

𝔼1NQ((XQKTXT)pd)𝒪(N).𝔼norm1𝑁𝑄superscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑝𝑑𝒪𝑁\mathbb{E}\bigg{|}\bigg{|}\frac{1}{\sqrt{N}}\frac{\partial}{\partial Q}\bigg{(% }\frac{(XQK^{T}X^{T})^{p}}{\sqrt{d}}\bigg{)}\bigg{|}\bigg{|}\leq\mathcal{O}(% \sqrt{N}).blackboard_E | | divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG divide start_ARG ∂ end_ARG start_ARG ∂ italic_Q end_ARG ( divide start_ARG ( italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) | | ≤ caligraphic_O ( square-root start_ARG italic_N end_ARG ) . (4.6)

An analogous estimate holds for derivatives with respect to the K𝐾Kitalic_K matrix. The proof of theorem 4.4 can be found in appendix A.1.2.

Corollaries 4.3 and 4.5 suggest that polynomial activations of the form ϕ(x)=1Nxpitalic-ϕ𝑥1𝑁superscript𝑥𝑝\phi(x)=\frac{1}{\sqrt{N}}x^{p}italic_ϕ ( italic_x ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG italic_x start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT, with p>0𝑝0p>0italic_p > 0, can achieve performance comparable to softmax when applied to self-attention matrices. In section 5, we empirically compare these activations to softmax and observe that they outperform softmax on a variety of transformer tasks. We focus on p=1𝑝1p=1italic_p = 1 and p=3𝑝3p=3italic_p = 3 as these polynomials clearly violate key aspects of softmax based attention, such as normalized rows, positivity, and sparsity. For larger values of p𝑝pitalic_p, performance declines due to the functions ϕ(x)=1Nxpitalic-ϕ𝑥1𝑁superscript𝑥𝑝\phi(x)=\frac{1}{\sqrt{N}}x^{p}italic_ϕ ( italic_x ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG italic_x start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT having smaller gradients around 00 when p𝑝pitalic_p is large, causing difficulties in training.

5 Experiments

In this section, we validate the theory from section 4 on a variety of transformer tasks. We perform the empirical validation on two primary activations from section 4, namely a cubic polynomial activation ϕ(x)=x3italic-ϕ𝑥superscript𝑥3\phi(x)=x^{3}italic_ϕ ( italic_x ) = italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and a linear polynomial ϕ(x)=xitalic-ϕ𝑥𝑥\phi(x)=xitalic_ϕ ( italic_x ) = italic_x. The goal will be to show that by suitably scaling these activations using the theory in section 4, we can achieve competitive performance when compared to softmax. For the rest of this section we will simply denote these activations by x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x.

Refer to caption
Figure 1: Training ViT-Tiny with the activation ϕ(x)=x3italic-ϕ𝑥superscript𝑥3\phi(x)=x^{3}italic_ϕ ( italic_x ) = italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT with different sequence lengths and different scales. As the sequence length gets smaller, the log scale needed to obtain good accuracy decreases validating the theory from section 4.2.

5.1 Image classification

5.1.1 ViT-Tiny on Tiny-Imagenet:

In this section we test the theory from section 4 on the ViT-Tiny architecture (Steiner et al., 2021) trained from scratch on the Tiny-Imagenet dataset (Le & Yang, 2015).

Our first experiment was to test how the Top-1%percent11\%1 % accuracy changes for a ViT-Tiny trained on Tiny-Imagenet as we change the sequence length of the input and the scale predicted in corollaries 4.3 and 4.5 when using the activation x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT. According to the theory developed in section 4, the Frobenius norm scales according to 𝒪(N)𝒪𝑁\mathcal{O}(\sqrt{N})caligraphic_O ( square-root start_ARG italic_N end_ARG ) when we scale X3superscript𝑋3X^{3}italic_X start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT by 1N1𝑁\frac{1}{\sqrt{N}}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG. Thus as the sequence length decreases we should see the amount of scaling in a Log scale decrease.

Figure 1 shows the results of this experiment. We considered four different input sequence lengths of sizes 256256256256, 64646464, 16161616 and 8888. We ran several ViT-Tiny architectures with a variety of scalings of the form 𝒪(1N)𝒪1𝑁\mathcal{O}(\frac{1}{\sqrt{N}})caligraphic_O ( divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG ) where N𝑁Nitalic_N ranged below to above the sequence length. As can be seen from figure 1 as the sequence length got smaller the amount of scaling, shown in Log scale on the x-axis, needed for good accuracy got smaller verifying the theory in section 4.2.

The second experiment compared activations x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x, along with scaled versions 116x3116superscript𝑥3\frac{1}{16}x^{3}divide start_ARG 1 end_ARG start_ARG 16 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and 116x116𝑥\frac{1}{16}xdivide start_ARG 1 end_ARG start_ARG 16 end_ARG italic_x, against softmax using the Tiny-ViT architecture on Tiny-Imagenet. With a sequence length of 256256256256 (256=1625616\sqrt{256}=16square-root start_ARG 256 end_ARG = 16), we decided to take 116116\frac{1}{16}divide start_ARG 1 end_ARG start_ARG 16 end_ARG as the scale of the polynomial activations. The experiment used a patch size of 4, 3 attention heads, and 12 layers as described in Steiner et al. (2021). Results in table 1 show 18x318superscript𝑥3\frac{1}{8}x^{3}divide start_ARG 1 end_ARG start_ARG 8 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT outperforming softmax, while the unscaled version performed poorly. Similarly, 116x116𝑥\frac{1}{16}xdivide start_ARG 1 end_ARG start_ARG 16 end_ARG italic_x performed competitively with a significant drop in performance without scaling.

Figure 2 displays the Frobenius norm of the self-attention matrix during training for five activations in layers 2 and 12 of ViT-Tiny, averaged across all heads. Norms for x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x are higher than softmax, but scaling by 116116\frac{1}{16}divide start_ARG 1 end_ARG start_ARG 16 end_ARG reduces them to more stable levels, improving training stability. Similarly, figure 3 shows the Jacobian’s Frobenius norm, where scaling also brings the norms closer to softmax, ensuring more stable gradients. Further plots for other layers are in appendix A.2.2.

softmax x316superscript𝑥316\frac{x^{3}}{16}divide start_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG 16 end_ARG x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT x16𝑥16\frac{x}{16}divide start_ARG italic_x end_ARG start_ARG 16 end_ARG x𝑥xitalic_x
Top-1% accuracy 50.26 50.5 45.3 47.9 31.78
Table 1: Comparison of Top-1% accuracy on Tiny-Imagenet between softmax and polynomial activations. The cubic activation outperforms softmax when the right scale of 1818\frac{1}{8}divide start_ARG 1 end_ARG start_ARG 8 end_ARG is applied. Similarly, the linear activation is competitive only with an optimal scale.
Refer to caption
Figure 2: Frobenius norm of the self-attention matrix with five different activations in layer 2 and 12 of the ViT-Tiny architecture during training.
Refer to caption
Figure 3: Frobenius norm of the Jacobian of the self-attention matrix with five different activations in layer 2 and 12 of the ViT-Tiny architecture during training.

5.1.2 Larger vision transformers on ImageNet-1k

For this experiment we carried out an image classification task using a variety of different vision transformers from the literature on the ImageNet-1k dataset. We found that the scales 114114\frac{1}{14}divide start_ARG 1 end_ARG start_ARG 14 end_ARG worked best for both x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x.

We train all models on the ImageNet-1k dataset from scratch and report Top-1 accuracy on the validation set. We use PyTorch Paszke et al. (2019) and Timm Wightman (2019) libraries to train our models with similar setups to He et al. (2022) Liu et al. (2021). We examined our approach along with the following three transformer architectures to show its generalization:

- ViT: Dosovitskiy et al. (2020) is the pioneering work that interprets an image as a sequence of patches and processes it by a standard Transformer encoder as used in NLP. This simple, yet scalable, strategy works surprisingly well when coupled with pre-training on large datasets. We use ViT-Small which has the following settings: patch size = 16, embedding dimensions = 384, number of heads = 6, and layers = 12. Also, We use ViT-Base which has the following settings: patch size = 16, embedding dimensions = 768, number of heads = 12, and layers = 12.

- DeiT: Touvron et al. (2021) is a well-known transformer based on ViT. It is very similar to ViT except that it converged faster due to better training recipe since we do not use the distillation token imposed in DeiT. We use DeiT-Small which has the following settings: patch size = 16, embedding dimensions = 384, number of heads = 6, and layers = 12. Also, We use Deit-Base which has the following settings: patch size = 16, embedding dimensions = 768, number of heads = 12, and layers = 12.

- Swin Transformer: Liu et al. (2021) produces a hierarchical feature representation and proposes the shifted window based self-attention which is shown to be effective and efficient on vision problems. We use Swin-Small with 96 channels and Swin-Base with 128 channels in the hidden layers of the first stage. The window size is set to M𝑀Mitalic_M = 7 by default, the query dimension of each head is d𝑑ditalic_d = 32, and layer numbers are 2,2,18,222182{2,2,18,2}2 , 2 , 18 , 2 for all experiments.

- XciT: Xiong et al. (2021) is a vision transformer architecture consisting of two different components compared to the standard ViT. Firstly, it has Local Patch Interaction in each block, which includes one depth-wise 3×3 convolution followed by Batch Normalization, GELU, and another depth-wise 3×3 convolution. Secondly, it uses Cross-Covariance attention, where the attention map is derived from the cross-covariance matrix computed over the key and query projections of the token features. We use XCiT-S12 with a patch size of 16 and XCiT-M24 with a patch size of 24.

The results are shown in table 2. The activation 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT performed best on the ViT’s and the Swin transformers while softmax performed best on the DeiT architectures. Further ablations with different scales and activations can be fond in appendix A.2.1.

Activation Models
ViT-Base ViT-Small DeiT-Base DeiT-Small Swin-Base Swin-Small XciT-Medium XciT-Small
softmax 79.6 80.2 78.9 79.6 83.0 83.3 81.1 81.2
x314superscript𝑥314\frac{x^{3}}{14}divide start_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG 14 end_ARG 79.6 80.5 77.4 78.3 83.2 83.4 81.2 82.1
x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT 77.8 78.6 76.1 76.3 79.7 79.9 78.1 78.3
x14𝑥14\frac{x}{14}divide start_ARG italic_x end_ARG start_ARG 14 end_ARG 76.9 77.8 79.6 79.8 79.4 79.5 79.2 79.3
x𝑥xitalic_x 73.2 73.9 77.7 77.9 77.8 77.9 76.4 76.6
Table 2: Comparsions of pre-training models with different activation functions on ImageNet-1k datasets. We report the classification top-1 accuracy (%).

Figure 4 plots the Frobenius norm of the self-attention matrix with the five different activations during training in layers 2 and 12, averaged over all heads within each layer, of the ViT-Small architecture. By scaling the activations x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x by 114114\frac{1}{14}divide start_ARG 1 end_ARG start_ARG 14 end_ARG we were able to control the scale of the Frobenius norm of the self-attention matrix and obtain scales comparable to softmax’s scale. Simiarly, figure 5 plots the Frobenius norm of the Jacobian of the self-attention matrix during training for layers 2 and 12, averaged over all heads. By scaling the activations x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x by 114114\frac{1}{14}divide start_ARG 1 end_ARG start_ARG 14 end_ARG we were able to control the scale of the Jacobian norm and obtain scales comparable to softmax’s scale during training. Plots for other layers of the architecture during training are given in appendix A.2.2.

Refer to caption
Figure 4: Frobenius norm of the self-attention matrix with five different activations in layer 2 and 12 of the ViT-Tiny architecture during training.
Refer to caption
Figure 5: Frobenius norm of the Jacobian of the self-attention matrix with five different activations in layer 2 and 12 of the ViT-Tiny architecture during training.

5.1.3 Visualizing self-attention with ViT-Base

We plotted heat maps of self-attention matrices using the 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and softmax activations across two layers and heads after convergence, averaging over a fixed training batch of size 128. Figure 6 shows layer 2, head 8, highlighting the differences in attention patterns, with 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT containing both positive and negative values. Similarly, Figure 7 for layer 12, head 6 shows distinct patterns for each activation. Overall, the ViT-Base architecture with 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT exhibits notably different self-attention patterns compared to softmax.

We visualized how the self-attention matrix targets different regions of an image. Using an image from the ImageNet-1k validation set, we extracted the class token, reshaped it into a (14,14)1414(14,14)( 14 , 14 ) grid representing 196 patches, and then mapped it back to the original image size with nearest neighbor interpolation. Figure 8 shows the input image, while figure 9 illustrates the self-attention matrices for different activations in layer 12, head 6 of the ViT-Base architecture after convergence, highlighting their distinct focus areas.

Refer to caption
Figure 6: Heat maps of the self-attention matrix in layer 2, head 8, of a ViT base architecture, comparing 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT (left) and softmax (right) activations after training. The stark difference in self-attention patterns between the two activations is evident, showing distinct distributions across input tokens.
Refer to caption
Figure 7: Heat maps of the self-attention matrix in layer 12, head 6, of a ViT base architecture, comparing 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT (left) and softmax (right) activations after training. The contrast in self-attention patterns between the two activations is clearly visible.
Refer to caption
Figure 8: Fish image from validation set of ImageNet-1k broken up into 196=14×141961414196=14\times 14196 = 14 × 14 patches.
Refer to caption
Figure 9: Comparing how 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and softmax self-attention matrices focus on different parts of the fish image from figure 8 in layer 12, head 6, after the model has converged.

5.2 Object Detection and Instance Segmentation

In this section, in order to examine the transfer learning ability of our models, we demonstrate our approach to object detection and segmentation tasks by fine-tuning our ImageNet-pretrained XCiT model on them. Our experiments are conducted on COCO 2017Lin et al. (2014), which has 118K training images and 5K validation images with 80 categories. We integrate the XCiT architecture as the backbone in the Mask R-CNN (He et al., 2017) detector with a Feature Pyramid Network (FPN). Due to XCiT’s inherently columnar design, we adapt it for FPN compatibility by extracting features from various layers for XCiT-S12. These features have a consistent stride of either 16. The feature resolutions are then adjusted to strides of [4, 8, 16, 32] This downsampling is accomplished through max pooling, while upsampling is achieved using a single transposed convolution layer. The model is trained for 36 epochs using the AdamW optimizer with learning rate of 104superscript10410^{-4}10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT, 0.05 weight decay and 16 batch size. In table 3,we condunct experiments on XCiT-S12 models using 16×16 patches with the activations 114x3114superscript𝑥3\frac{1}{14}x^{3}divide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT, 114x114𝑥\frac{1}{14}xdivide start_ARG 1 end_ARG start_ARG 14 end_ARG italic_x and softmax. We found we couldn’t train with the activations x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x on this task well so only report the others.

Activation APb𝐴superscript𝑃𝑏AP^{b}italic_A italic_P start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT AP50b𝐴subscriptsuperscript𝑃𝑏50AP^{b}_{50}italic_A italic_P start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 50 end_POSTSUBSCRIPT AP75b𝐴subscriptsuperscript𝑃𝑏75AP^{b}_{75}italic_A italic_P start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 75 end_POSTSUBSCRIPT APm𝐴superscript𝑃𝑚AP^{m}italic_A italic_P start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT AP50m𝐴subscriptsuperscript𝑃𝑚50AP^{m}_{50}italic_A italic_P start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 50 end_POSTSUBSCRIPT AP75m𝐴subscriptsuperscript𝑃𝑚75AP^{m}_{75}italic_A italic_P start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 75 end_POSTSUBSCRIPT
softmax 44.9 66.1 48.9 40.1 63.1 42.8
x314superscript𝑥314\frac{x^{3}}{14}divide start_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG 14 end_ARG 44.8 66.3 49 40.1 63.1 42.9
x14𝑥14\frac{x}{14}divide start_ARG italic_x end_ARG start_ARG 14 end_ARG 44.8 66.2 49 40.1 63.1 42.8
Table 3: COCO object detection and instance segmentation performance on the mini-val set. All backbones are pretrained on ImageNet-1k, and use Mask R-CNN model. APb𝐴superscript𝑃𝑏AP^{b}italic_A italic_P start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT: Average Precision for bounding box predictions, AP50/75b𝐴subscriptsuperscript𝑃𝑏5075AP^{b}_{50/75}italic_A italic_P start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 50 / 75 end_POSTSUBSCRIPT: Average Precision at an Intersection over Union (IoU) threshold of 0.50/0.75 for bounding box predictions, APm𝐴superscript𝑃𝑚AP^{m}italic_A italic_P start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT: Average Precision for mask predictions, AP50/75m𝐴subscriptsuperscript𝑃𝑚5075AP^{m}_{50/75}italic_A italic_P start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 50 / 75 end_POSTSUBSCRIPT: Average Precision at an Intersection over Union (IoU) threshold of 0.50/0.75 for mask predictions

5.3 Natural language processing(NLP)

To assess the effectiveness of our approach on NLP tasks, we trained models on five benchmarks from the Long Range Arena (LRA) suite Tay et al. (2020): ListOps, Text Classification, Retrieval, Image Classification, and Pathfinder. We evaluated the activations x314superscript𝑥314\frac{x^{3}}{14}divide start_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG 14 end_ARG and x14𝑥14\frac{x}{14}divide start_ARG italic_x end_ARG start_ARG 14 end_ARG against softmax, finding that x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x did not train effectively on their own, so only results for these scaled activations and softmax are presented. Our implementation followed the guidelines from Xiong et al. (2021). The results are summarized in table 4.

Activation ListOps Text Retrieval Image Pathfinder
softmax 37.1 63.8 79.8 39.9 72.9
x314superscript𝑥314\frac{x^{3}}{14}divide start_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG 14 end_ARG 37.5 63.3 80.9 37.2 68.7
x14𝑥14\frac{x}{14}divide start_ARG italic_x end_ARG start_ARG 14 end_ARG 34.3 62.9 81.5 39.0 69.1
Table 4: Comparsions of transformer models with different activation functions on NLP tasks. We report the accuracy (%) on LRA benchmarks.

6 Limitations

While our work introduces novel activations that challenge the conventional softmax approach, there are some limitations to address. Our theoretical framework is primarily designed for dot-product self-attention and may not immediately extend to other attention mechanisms, although our empirical results showed competitive performance against softmax across different architectures. Additionally, we observed that while our activations performed well on vision tasks, their performance was less consistent on NLP tasks, suggesting that a more refined theoretical approach may be needed for these applications.

7 Conclusion

This work challenges the traditional view that transformer activations for attention must produce sparse probability distributions. We introduced a theoretical framework analyzing the Frobenius norm of the self-attention matrix, which suggests key scaling properties for activations in attention mechanisms. We proved that specific polynomial activations, which behave very differently from softmax, satisfy these properties. Through extensive experiments across vision and NLP tasks, we demonstrated that these alternative activations not only compete with but sometimes outperform softmax, offering a fresh perspective on attention mechanisms in transformers.

References

  • AUEB et al. (2016) Titsias RC AUEB et al. One-vs-each approximation to softmax for scalable estimation of probabilities. Advances in Neural Information Processing Systems, 29, 2016.
  • Banerjee et al. (2020) Kunal Banerjee, Rishi Raj Gupta, Karthik Vyas, Biswajit Mishra, et al. Exploring alternatives to softmax function. arXiv preprint arXiv:2011.11538, 2020.
  • Carion et al. (2020) Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, and Sergey Zagoruyko. End-to-end object detection with transformers. In European conference on computer vision, pp.  213–229. Springer, 2020.
  • Castin et al. (2023) Valérie Castin, Pierre Ablin, and Gabriel Peyré. Understanding the regularity of self-attention with optimal transport. arXiv preprint arXiv:2312.14820, 2023.
  • Correia et al. (2019) Gon𝐜𝐜\mathbf{c}bold_ccalo M Correia, Vlad Niculae, and André FT Martins. Adaptively sparse transformers. arXiv preprint arXiv:1909.00015, 2019.
  • Devlin et al. (2018) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • Dosovitskiy et al. (2020) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  • Fang et al. (2022) Haishuo Fang, Ji-Ung Lee, Nafise Sadat Moosavi, and Iryna Gurevych. Transformers with learnable activation functions. arXiv preprint arXiv:2208.14111, 2022.
  • Fu et al. (2024) Daocheng Fu, Xin Li, Licheng Wen, Min Dou, Pinlong Cai, Botian Shi, and Yu Qiao. Drive like a human: Rethinking autonomous driving with large language models. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp.  910–919, 2024.
  • He et al. (2017) Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. Mask r-cnn. In Proceedings of the IEEE international conference on computer vision, pp.  2961–2969, 2017.
  • He et al. (2022) Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross Girshick. Masked autoencoders are scalable vision learners. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp.  16000–16009, 2022.
  • Kim et al. (2021) Hyunjik Kim, George Papamakarios, and Andriy Mnih. The lipschitz constant of self-attention. In International Conference on Machine Learning, pp.  5562–5571. PMLR, 2021.
  • Koohpayegani & Pirsiavash (2024) Soroush Abbasi Koohpayegani and Hamed Pirsiavash. Sima: Simple softmax-free attention for vision transformers. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp.  2607–2617, 2024.
  • Le & Yang (2015) Ya Le and Xuan Yang. Tiny imagenet visual recognition challenge. CS 231N, 7(7):3, 2015.
  • Lin et al. (2014) Tsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr Dollár, and C Lawrence Zitnick. Microsoft coco: Common objects in context. In Computer Vision–ECCV 2014: 13th European Conference, Zurich, Switzerland, September 6-12, 2014, Proceedings, Part V 13, pp.  740–755. Springer, 2014.
  • Liu et al. (2021) Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, and Baining Guo. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF international conference on computer vision, pp.  10012–10022, 2021.
  • Maiti et al. (2023) Abhisek Maiti, Sander Oude Elberink, and George Vosselman. Transfusion: Multi-modal fusion network for semantic segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  6536–6546, 2023.
  • Paszke et al. (2019) Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, et al. Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems, 32, 2019.
  • Salzmann et al. (2020) Tim Salzmann, Boris Ivanovic, Punarjay Chakravarty, and Marco Pavone. Trajectron++: Multi-agent generative trajectory forecasting with heterogeneous data for control. arXiv preprint arXiv:2001.03093, 2, 2020.
  • Shen et al. (2023) Kai Shen, Junliang Guo, Xu Tan, Siliang Tang, Rui Wang, and Jiang Bian. A study on relu and softmax in transformer. arXiv preprint arXiv:2302.06461, 2023.
  • Steiner et al. (2021) Andreas Steiner, Alexander Kolesnikov, Xiaohua Zhai, Ross Wightman, Jakob Uszkoreit, and Lucas Beyer. How to train your vit? data, augmentation, and regularization in vision transformers. arXiv preprint arXiv:2106.10270, 2021.
  • Tay et al. (2020) Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. Long range arena: A benchmark for efficient transformers. arXiv preprint arXiv:2011.04006, 2020.
  • Touvron et al. (2021) Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, and Hervé Jégou. Training data-efficient image transformers & distillation through attention. In International conference on machine learning, pp.  10347–10357. PMLR, 2021.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wang et al. (2021) Shulun Wang, Feng Liu, and Bin Liu. Escaping the gradient vanishing: Periodic alternatives of softmax in attention mechanism. IEEE Access, 9:168749–168759, 2021.
  • Wightman (2019) Ross Wightman. Pytorch image models. https://github.com/rwightman/pytorch-image-models, 2019.
  • Xiong et al. (2021) Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, and Vikas Singh. Nyströmformer: A nyström-based algorithm for approximating self-attention. Proceedings of the AAAI Conference on Artificial Intelligence, 2021.
  • Zhen et al. (2022) Q Zhen, W Sun, H Deng, D Li, Y Wei, B Lv, J Yan, L Kong, and Y Zhong. cosformer: rethinking softmax in attention. In International Conference on Learning Representations, 2022.
  • Zhuang et al. (2021) Liu Zhuang, Lin Wayne, Shi Ya, and Zhao Jun. A robustly optimized bert pre-training approach with post-training. In Proceedings of the 20th chinese national conference on computational linguistics, pp.  1218–1227, 2021.

Appendix A Appendix

A.1 Theoretical analysis

A.1.1 Proofs for theorems in section 4.1

In this section we give the proof of theorem 4.1.

Proof of theorem 4.1.

We will start by proving the first inequality in theorem 4.1. Given a matrix A=(aij)N×N𝐴subscript𝑎𝑖𝑗superscript𝑁𝑁A=(a_{ij})\in\mathbb{R}^{N\times N}italic_A = ( italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT we have that

𝐬𝐨𝐟𝐭𝐦𝐚𝐱([a11a1nan1ann])=[ea11j=1nea1jea1nj=1nea1jean1j=1neanjeannj=1neanj.]𝐬𝐨𝐟𝐭𝐦𝐚𝐱matrixsubscript𝑎11subscript𝑎1𝑛subscript𝑎𝑛1subscript𝑎𝑛𝑛matrixsuperscript𝑒subscript𝑎11superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎1𝑗superscript𝑒subscript𝑎1𝑛superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎1𝑗superscript𝑒subscript𝑎𝑛1superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎𝑛𝑗superscript𝑒subscript𝑎𝑛𝑛superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎𝑛𝑗\mathbf{softmax}\bigg{(}\begin{bmatrix}a_{11}&\cdots&a_{1n}\\ \vdots&\vdots&\vdots\\ a_{n1}&\cdots&a_{nn}\end{bmatrix}\bigg{)}=\begin{bmatrix}\frac{e^{a_{11}}}{% \sum_{j=1}^{n}e^{a_{1j}}}&\cdots&\frac{e^{a_{1n}}}{\sum_{j=1}^{n}e^{a_{1j}}}\\ \vdots&\vdots&\vdots\\ \frac{e^{a_{n1}}}{\sum_{j=1}^{n}e^{a_{nj}}}&\cdots&\frac{e^{a_{nn}}}{\sum_{j=1% }^{n}e^{a_{nj}}}.\end{bmatrix}bold_softmax ( [ start_ARG start_ROW start_CELL italic_a start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 italic_n end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT italic_n 1 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) = [ start_ARG start_ROW start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL start_CELL ⋯ end_CELL start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_n 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_n italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG end_CELL start_CELL ⋯ end_CELL start_CELL divide start_ARG italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_n italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG . end_CELL end_ROW end_ARG ] (A.1)

By definition of the Frobenius norm we then see that

𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)F2superscriptsubscriptnorm𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐹2\displaystyle||\mathbf{softmax}(A)||_{F}^{2}| | bold_softmax ( italic_A ) | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT =(1j=1nea1j)2(e2a11+e2a11)++(1j=1neaNj)2(e2aN1+e2aNN)absentsuperscript1superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎1𝑗2superscript𝑒2subscript𝑎11superscript𝑒2subscript𝑎11superscript1superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎𝑁𝑗2superscript𝑒2subscript𝑎𝑁1superscript𝑒2subscript𝑎𝑁𝑁\displaystyle=\bigg{(}\frac{1}{\sum_{j=1}^{n}e^{a_{1j}}}\bigg{)}^{2}(e^{2a_{11% }}+\cdots e^{2a_{11}})+\cdots+\bigg{(}\frac{1}{\sum_{j=1}^{n}e^{a_{Nj}}}\bigg{% )}^{2}(e^{2a_{N1}}+\cdots e^{2a_{NN}})= ( divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT 2 italic_a start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT + ⋯ italic_e start_POSTSUPERSCRIPT 2 italic_a start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) + ⋯ + ( divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_N italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_e start_POSTSUPERSCRIPT 2 italic_a start_POSTSUBSCRIPT italic_N 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT + ⋯ italic_e start_POSTSUPERSCRIPT 2 italic_a start_POSTSUBSCRIPT italic_N italic_N end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) (A.2)
[(1j=1nea1j)(ea11+ea11)]2++[(1j=1neaNj)(eaN1+eaNN)]2absentsuperscriptdelimited-[]1superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎1𝑗superscript𝑒subscript𝑎11superscript𝑒subscript𝑎112superscriptdelimited-[]1superscriptsubscript𝑗1𝑛superscript𝑒subscript𝑎𝑁𝑗superscript𝑒subscript𝑎𝑁1superscript𝑒subscript𝑎𝑁𝑁2\displaystyle\leq\bigg{[}\bigg{(}\frac{1}{\sum_{j=1}^{n}e^{a_{1j}}}\bigg{)}(e^% {a_{11}}+\cdots e^{a_{11}})\bigg{]}^{2}+\cdots+\bigg{[}\bigg{(}\frac{1}{\sum_{% j=1}^{n}e^{a_{Nj}}}\bigg{)}(e^{a_{N1}}+\cdots e^{a_{NN}})\bigg{]}^{2}≤ [ ( divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG ) ( italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT + ⋯ italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + [ ( divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_N italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG ) ( italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_N 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT + ⋯ italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_N italic_N end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.3)
=1++1absent11\displaystyle=1+\cdots+1= 1 + ⋯ + 1 (A.4)
=Nabsent𝑁\displaystyle=N= italic_N (A.5)

where the second inequality uses the fact that for non-negative numbers a𝑎aitalic_a and b𝑏bitalic_b we always have that a2+b2(a+b)2superscript𝑎2superscript𝑏2superscript𝑎𝑏2a^{2}+b^{2}\leq(a+b)^{2}italic_a start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ ( italic_a + italic_b ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

It then immediately follows that 𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)FNsubscriptnorm𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐹𝑁||\mathbf{softmax}(A)||_{F}\leq\sqrt{N}| | bold_softmax ( italic_A ) | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ square-root start_ARG italic_N end_ARG and this proves the first inequality in the statement of theorem 4.1.

We move on to prove the second inequality in the statement of theorem 4.1. For this, let us write each entry of the matrix on the right of equation A.1 as follows:

Fkl=eaklj=1Neakj.subscript𝐹𝑘𝑙superscript𝑒subscript𝑎𝑘𝑙superscriptsubscript𝑗1𝑁superscript𝑒subscript𝑎𝑘𝑗F_{kl}=\frac{e^{a_{kl}}}{\sum_{j=1}^{N}e^{a_{kj}}}.italic_F start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG . (A.6)

By applying the chain rule we then have the following derivative formulas

xijFijsubscript𝑥𝑖𝑗subscript𝐹𝑖𝑗\displaystyle\frac{\partial}{\partial x_{ij}}F_{ij}divide start_ARG ∂ end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT =FijFij2absentsubscript𝐹𝑖𝑗superscriptsubscript𝐹𝑖𝑗2\displaystyle=F_{ij}-F_{ij}^{2}= italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT - italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.7)
xikFijsubscript𝑥𝑖𝑘subscript𝐹𝑖𝑗\displaystyle\frac{\partial}{\partial x_{ik}}F_{ij}divide start_ARG ∂ end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT end_ARG italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT =FijFik for any kjabsentsubscript𝐹𝑖𝑗subscript𝐹𝑖𝑘 for any 𝑘𝑗\displaystyle=-F_{ij}F_{ik}\text{ for any }k\neq j= - italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT for any italic_k ≠ italic_j (A.8)
xklFijsubscript𝑥𝑘𝑙subscript𝐹𝑖𝑗\displaystyle\frac{\partial}{\partial x_{kl}}F_{ij}divide start_ARG ∂ end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT end_ARG italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT =0 for any ki and lj.absent0 for any 𝑘𝑖 and 𝑙𝑗\displaystyle=0\text{ for any }k\neq i\text{ and }l\neq j.= 0 for any italic_k ≠ italic_i and italic_l ≠ italic_j . (A.9)

We can then express the gradient as

𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)=[F11F1NF21FNN]𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴matrixsubscript𝐹11subscript𝐹1𝑁subscript𝐹21subscript𝐹𝑁𝑁\nabla\mathbf{softmax}(A)=\begin{bmatrix}\nabla F_{11}\\ \vdots\\ \nabla F_{1N}\\ \nabla F_{21}\\ \vdots\\ \vdots\\ \nabla F_{NN}\end{bmatrix}∇ bold_softmax ( italic_A ) = [ start_ARG start_ROW start_CELL ∇ italic_F start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL ∇ italic_F start_POSTSUBSCRIPT 1 italic_N end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ∇ italic_F start_POSTSUBSCRIPT 21 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL ∇ italic_F start_POSTSUBSCRIPT italic_N italic_N end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] (A.10)

where

Fij=[FijFi12FijFi2FijFij2FijFiN.]subscript𝐹𝑖𝑗matrixsubscript𝐹𝑖𝑗superscriptsubscript𝐹𝑖12subscript𝐹𝑖𝑗subscript𝐹𝑖2subscript𝐹𝑖𝑗superscriptsubscript𝐹𝑖𝑗2subscript𝐹𝑖𝑗subscript𝐹𝑖𝑁\nabla F_{ij}=\begin{bmatrix}-F_{ij}F_{i1}^{2}&-F_{ij}F_{i2}&\cdots&F_{ij}-F_{% ij}^{2}&\cdots&F_{ij}F_{iN}.\end{bmatrix}∇ italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = [ start_ARG start_ROW start_CELL - italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL start_CELL - italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT - italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT . end_CELL end_ROW end_ARG ] (A.11)

From these computations we see that

𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)F2=F11F2++FNNF2.superscriptsubscriptnorm𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐹2superscriptsubscriptnormsubscript𝐹11𝐹2superscriptsubscriptnormsubscript𝐹𝑁𝑁𝐹2||\nabla\mathbf{softmax}(A)||_{F}^{2}=||\nabla F_{11}||_{F}^{2}+\cdots+||% \nabla F_{NN}||_{F}^{2}.| | ∇ bold_softmax ( italic_A ) | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = | | ∇ italic_F start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | | ∇ italic_F start_POSTSUBSCRIPT italic_N italic_N end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (A.12)

We will proceed by bounding each collection Fi1F2++F1NF2superscriptsubscriptnormsubscript𝐹𝑖1𝐹2superscriptsubscriptnormsubscript𝐹1𝑁𝐹2||\nabla F_{i1}||_{F}^{2}+\cdots+||\nabla F_{1N}||_{F}^{2}| | ∇ italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | | ∇ italic_F start_POSTSUBSCRIPT 1 italic_N end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT separately then add up all the bounds. We have

Fi1F2++F1NF2superscriptsubscriptnormsubscript𝐹𝑖1𝐹2superscriptsubscriptnormsubscript𝐹1𝑁𝐹2\displaystyle||\nabla F_{i1}||_{F}^{2}+\cdots+||\nabla F_{1N}||_{F}^{2}| | ∇ italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | | ∇ italic_F start_POSTSUBSCRIPT 1 italic_N end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT =|Fi1Fi12|2+|Fi1Fi2|2++|Fi1FiN|2absentsuperscriptsubscript𝐹𝑖1superscriptsubscript𝐹𝑖122superscriptsubscript𝐹𝑖1subscript𝐹𝑖22superscriptsubscript𝐹𝑖1subscript𝐹𝑖𝑁2\displaystyle=|F_{i1}-F_{i1}^{2}|^{2}+|F_{i1}F_{i2}|^{2}+\cdots+|F_{i1}F_{iN}|% ^{2}= | italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT - italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.13)
+|Fi2Fi1|2+|Fi2Fi22|2++|Fi2FiN|2superscriptsubscript𝐹𝑖2subscript𝐹𝑖12superscriptsubscript𝐹𝑖2superscriptsubscript𝐹𝑖222superscriptsubscript𝐹𝑖2subscript𝐹𝑖𝑁2\displaystyle\hskip 14.22636pt+|F_{i2}F_{i1}|^{2}+|F_{i2}-F_{i2}^{2}|^{2}+% \cdots+|F_{i2}F_{iN}|^{2}+ | italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT - italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.14)
++limit-from\displaystyle\hskip 14.22636pt+\cdots\cdots\cdots\cdots++ ⋯ ⋯ ⋯ ⋯ + (A.15)
+|FiNFi1|2+|FiNFi2|2++|FiNFiN2|2superscriptsubscript𝐹𝑖𝑁subscript𝐹𝑖12superscriptsubscript𝐹𝑖𝑁subscript𝐹𝑖22superscriptsubscript𝐹𝑖𝑁superscriptsubscript𝐹𝑖𝑁22\displaystyle\hskip 14.22636pt+|F_{iN}F_{i1}|^{2}+|F_{iN}F_{i2}|^{2}+\cdots+|F% _{iN}-F_{iN}^{2}|^{2}+ | italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT - italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.16)
(Fi1)2(|1Fi1|+|Fi2|++|FiN|)2absentsuperscriptsubscript𝐹𝑖12superscript1subscript𝐹𝑖1subscript𝐹𝑖2subscript𝐹𝑖𝑁2\displaystyle\leq(F_{i1})^{2}(|1-F_{i1}|+|F_{i2}|+\cdots+|F_{iN}|)^{2}≤ ( italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( | 1 - italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | + | italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT | + ⋯ + | italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT | ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.17)
(Fi2)2(|Fi1|+|1Fi2|++|FiN|)2superscriptsubscript𝐹𝑖22superscriptsubscript𝐹𝑖11subscript𝐹𝑖2subscript𝐹𝑖𝑁2\displaystyle\hskip 14.22636pt(F_{i2})^{2}(|F_{i1}|+|1-F_{i2}|+\cdots+|F_{iN}|% )^{2}( italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( | italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | + | 1 - italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT | + ⋯ + | italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT | ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.18)
++limit-from\displaystyle\hskip 14.22636pt+\cdots\cdots\cdots\cdots++ ⋯ ⋯ ⋯ ⋯ + (A.19)
+(FiN)2(|Fi1|+|Fi2|++|1FiN|)2.superscriptsubscript𝐹𝑖𝑁2superscriptsubscript𝐹𝑖1subscript𝐹𝑖21subscript𝐹𝑖𝑁2\displaystyle\hskip 14.22636pt+(F_{iN})^{2}(|F_{i1}|+|F_{i2}|+\cdots+|1-F_{iN}% |)^{2}.+ ( italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( | italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | + | italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT | + ⋯ + | 1 - italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT | ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (A.20)

We then observe that since Fi1++FiN=1subscript𝐹𝑖1subscript𝐹𝑖𝑁1F_{i1}+\cdots+F_{iN}=1italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT + ⋯ + italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT = 1 we have that 1Fij=2(Fi1++Fij^++FiN)1subscript𝐹𝑖𝑗2subscript𝐹𝑖1^subscript𝐹𝑖𝑗subscript𝐹𝑖𝑁1-F_{ij}=2(F_{i1}+\cdots+\widehat{F_{ij}}+\cdots+F_{iN})1 - italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = 2 ( italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT + ⋯ + over^ start_ARG italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG + ⋯ + italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT ) where Fij^^subscript𝐹𝑖𝑗\widehat{F_{ij}}over^ start_ARG italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_ARG means we don’t include Fijsubscript𝐹𝑖𝑗F_{ij}italic_F start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT in the sum. This means we get the bound

Fi1F2++F1NF2superscriptsubscriptnormsubscript𝐹𝑖1𝐹2superscriptsubscriptnormsubscript𝐹1𝑁𝐹2\displaystyle||\nabla F_{i1}||_{F}^{2}+\cdots+||\nabla F_{1N}||_{F}^{2}| | ∇ italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | | ∇ italic_F start_POSTSUBSCRIPT 1 italic_N end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 4Fi12(Fi1^+Fi2++FiN)absent4superscriptsubscript𝐹𝑖12^subscript𝐹𝑖1subscript𝐹𝑖2subscript𝐹𝑖𝑁\displaystyle\leq 4F_{i1}^{2}(\widehat{F_{i1}}+F_{i2}+\cdots+F_{iN})≤ 4 italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( over^ start_ARG italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT end_ARG + italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT + ⋯ + italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT ) (A.21)
++limit-from\displaystyle\hskip 14.22636pt+\cdots\cdots\cdots\cdots++ ⋯ ⋯ ⋯ ⋯ + (A.22)
+4FiN2(Fi1+Fi2++FiN^)4superscriptsubscript𝐹𝑖𝑁2subscript𝐹𝑖1subscript𝐹𝑖2^subscript𝐹𝑖𝑁\displaystyle\hskip 14.22636pt+4F_{iN}^{2}(F_{i1}+F_{i2}+\cdots+\widehat{F_{iN% }})+ 4 italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT + italic_F start_POSTSUBSCRIPT italic_i 2 end_POSTSUBSCRIPT + ⋯ + over^ start_ARG italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT end_ARG ) (A.23)
4(Fi12+FiN2)absent4superscriptsubscript𝐹𝑖12superscriptsubscript𝐹𝑖𝑁2\displaystyle\leq 4(F_{i1}^{2}+\cdots F_{iN}^{2})≤ 4 ( italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ italic_F start_POSTSUBSCRIPT italic_i italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.24)
=4.absent4\displaystyle=4.= 4 . (A.25)

Putting all the bounds together for each of the terms N𝑁Nitalic_N terms Fi1F2++F1NF2superscriptsubscriptnormsubscript𝐹𝑖1𝐹2superscriptsubscriptnormsubscript𝐹1𝑁𝐹2||\nabla F_{i1}||_{F}^{2}+\cdots+||\nabla F_{1N}||_{F}^{2}| | ∇ italic_F start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ + | | ∇ italic_F start_POSTSUBSCRIPT 1 italic_N end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT we get

𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)F24Nsuperscriptsubscriptnorm𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐹24𝑁||\nabla\mathbf{softmax}(A)||_{F}^{2}\leq 4N| | ∇ bold_softmax ( italic_A ) | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ 4 italic_N (A.26)

and this implies

𝐬𝐨𝐟𝐭𝐦𝐚𝐱(A)F2N.subscriptnorm𝐬𝐨𝐟𝐭𝐦𝐚𝐱𝐴𝐹2𝑁||\nabla\mathbf{softmax}(A)||_{F}\leq 2\sqrt{N}.| | ∇ bold_softmax ( italic_A ) | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≤ 2 square-root start_ARG italic_N end_ARG . (A.27)

This finishes the proof of theorem 4.1.

A.1.2 Proofs for theorems section 4.2

In this section we will give the proof of theorems 4.2 and 4.4.

Proof of theorem 4.2.

We will split the matrix product XQKTXT𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇XQK^{T}X^{T}italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and think of it as the product of two matrices. Suppose 𝐀N×D𝒩(0,σ12)𝐀superscript𝑁𝐷similar-to𝒩0superscriptsubscript𝜎12\mathbf{A}\in\mathbb{R}^{N\times D}\sim\mathcal{N}(0,\sigma_{1}^{2})bold_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT ∼ caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), 𝐁D×N𝒩(0,σ22)𝐁superscript𝐷𝑁similar-to𝒩0superscriptsubscript𝜎22\mathbf{B}\in\mathbb{R}^{D\times N}\sim\mathcal{N}(0,\sigma_{2}^{2})bold_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT ∼ caligraphic_N ( 0 , italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) and 𝐂=𝐀𝐁𝐂𝐀𝐁\mathbf{C}=\mathbf{AB}bold_C = bold_AB. Each element in the matrix 𝐂𝐂\mathbf{C}bold_C can be written as a product of a row of 𝐀𝐀\mathbf{A}bold_A with a column of 𝐁𝐁\mathbf{B}bold_B. Since expectation is linear, we need to compute the expectation of each of these elements. We do the case of the entry c11subscript𝑐11c_{11}italic_c start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT which is the entry in 𝐂𝐂\mathbf{C}bold_C in the first row and first column. For the p=1𝑝1p=1italic_p = 1 case we can then compute

𝔼(c112)𝔼superscriptsubscript𝑐112\displaystyle\mathop{\mathbb{E}}(c_{11}^{2})blackboard_E ( italic_c start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) =𝔼((i=1Da1ibi1)2)absent𝔼superscriptsuperscriptsubscript𝑖1𝐷subscript𝑎1𝑖subscript𝑏𝑖12\displaystyle=\mathop{\mathbb{E}}((\sum_{i=1}^{D}a_{1i}b_{i1})^{2})= blackboard_E ( ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.28)
=𝔼(i=1Da1i2bi12+i=1Dj=1,jiDa1ibi1a1jbj1)absent𝔼superscriptsubscript𝑖1𝐷superscriptsubscript𝑎1𝑖2superscriptsubscript𝑏𝑖12superscriptsubscript𝑖1𝐷superscriptsubscriptformulae-sequence𝑗1𝑗𝑖𝐷subscript𝑎1𝑖subscript𝑏𝑖1subscript𝑎1𝑗subscript𝑏𝑗1\displaystyle=\mathop{\mathbb{E}}(\sum_{i=1}^{D}a_{1i}^{2}b_{i1}^{2}+\sum_{i=1% }^{D}\sum_{j=1,j\neq i}^{D}a_{1i}b_{i1}a_{1j}b_{j1})= blackboard_E ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 , italic_j ≠ italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j 1 end_POSTSUBSCRIPT )
=i=1D𝔼(a1i2)𝔼(bi12)+i=1Dj=1,jiD𝔼(a1i)𝔼(bi1)𝔼(a1j)𝔼(bj1)absentsuperscriptsubscript𝑖1𝐷𝔼superscriptsubscript𝑎1𝑖2𝔼superscriptsubscript𝑏𝑖12superscriptsubscript𝑖1𝐷superscriptsubscriptformulae-sequence𝑗1𝑗𝑖𝐷𝔼subscript𝑎1𝑖𝔼subscript𝑏𝑖1𝔼subscript𝑎1𝑗𝔼subscript𝑏𝑗1\displaystyle=\sum_{i=1}^{D}\mathop{\mathbb{E}}(a_{1i}^{2})\mathop{\mathbb{E}}% (b_{i1}^{2})+\sum_{i=1}^{D}\sum_{j=1,j\neq i}^{D}\mathop{\mathbb{E}}(a_{1i})% \mathop{\mathbb{E}}(b_{i1})\mathop{\mathbb{E}}(a_{1j})\mathop{\mathbb{E}}(b_{j% 1})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) blackboard_E ( italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 , italic_j ≠ italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT ) blackboard_E ( italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT ) blackboard_E ( italic_a start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT ) blackboard_E ( italic_b start_POSTSUBSCRIPT italic_j 1 end_POSTSUBSCRIPT )
=Dσ12σ22+0.absent𝐷superscriptsubscript𝜎12superscriptsubscript𝜎220\displaystyle=D\sigma_{1}^{2}\sigma_{2}^{2}+0.= italic_D italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 0 .

The Frobenius norm of the matrix 𝐂𝐂\mathbf{C}bold_C is just the sum of these values for all N2superscript𝑁2N^{2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT elements and this proves the p=1𝑝1p=1italic_p = 1 case.

For the case that p>1𝑝1p>1italic_p > 1 we proceed in a similar way. The key observation is that odd powers, in the matrix expansion, will have expectaion 00, so we need only consider the even powers. Therefore, suppose 𝐂=(𝐀𝐁)p𝐂superscript𝐀𝐁𝑝\mathbf{C}=(\mathbf{A}\mathbf{B})^{p}bold_C = ( bold_AB ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT. We will compute the expectation of the first entry c11𝐂subscript𝑐11𝐂c_{1}1\in\mathbf{C}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT 1 ∈ bold_C:

𝔼(c112)𝔼superscriptsubscript𝑐112\displaystyle\mathop{\mathbb{E}}(c_{11}^{2})blackboard_E ( italic_c start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) =𝔼((i=1Da1ibi1)2p)absent𝔼superscriptsuperscriptsubscript𝑖1𝐷subscript𝑎1𝑖subscript𝑏𝑖12𝑝\displaystyle=\mathop{\mathbb{E}}((\sum_{i=1}^{D}a_{1i}b_{i1})^{2p})= blackboard_E ( ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT ) (A.29)
=𝔼(i=1Da1i2pbi12p+i=1Dj=1,jiDa1i2p2bi12p2a1j2bj12+).absent𝔼superscriptsubscript𝑖1𝐷superscriptsubscript𝑎1𝑖2𝑝superscriptsubscript𝑏𝑖12𝑝superscriptsubscript𝑖1𝐷superscriptsubscriptformulae-sequence𝑗1𝑗𝑖𝐷superscriptsubscript𝑎1𝑖2𝑝2superscriptsubscript𝑏𝑖12𝑝2superscriptsubscript𝑎1𝑗2superscriptsubscript𝑏𝑗12\displaystyle=\mathop{\mathbb{E}}(\sum_{i=1}^{D}a_{1i}^{2p}b_{i1}^{2p}+\sum_{i% =1}^{D}\sum_{j=1,j\neq i}^{D}a_{1i}^{2p-2}b_{i1}^{2p-2}a_{1j}^{2}b_{j1}^{2}+% \cdots).= blackboard_E ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 , italic_j ≠ italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_j 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ⋯ ) .

Note that the first term only has a count of D𝐷Ditalic_D and the second term has a count of D(D1)𝐷𝐷1D(D-1)italic_D ( italic_D - 1 ). Thus, we only need to consider the 𝒪(Dp)𝒪superscript𝐷𝑝\mathcal{O}(D^{p})caligraphic_O ( italic_D start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ) term where all the components have a power of 2. The count is similar to choosing p𝑝pitalic_p items from D𝐷Ditalic_D,

𝔼(c112)𝔼superscriptsubscript𝑐112\displaystyle\mathop{\mathbb{E}}(c_{11}^{2})blackboard_E ( italic_c start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) 𝔼({i1,,ip}{1,,D}k=1pa1,ik2bik,12)absent𝔼subscriptsubscript𝑖1subscript𝑖𝑝1𝐷superscriptsubscriptproduct𝑘1𝑝superscriptsubscript𝑎1subscript𝑖𝑘2superscriptsubscript𝑏subscript𝑖𝑘12\displaystyle\approx\mathop{\mathbb{E}}(\sum_{\{i_{1},\dots,i_{p}\}\in\{1,% \dots,D\}}\prod_{k=1}^{p}a_{1,i_{k}}^{2}b_{i_{k},1}^{2})≈ blackboard_E ( ∑ start_POSTSUBSCRIPT { italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_i start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT } ∈ { 1 , … , italic_D } end_POSTSUBSCRIPT ∏ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT 1 , italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.30)
=(Dp)2p!2pσ12pσ22pabsent𝐷𝑝2𝑝superscript2𝑝superscriptsubscript𝜎12𝑝superscriptsubscript𝜎22𝑝\displaystyle=\left(\begin{array}[]{c}D\\ p\end{array}\right)\frac{2p!}{2^{p}}\sigma_{1}^{2p}\sigma_{2}^{2p}= ( start_ARRAY start_ROW start_CELL italic_D end_CELL end_ROW start_ROW start_CELL italic_p end_CELL end_ROW end_ARRAY ) divide start_ARG 2 italic_p ! end_ARG start_ARG 2 start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT
=D!(Dp)!2p!p!2pσ12pσ22pabsent𝐷𝐷𝑝2𝑝𝑝superscript2𝑝superscriptsubscript𝜎12𝑝superscriptsubscript𝜎22𝑝\displaystyle=\frac{D!}{(D-p)!}\frac{2p!}{p!2^{p}}\sigma_{1}^{2p}\sigma_{2}^{2p}= divide start_ARG italic_D ! end_ARG start_ARG ( italic_D - italic_p ) ! end_ARG divide start_ARG 2 italic_p ! end_ARG start_ARG italic_p ! 2 start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT
=D!(Dp)!2p!2p!!σ12pσ22pabsent𝐷𝐷𝑝2𝑝2double-factorial𝑝superscriptsubscript𝜎12𝑝superscriptsubscript𝜎22𝑝\displaystyle=\frac{D!}{(D-p)!}\frac{2p!}{2p!!}\sigma_{1}^{2p}\sigma_{2}^{2p}= divide start_ARG italic_D ! end_ARG start_ARG ( italic_D - italic_p ) ! end_ARG divide start_ARG 2 italic_p ! end_ARG start_ARG 2 italic_p !! end_ARG italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT
=D!(Dp)!(2p1)!!σ12pσ22p.absent𝐷𝐷𝑝double-factorial2𝑝1superscriptsubscript𝜎12𝑝superscriptsubscript𝜎22𝑝\displaystyle=\frac{D!}{(D-p)!}(2p-1)!!\sigma_{1}^{2p}\sigma_{2}^{2p}.= divide start_ARG italic_D ! end_ARG start_ARG ( italic_D - italic_p ) ! end_ARG ( 2 italic_p - 1 ) !! italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT .

D!(Dp)!𝐷𝐷𝑝\frac{D!}{(D-p)!}divide start_ARG italic_D ! end_ARG start_ARG ( italic_D - italic_p ) ! end_ARG can always be bounded above by Dpsuperscript𝐷𝑝D^{p}italic_D start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT, so the expectation can be upper bounded by Dp(2p1)!!σ12pσ22psuperscript𝐷𝑝double-factorial2𝑝1superscriptsubscript𝜎12𝑝superscriptsubscript𝜎22𝑝D^{p}(2p-1)!!\sigma_{1}^{2p}\sigma_{2}^{2p}italic_D start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ( 2 italic_p - 1 ) !! italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT and thus we get a quantity of the form 𝒪(N)𝒪𝑁\mathcal{O}(N)caligraphic_O ( italic_N ).

Proof of theorem 4.4.

We will do the p=1𝑝1p=1italic_p = 1 case first. We proceed similar to the proof of Theorem 4.2.

𝔼(XQKTXTQF2)𝔼subscriptsuperscriptnorm𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑄2𝐹\displaystyle\mathop{\mathbb{E}}(\|\frac{\partial XQK^{T}X^{T}}{\partial Q}\|^% {2}_{F})blackboard_E ( ∥ divide start_ARG ∂ italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_Q end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) =i=1Nj=1N𝔼(|xiTQKTxjQ|F2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼subscriptsuperscriptsuperscriptsubscript𝑥𝑖𝑇𝑄superscript𝐾𝑇subscript𝑥𝑗𝑄2𝐹\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(|\frac{\partial x% _{i}^{T}QK^{T}x_{j}}{\partial Q}|^{2}_{F})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( | divide start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_Q end_ARG | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) (A.31)
=i=1Nj=1N𝔼(xixjTKF2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼subscriptsuperscriptnormsubscript𝑥𝑖superscriptsubscript𝑥𝑗𝑇𝐾2𝐹\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(\|x_{i}x_{j}^{T}% K\|^{2}_{F})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( ∥ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_K ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) (A.32)
=i=1Nj=1N𝔼(k=1Dl=1d(xikm=1Dxjmkml)2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑥𝑖𝑘superscriptsubscript𝑚1𝐷subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙2\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(\sum_{k=1}^{D}% \sum_{l=1}^{d}(x_{ik}\sum_{m=1}^{D}x_{jm}k_{ml})^{2})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.33)
=i=1Nj=1N𝔼(k=1Dl=1dxik2(m=1Dxjmkml)2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑥𝑖𝑘2superscriptsuperscriptsubscript𝑚1𝐷subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙2\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(\sum_{k=1}^{D}% \sum_{l=1}^{d}x_{ik}^{2}(\sum_{m=1}^{D}x_{jm}k_{ml})^{2})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.34)
=i=1Nj=1N𝔼(k=1Dl=1dxik2(m=1Dxjm2kml2+m=1Dn=1,nmDxjmkmlxjnknl))absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑥𝑖𝑘2superscriptsubscript𝑚1𝐷superscriptsubscript𝑥𝑗𝑚2superscriptsubscript𝑘𝑚𝑙2superscriptsubscript𝑚1𝐷superscriptsubscriptformulae-sequence𝑛1𝑛𝑚𝐷subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙subscript𝑥𝑗𝑛subscript𝑘𝑛𝑙\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(\sum_{k=1}^{D}% \sum_{l=1}^{d}x_{ik}^{2}(\sum_{m=1}^{D}x_{jm}^{2}k_{ml}^{2}+\sum_{m=1}^{D}\sum% _{n=1,n\neq m}^{D}x_{jm}k_{ml}x_{jn}k_{nl}))= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 , italic_n ≠ italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_n end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_n italic_l end_POSTSUBSCRIPT ) ) (A.35)
=i=1Nj=1Nk=1Dl=1d(m=1D𝔼(xik2xjm2kml2)+m=1Dn=1,nmD𝔼(xik2xjmkmlxjnknl))absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑚1𝐷𝔼superscriptsubscript𝑥𝑖𝑘2superscriptsubscript𝑥𝑗𝑚2superscriptsubscript𝑘𝑚𝑙2superscriptsubscript𝑚1𝐷superscriptsubscriptformulae-sequence𝑛1𝑛𝑚𝐷𝔼superscriptsubscript𝑥𝑖𝑘2subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙subscript𝑥𝑗𝑛subscript𝑘𝑛𝑙\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\sum_{k=1}^{D}\sum_{l=1}^{d}(\sum_{m% =1}^{D}\mathop{\mathbb{E}}(x_{ik}^{2}x_{jm}^{2}k_{ml}^{2})+\sum_{m=1}^{D}\sum_% {n=1,n\neq m}^{D}\mathop{\mathbb{E}}(x_{ik}^{2}x_{jm}k_{ml}x_{jn}k_{nl}))= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 , italic_n ≠ italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_n end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_n italic_l end_POSTSUBSCRIPT ) ) (A.36)
=i=1Nj=1Nk=1Dl=1d(m=1D𝔼(xik2xjm2kml2)+0)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑚1𝐷𝔼superscriptsubscript𝑥𝑖𝑘2superscriptsubscript𝑥𝑗𝑚2superscriptsubscript𝑘𝑚𝑙20\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\sum_{k=1}^{D}\sum_{l=1}^{d}(\sum_{m% =1}^{D}\mathop{\mathbb{E}}(x_{ik}^{2}x_{jm}^{2}k_{ml}^{2})+0)= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + 0 ) (A.37)
=i=1Nj=1Nk=1Dl=1dm=1D𝔼(xik2xjm2kml2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑚1𝐷𝔼superscriptsubscript𝑥𝑖𝑘2superscriptsubscript𝑥𝑗𝑚2superscriptsubscript𝑘𝑚𝑙2\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\sum_{k=1}^{D}\sum_{l=1}^{d}\sum_{m=% 1}^{D}\mathop{\mathbb{E}}(x_{ik}^{2}x_{jm}^{2}k_{ml}^{2})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.38)
=i=1Nk=1Dl=1d𝔼(xik2xik2kkl2)+i=1Nj=1,jiNk=1Dl=1dm=1,mkD𝔼(xik2xjm2kml2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑𝔼superscriptsubscript𝑥𝑖𝑘2superscriptsubscript𝑥𝑖𝑘2superscriptsubscript𝑘𝑘𝑙2superscriptsubscript𝑖1𝑁superscriptsubscriptformulae-sequence𝑗1𝑗𝑖𝑁superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscriptformulae-sequence𝑚1𝑚𝑘𝐷𝔼superscriptsubscript𝑥𝑖𝑘2superscriptsubscript𝑥𝑗𝑚2superscriptsubscript𝑘𝑚𝑙2\displaystyle=\sum_{i=1}^{N}\sum_{k=1}^{D}\sum_{l=1}^{d}\mathop{\mathbb{E}}(x_% {ik}^{2}x_{ik}^{2}k_{kl}^{2})+\sum_{i=1}^{N}\sum_{j=1,j\neq i}^{N}\sum_{k=1}^{% D}\sum_{l=1}^{d}\sum_{m=1,m\neq k}^{D}\mathop{\mathbb{E}}(x_{ik}^{2}x_{jm}^{2}% k_{ml}^{2})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT blackboard_E ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 , italic_j ≠ italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 , italic_m ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.39)
=NDd3σx4σw2+N(N1)D(D1)dσx4σw2absent𝑁𝐷𝑑3superscriptsubscript𝜎𝑥4superscriptsubscript𝜎𝑤2𝑁𝑁1𝐷𝐷1𝑑superscriptsubscript𝜎𝑥4superscriptsubscript𝜎𝑤2\displaystyle=NDd3\sigma_{x}^{4}\sigma_{w}^{2}+N(N-1)D(D-1)d\sigma_{x}^{4}% \sigma_{w}^{2}= italic_N italic_D italic_d 3 italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_N ( italic_N - 1 ) italic_D ( italic_D - 1 ) italic_d italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.40)
N2D2dσx4σw2.absentsuperscript𝑁2superscript𝐷2𝑑superscriptsubscript𝜎𝑥4superscriptsubscript𝜎𝑤2\displaystyle\approx N^{2}D^{2}d\sigma_{x}^{4}\sigma_{w}^{2}.≈ italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (A.41)

When p>1𝑝1p>1italic_p > 1 we can proceed in a similar way.

𝔼((XQKTXT)pQF2)𝔼subscriptsuperscriptnormsuperscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑝𝑄2𝐹\displaystyle\mathop{\mathbb{E}}(\|\frac{\partial(XQK^{T}X^{T})^{p}}{\partial Q% }\|^{2}_{F})blackboard_E ( ∥ divide start_ARG ∂ ( italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_Q end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) =i=1Nj=1N𝔼((xiTQKTxj)pQF2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼subscriptsuperscriptnormsuperscriptsuperscriptsubscript𝑥𝑖𝑇𝑄superscript𝐾𝑇subscript𝑥𝑗𝑝𝑄2𝐹\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(\|\frac{(% \partial x_{i}^{T}QK^{T}x_{j})^{p}}{\partial Q}\|^{2}_{F})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( ∥ divide start_ARG ( ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_Q end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) (A.42)
=i=1Nj=1N𝔼(p(xiTQKTxj)p1xiTQKTxjQF2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼subscriptsuperscriptnorm𝑝superscriptsuperscriptsubscript𝑥𝑖𝑇𝑄superscript𝐾𝑇subscript𝑥𝑗𝑝1superscriptsubscript𝑥𝑖𝑇𝑄superscript𝐾𝑇subscript𝑥𝑗𝑄2𝐹\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(\|p(x_{i}^{T}QK^% {T}x_{j})^{p-1}\frac{\partial x_{i}^{T}QK^{T}x_{j}}{\partial Q}\|^{2}_{F})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( ∥ italic_p ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_p - 1 end_POSTSUPERSCRIPT divide start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_Q end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) (A.43)
=i=1Nj=1N𝔼(p(xiTQKTxj)p1xixjTKF2)absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼subscriptsuperscriptnorm𝑝superscriptsuperscriptsubscript𝑥𝑖𝑇𝑄superscript𝐾𝑇subscript𝑥𝑗𝑝1subscript𝑥𝑖superscriptsubscript𝑥𝑗𝑇𝐾2𝐹\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(\|p(x_{i}^{T}QK^% {T}x_{j})^{p-1}x_{i}x_{j}^{T}K\|^{2}_{F})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( ∥ italic_p ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_p - 1 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_K ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) (A.44)
=i=1Nj=1N𝔼(p2(xiTQKTxj)2p2k=1Dl=1d(xikm=1Dxjmkml)2).absentsuperscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼superscript𝑝2superscriptsuperscriptsubscript𝑥𝑖𝑇𝑄superscript𝐾𝑇subscript𝑥𝑗2𝑝2superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑥𝑖𝑘superscriptsubscript𝑚1𝐷subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙2\displaystyle=\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(p^{2}(x_{i}^{T}% QK^{T}x_{j})^{2p-2}\sum_{k=1}^{D}\sum_{l=1}^{d}(x_{ik}\sum_{m=1}^{D}x_{jm}k_{% ml})^{2}).= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . (A.45)

We know that

(xiTQKTxj)2p2superscriptsuperscriptsubscript𝑥𝑖𝑇𝑄superscript𝐾𝑇subscript𝑥𝑗2𝑝2\displaystyle(x_{i}^{T}QK^{T}x_{j})^{2p-2}( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT =(l=1d((k=1Dxikqkl)(m=1Dxjmkml)))2p2absentsuperscriptsuperscriptsubscript𝑙1𝑑superscriptsubscript𝑘1𝐷subscript𝑥𝑖𝑘subscript𝑞𝑘𝑙superscriptsubscript𝑚1𝐷subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙2𝑝2\displaystyle=(\sum_{l=1}^{d}((\sum_{k=1}^{D}x_{ik}q_{kl})\cdot(\sum_{m=1}^{D}% x_{jm}k_{ml})))^{2p-2}= ( ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT ) ⋅ ( ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT ) ) ) start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT (A.46)
=(l=1dk=1Dm=1Dxikqklxjmkml)2p2absentsuperscriptsuperscriptsubscript𝑙1𝑑superscriptsubscript𝑘1𝐷superscriptsubscript𝑚1𝐷subscript𝑥𝑖𝑘subscript𝑞𝑘𝑙subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙2𝑝2\displaystyle=(\sum_{l=1}^{d}\sum_{k=1}^{D}\sum_{m=1}^{D}x_{ik}q_{kl}x_{jm}k_{% ml})^{2p-2}= ( ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT (A.47)
=(k=1Dm=1Dxikxjml=1dqklkml)2p2absentsuperscriptsuperscriptsubscript𝑘1𝐷superscriptsubscript𝑚1𝐷subscript𝑥𝑖𝑘subscript𝑥𝑗𝑚superscriptsubscript𝑙1𝑑subscript𝑞𝑘𝑙subscript𝑘𝑚𝑙2𝑝2\displaystyle=(\sum_{k=1}^{D}\sum_{m=1}^{D}x_{ik}x_{jm}\sum_{l=1}^{d}q_{kl}k_{% ml})^{2p-2}= ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT (A.48)
=(k=1Dm=1Dxikxjmakm)2p2,absentsuperscriptsuperscriptsubscript𝑘1𝐷superscriptsubscript𝑚1𝐷subscript𝑥𝑖𝑘subscript𝑥𝑗𝑚subscript𝑎𝑘𝑚2𝑝2\displaystyle=(\sum_{k=1}^{D}\sum_{m=1}^{D}x_{ik}x_{jm}a_{km})^{2p-2},= ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_k italic_m end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT , (A.49)

where akm=l=1dqklkmlsubscript𝑎𝑘𝑚superscriptsubscript𝑙1𝑑subscript𝑞𝑘𝑙subscript𝑘𝑚𝑙a_{km}=\sum_{l=1}^{d}q_{kl}k_{ml}italic_a start_POSTSUBSCRIPT italic_k italic_m end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_k italic_l end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT. Let zij=k=1Dm=1Dxikxjmakmsubscript𝑧𝑖𝑗superscriptsubscript𝑘1𝐷superscriptsubscript𝑚1𝐷subscript𝑥𝑖𝑘subscript𝑥𝑗𝑚subscript𝑎𝑘𝑚z_{ij}=\sum_{k=1}^{D}\sum_{m=1}^{D}x_{ik}x_{jm}a_{km}italic_z start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_k italic_m end_POSTSUBSCRIPT Thus we have

𝔼((XQKTXT)pQF2)𝔼subscriptsuperscriptnormsuperscript𝑋𝑄superscript𝐾𝑇superscript𝑋𝑇𝑝𝑄2𝐹\displaystyle\mathop{\mathbb{E}}(\|\frac{\partial(XQK^{T}X^{T})^{p}}{\partial Q% }\|^{2}_{F})blackboard_E ( ∥ divide start_ARG ∂ ( italic_X italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT end_ARG start_ARG ∂ italic_Q end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) =\displaystyle== p2i=1Nj=1N𝔼(zij2p2k=1Dl=1d(xikm=1Dxjmkml)2)superscript𝑝2superscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁𝔼superscriptsubscript𝑧𝑖𝑗2𝑝2superscriptsubscript𝑘1𝐷superscriptsubscript𝑙1𝑑superscriptsubscript𝑥𝑖𝑘superscriptsubscript𝑚1𝐷subscript𝑥𝑗𝑚subscript𝑘𝑚𝑙2\displaystyle p^{2}\sum_{i=1}^{N}\sum_{j=1}^{N}\mathop{\mathbb{E}}(z_{ij}^{2p-% 2}\sum_{k=1}^{D}\sum_{l=1}^{d}(x_{ik}\sum_{m=1}^{D}x_{jm}k_{ml})^{2})italic_p start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E ( italic_z start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.50)
=\displaystyle== i=1Nj=1Nk=1Dl=1d(m=1D𝔼(zij2p2xik2xjm2kml2)\displaystyle\sum_{i=1}^{N}\sum_{j=1}^{N}\sum_{k=1}^{D}\sum_{l=1}^{d}(\sum_{m=% 1}^{D}\mathop{\mathbb{E}}(z_{ij}^{2p-2}x_{ik}^{2}x_{jm}^{2}k_{ml}^{2})∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_z start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.51)
+m=1Dn=1,nmD𝔼(zij2p2xik2xjmkmlxjnknl))\displaystyle+\sum_{m=1}^{D}\sum_{n=1,n\neq m}^{D}\mathop{\mathbb{E}}(z_{ij}^{% 2p-2}x_{ik}^{2}x_{jm}k_{ml}x_{jn}k_{nl}))+ ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 , italic_n ≠ italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_z start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_n end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_n italic_l end_POSTSUBSCRIPT ) ) (A.52)
=\displaystyle== i=1Nj=1Nk=1Dl=1d(m=1D𝔼(zij2p2xik2xjm2kml2)\displaystyle\sum_{i=1}^{N}\sum_{j=1}^{N}\sum_{k=1}^{D}\sum_{l=1}^{d}(\sum_{m=% 1}^{D}\mathop{\mathbb{E}}(z_{ij}^{2p-2}x_{ik}^{2}x_{jm}^{2}k_{ml}^{2})∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_z start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (A.53)
+m=1Dn=1,nmD𝔼(zij2p3xik2xjm2kml2xjn2knl2))\displaystyle+\sum_{m=1}^{D}\sum_{n=1,n\neq m}^{D}\mathop{\mathbb{E}}(z_{ij}^{% 2p-3}x_{ik}^{2}x_{jm}^{2}k_{ml}^{2}x_{jn}^{2}k_{nl}^{2}))+ ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 , italic_n ≠ italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E ( italic_z start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_p - 3 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_m italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_j italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_n italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) (A.54)
\displaystyle\approx N2Dd(D(D2p2dp1(2p3)!!σx4pσw4p2)+0\displaystyle N^{2}Dd(D(D^{2p-2}d^{p-1}(2p-3)!!\sigma_{x}^{4p}\sigma_{w}^{4p-2% })+0italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D italic_d ( italic_D ( italic_D start_POSTSUPERSCRIPT 2 italic_p - 2 end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT italic_p - 1 end_POSTSUPERSCRIPT ( 2 italic_p - 3 ) !! italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 italic_p end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 italic_p - 2 end_POSTSUPERSCRIPT ) + 0 (A.55)
=\displaystyle== N2D2pdp(2p3)!!σx4pσw4p2superscript𝑁2superscript𝐷2𝑝superscript𝑑𝑝double-factorial2𝑝3superscriptsubscript𝜎𝑥4𝑝superscriptsubscript𝜎𝑤4𝑝2\displaystyle N^{2}D^{2p}d^{p}(2p-3)!!\sigma_{x}^{4p}\sigma_{w}^{4p-2}italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT 2 italic_p end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ( 2 italic_p - 3 ) !! italic_σ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 italic_p end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 4 italic_p - 2 end_POSTSUPERSCRIPT (A.56)

showing that we can bound the gradient by a quantity of the form 𝒪(N)𝒪𝑁\mathcal{O}(N)caligraphic_O ( italic_N ) and the proof is complete.

A.2 Experiments

A.2.1 Ablations

In this section we carry out ablations on the experiments we did in the main paper.

Scale ablations:

The theory we developed in section 4.2 suggested that we needed to scale our polynomial activations by 1N1𝑁\frac{1}{\sqrt{N}}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG to obtain a complexity bound of 𝒪(N)𝒪𝑁\mathcal{O}(\sqrt{N})caligraphic_O ( square-root start_ARG italic_N end_ARG ). In general, we could also scale the polynomial activations by 𝒪(1N)𝒪1𝑁\mathcal{O}(\frac{1}{\sqrt{N}})caligraphic_O ( divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG ) to see if we can get a better accuracr. Figure 10 carries out an ablation on both the ViT-Base architecture and the ViT-Small architecture to see how different scales affect the accuracy for the x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT activation. We found that in general a scale from 1818\frac{1}{8}divide start_ARG 1 end_ARG start_ARG 8 end_ARG to 125125\frac{1}{25}divide start_ARG 1 end_ARG start_ARG 25 end_ARG seemed to perform very well.

Refer to caption
Figure 10: We show how the top-1% accuracy changes as the scale of the activation x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT for a ViT-Base (left) and ViT-Small (right) architecture. The x-axis plots the denominator of the scale for easier readability. In other words as we go to the right of the x-axis the scale used on the activation x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT is getting smaller.
Activation ablations:

Our theory primarily compared the activations 1Nx31𝑁superscript𝑥3\frac{1}{\sqrt{N}}x^{3}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and 1Nx1𝑁𝑥\frac{1}{\sqrt{N}}xdivide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG italic_x with softmax where N𝑁Nitalic_N was the sequence length of the input. In this section we carry out further experimental comparisons of our activations with other activations. We compare with a variety of activations used in the literature as well as the exponential function Exp=ex𝐸𝑥𝑝superscript𝑒𝑥Exp=e^{-x}italic_E italic_x italic_p = italic_e start_POSTSUPERSCRIPT - italic_x end_POSTSUPERSCRIPT which can be thought of as softmax without the row normalization scaling. Table 5 shows the results of the experiments.

Activation ViT-Base ViT-small
softmax 79.6 80.2
x314superscript𝑥314\frac{x^{3}}{14}divide start_ARG italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG 14 end_ARG 79.6 80.5
x14𝑥14\frac{x}{14}divide start_ARG italic_x end_ARG start_ARG 14 end_ARG 76.9 77.8
x214superscript𝑥214\frac{x^{2}}{14}divide start_ARG italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 14 end_ARG 78.7 79.9
ReLU𝑅𝑒𝐿𝑈ReLUitalic_R italic_e italic_L italic_U 77.3 77.5
ELU𝐸𝐿𝑈ELUitalic_E italic_L italic_U 78.2 78.5
GELU𝐺𝐸𝐿𝑈GELUitalic_G italic_E italic_L italic_U 78.2 78.4
Tanh𝑇𝑎𝑛Tanhitalic_T italic_a italic_n italic_h 77.1 77.3
Exp𝐸𝑥𝑝Expitalic_E italic_x italic_p 77.9 78.0
Table 5: Top-1% accuracy with various activations on a ViT-Base and ViT-Small architecture.

A.2.2 Frobenius norm computations

In section 5.1 we showed plots of the Frobenius norm of the self-attention matrix and for the Jacobian of the self-attention matrix for different scalings of the x3superscript𝑥3x^{3}italic_x start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT and x𝑥xitalic_x activations along with softmax. This was done for a ViT-Tiny architecture on the Tiny-ImageNet dataset. Figure 11 shows the plots of the Frobenius norm of the self-attention matrix for the Tiny-ViT architecture, during training, for all layers averaged over the heads within each layer. Figure 12 shows the Frobenius norm of the Jacobian of the self-attention matrix during training for each layer, averaged over the total number of heads within each layer.

Figures 13 and 14 show similar results for the ViT-Small architecture.

Refer to caption
Figure 11: Frobenius norm of self-attention matrix for different scaled activations on ViT-Tiny during training (zoom in for better viewing).
Refer to caption
Figure 12: Frobenius norm of Jacobian of self-attention matrix for different scaled activations on ViT-Tiny during training (zoom in for better viewing).
Refer to caption
Figure 13: Frobenius norm of self-attention matrix for different scaled activations on ViT-Small during training (zoom in for better viewing).
Refer to caption
Figure 14: Frobenius norm of Jacobian of self-attention matrix for different scaled activations on ViT-Small during training (zoom in for better viewing).