A Mean-Field Analysis of Neural Stochastic Gradient Descent-Ascent for Functional Minimax Optimization
Abstract
This paper studies minimax optimization problems defined over infinite-dimensional function classes of overparameterized two-layer neural networks. In particular, we consider the minimax optimization problem stemming from estimating linear functional equations defined by conditional expectations, where the objective functions are quadratic in the functional spaces. We address (i) the convergence of the stochastic gradient descent-ascent algorithm and (ii) the representation learning of the neural networks. We establish convergence under the mean-field regime by considering the continuous-time and infinite-width limit of the optimization dynamics. Under this regime, the stochastic gradient descent-ascent corresponds to a Wasserstein gradient flow over the space of probability measures defined over the space of neural network parameters. We prove that the Wasserstein gradient flow converges globally to a stationary point of the minimax objective at a sublinear rate, and additionally finds the solution to the functional equation when the regularizer of the minimax objective is strongly convex. Here denotes the time and is a scaling parameter of the neural networks. In terms of representation learning, our results show that the feature representation induced by the neural networks is allowed to deviate from the initial one by the magnitude of , measured in terms of the Wasserstein distance. Finally, we apply our general results to concrete examples including policy evaluation, nonparametric instrumental variable regression, asset pricing, and adversarial Riesz representer estimation.
1 Introduction
Minimax optimization problems are ubiquitous in machine learning, statistics, economics, and other fields. Examples include generative adversarial networks (GANs) (Goodfellow et al., 2020; Salimans et al., 2016), adversarial training (Ganin et al., 2016; Madry et al., 2017), robust optimization (Ben-Tal et al., 2009; Levy et al., 2020), and zero-sum games (Xie et al., 2020b; Zhao et al., 2022). The goal in minimax optimization is to find a solution to the problem , where is a bivariate objective function, and and are the feasible sets of the decision variables and . In modern machine learning applications, and are often function classes flexibly parameterized by neural networks, and the objective can be approximated using data. The minimax optimization problem is often solved using first-order optimization algorithms. Despite hugely successful in diverse applications, there is no global convergence theory for various popular first-order algorithms solving general minimax optimization using neural networks yet.
In this work, we study the convergence of first-order algorithms for solving minimax optimization problems where and are both flexibly parameterized by two-layer neural networks, and the objective functional is quadratic in and up to regularization:
(1.1) |
where is a convex regularizer that penalizes the complexity of . Here the expectation is taken with respect to the joint distribution of random variables , is a function of , and takes and a function as its input and is linear in . The objective function (1.1) arises from solving a linear functional conditional moment equation of the form if and only if . Here is a vector containing all the endogenous variables and contains all the exogenous/pre-determent variables. This problem has ample applications, including policy evaluation (Cai et al., 2019; Duan et al., 2020; Jin et al., 2021; Chen and Qi, 2022; Ramprasad et al., 2022), nonparametric instrumental variable regression (Blundell et al., 2007; Chen and Pouzo, 2012; Chen and Christensen, 2018; Xu et al., 2020), and asset pricing (Chen and Ludvigson, 2009; Chen et al., 2014, 2024). The minimax objective in (1.1) arises when we solve the conditional moment equation via adversarial estimation (Uehara et al., 2020; Duan et al., 2021; Chernozhukov et al., 2020; Liao et al., 2020; Wai et al., 2020; Bennett et al., 2019), which introduces a dual function and transforms equation solving into a minimax optimization.
We study the infinite-dimensional minimax optimization in (1.1) over the space of overparameterized two-layer neural networks. Specifically, a neural network is represented by , where is the number of neurons, denotes the -th neuron, are the network parameters, and is a scaling parameter. We aim to solve the minimax optimization in (1.1) with both and are represented by overparameterized two-layer neural networks, which is favorable especially when is a high-dimensional vector. To solve this problem, we consider the arguably simplest first-order algorithm, stochastic gradient descent-ascent (SGDA), where the parameters of and are simultaneously updated using stochastic gradients of the objective functional. Specifically, we aim to address the following two questions:
-
•
Does SGDA with overparameterized neural networks converge to some solution?
-
•
Does SGDA learn data-dependent features that yield a statistically accurate solution?
Answering these questions involves two intricate challenges in terms of optimization and representation learning using neural networks. First, the minimax objective is nonconvex-nonconcave with respect to the neural network parameters of and , it is unclear whether first-order algorithms converge. Second, the representation of the neural network evolves during the course of optimization, and it is unclear how to track and assess the data-dependent features learned by the neural networks. While there are some existing works on neural network optimization using the technique of neural tangent kernel (NTK) (Jacot et al., 2018; Du et al., 2018; Cai et al., 2019; Xu and Gu, 2020; Wang et al., 2022), such an approach suggests that the feature representation of the neural networks is fixed throughout training and is only determined by the initialization of the network parameters. Despite being an elegant theoretical framework, the NTK approach is limited in its ability to capture the representation learning aspect of neural network optimization. To show that the neural network optimization algorithms learn useful data-dependent features, in addition to establishing convergence, more importantly, we need to show that (i) the algorithm approximately finds a proper solution concept, e.g., a stationary point or a local or global optimizer of the minimax objective function, and (ii) the representation of the neural networks moves from the initialization by a considerable amount.
In this paper, we tackle both challenges by leveraging the framework of mean-field analysis of overparameterized neural networks (Chizat and Bach, 2018; Mei et al., 2018, 2019; Zhang et al., 2020; Lu et al., 2020b; Zhang et al., 2021b; Sirignano and Spiliopoulos, 2020b, a, 2022; Chen et al., 2020b; Fang et al., 2021b). In particular, we focus on the continuous-time and infinite width limit of the SGDA algorithm, where the stepsize goes to zero and the width goes to infinity. From the mean-field lens, a neural network can be identified with a probability measure by writing , where is the empirical distribution of and is the scaling parameter of the neural network. Thus, parameter updates of SGDA can be regarded as updates of the probability measure . From this perspective, we prove that in the continuous-time and infinite width limit, SGDA corresponds to a gradient flow of the minimax objective in the Wasserstein space, i.e., the space of probability measures over the parameter space equipped with the Wasserstein-2 distance. Besides, by defining a proper potential function that characterizes the stationary point of the minimax objective, we prove that the Wasserstein gradient flow converges to a stationary point at a sublinear rate of , where is the time horizon and is a scaling parameter of the neural network. Moreover, we prove that the Wasserstein distance between the parameter distribution found by SGDA and its initialization is , which shows that the representation of the neural networks is allowed to move from the initialization by a considerable amount. Such a behavior is not captured by the NTK analysis, in which the representation is shown to be fixed at the initialization. Furthermore, when the regularization on satisfies a version of strong convexity, we prove that the Wasserstein gradient flow converges to the global optimizer at a sublinear rate.
To the best of our knowledge, our work provides the first theoretical analysis of an optimization algorithm solving functional conditional moment equations using neural networks with representation learning. We apply our general theory to three important examples: policy evaluation, instrumental variables regression, and asset pricing. and adversarial Riesz representer estimation. In these examples, we prove that the SGDA algorithm finds the global solution with overparameterized neural networks. Moreover, SGDA learns data-dependent features that enable these statistically accurate estimators.
1.1 Related Works
Minimax Optimization. Our work is closely related to the literature on first-order methods for solving minimax optimization problems. These works establish the convergence rate or iteration complexity of first-order methods under various assumptions on the objective function. In particular, most of the existing works focus on finite-dimensional parameter spaces and one of the following objective functions: (i) convex-concave (Lin et al., 2020b; Ibrahim et al., 2019; Ouyang and Xu, 2021; Alkousa et al., 2019; Luo et al., 2021; Xie et al., 2020a; Han et al., 2024; Li et al., 2023; Jin et al., 2022), (ii) nonconvex-concave (Jin et al., 2019; Lin et al., 2020a; Lu et al., 2020a; Ostrovskii et al., 2021b; Zhao, 2023; Huang et al., 2022; Luo et al., 2020; Zhang et al., 2021a; Nouiehed et al., 2019; Thekumparampil et al., 2019), and (iii) nonconvex-nonconcave (Li et al., 2022; Diakonikolas et al., 2021; Ostrovskii et al., 2021a; Yang et al., 2022; Grimmer et al., 2022; Hajizadeh et al., 2024; Grimmer et al., 2023; Yang et al., 2020).
Our work can be viewed as an extension of convex-concave minimax optimization to the infinite-dimensional functional space. In particular, our objective is a regularized quadratic functional with respect to the input functions, which is then restricted to the class of overparameterized neural networks. Note that the objective of interest is in fact nonconvex-nonconcave in the neural network parameter space. Compared with the work on general nonconvex-nonconcave minimax optimization problems, our setting has a better underlying structure in the functional space in terms of convexity. This structure enables us to lift the network parameter updates to the Wasserstein space and analyze the gradient flow in the space of distributions. Our approach leverages the hidden convexity-concavity behind the seemingly nonconvex-nonconcave objective function and thus achieves better results in terms of algorithm convergence and complexity.
Mean-field Analysis in Deep Learning. Our work is closely related to the recent study of neural network training via gradient-based methods. One line of research establishes the convergence of gradient-based algorithms for training overparameterized neural networks under the “lazy training” regime, where the neural networks behave similarly to random kernel functions. Such a regime is also known as the as the neural tangent kernel regime (Jacot et al., 2018; Allen-Zhu et al., 2019a, b; Chen et al., 2020a; Frei and Gu, 2021; Zou and Gu, 2019; Du et al., 2018, 2019; Arora et al., 2019a, b; Huang and Yau, 2020). Our work is more related to another line of research based on the perspective of mean-field approximation (Mei et al., 2018, 2019; Chizat and Bach, 2018; Sirignano and Spiliopoulos, 2020b, a, 2022; Chen et al., 2020b; Fang et al., 2021b; Chen et al., 2019). Under the mean-field view, the neural network parameters are identified as a distribution over the parameter space. As a result, the evolution of parameters by gradient-based updates is captured by a differential equation that governs the evolution of the corresponding distribution. By elevating the training dynamics to an infinite Wasserstein space, the optimization objective often enjoys a benign landscape, which yields admits a more tractable analysis and global convergence. See, e.g, Zhang et al. (2020, 2021b); Fang et al. (2021b); Lu et al. (2020b); Fang et al. (2019); Chizat (2022); Hu et al. (2021); Nitanda et al. (2022) and the references therein. Also, see Fang et al. (2021a) for a recent survey.
Our work is especially related to the mean-field analysis of the Neural Temporal Difference (TD) (Zhang et al., 2020) and the Neural Actor-Critic (AC) (Zhang et al., 2021b) in reinforcement learning. These previous works have provided an analysis of the global convergence of the TD and AC algorithm with two-layered overparameterized neural networks. The optimization problem in these two tasks is the minimization of an objective where only one neural network is involved. Rather different from these works, we focus on minimax optimization, which requires neural network parameterization of both the primal function and the dual function. This brings new challenges to the analysis as the gradient dynamics of the primal and dual neural networks give birth to a coupled system of PDEs. To the best of our knowledge, our paper is the first to apply the mean-field limit to study the convergence of algorithms in solving the general form of functional conditional moment equations using neural networks.
Adversarial Estimation. Our work is also related to the literature on adversarial estimation, a method that solves a functional conditional moment equation by introducing a dual function and reformulating the original problem into a minimax optimization. Our work studies this type of minimax optimization with overparameterized neural networks. Thus, our work is more related to the study of adversarial estimation within neural network function classes (Dikkala et al., 2020; Chernozhukov et al., 2020; Bennett et al., 2019; Xu et al., 2021). Compared with our work, these studies focus on statistical errors pertinent to neural networks, assuming the optimization problem is solved perfectly. We instead study the optimization algorithm and establish the convergence of stochastic gradient-descent-ascent with neural networks.
Several previous works have also explored the convergence of optimization dynamics in adversarial estimation with neural networks. In particular, Neural GTD (Wai et al., 2020) and Neural SEM (Liao et al., 2020) analyze respectively the convergence for off-policy evaluation and structural equation models estimation with overparameterized two-layered neural network. However, their analyses are based on the idea of neural tangent kernel (NTK), where the employed neural network has a fixed representation during training, and the representation is completely determined by the initialization. In contrast, our work adopts the mean-field approach, which enables learning a data-dependent representation.
2 Preliminaries
The functional conditional moment equations cover many important examples in statistics, machine learning, economics, and causal inference. In this section, we first introduce the general formulation of the functional conditional moment equations and then reformulate them into a minimax optimization problem. Then, we present a few concrete examples of function conditional moment equations such as policy evaluation, nonparametric instrumental variables regression, asset pricing, and Riesz representers estimation. Finally, we introduce the background of mean-field neural networks and Wasserstein space, which are essential for the convergence analysis of the SGDA algorithm.
2.1 Functional Conditional Moment Equations
In this section, we introduce the general formulation of functional conditional moment equations. Let be a vector that includes all the endogenous variables, let denote all the exogenous variables, and let denote the joint distribution of . We let denote the expectation taken with respect to the joint distribution of and denote the conditional expectation using the conditional distribution of given . Let be a subset of variables that may contain both the endogenous and exogenous variables, and let denote a Hilbert space of measurable functions of with finite second moment. Let denote a class of functions defined on . In a functional conditional moment equation problem, we aim to find a function that solves the following functional equation involving the conditional distribution of given over :
(2.1) |
where is a known functional.
For any function and any , we define a functional as
(2.2) |
In other words, the conditional moment equation problem in (2.1) boils down to finding a function such that is a zero function on . Therefore an equivalent way to solve in (2.1) is by solving (Ai and Chen, 2003; Chen and Pouzo, 2012). To control the complexity of the function class , Ai and Chen (2003) propose to use flexible sieve spaces that becomes dense in as the sieve dimension grows to infinity with data sample size , and proposed the so-called sieve minimum distance criterion In particular, Ai and Chen (2003) allow for two-layer NNs, splines, wavelets, Fourier series, and all kinds of polynomial sieves to approximate functions in . Alternatively Chen and Pouzo (2012) propose the following penalized (or regularized) minimum distance criterion:
(2.3) |
where is a regularization parameter, is a regularizer on function . They allow that to be any convex or lower-semicompact regularizer. In the minimum distance approach, for any fixed , the authors first estimate by the following least squares criterion:
Furthermore, we assume that the functional is affine in , which captures several important applications in machine learning and causal inference listed in Section 2.2. Specifically, we define , where stands for the zero function on . Then for any two functions and any , we have
(2.4) |
Solving (2.1) with Overparameterized Neural Networks. In the sequel, we aim to solve the problem in (2.1) based on i.i.d. data points sampled from , with being a class of overparameterized neural networks. In this case, it is possible that (2.1) does not have a solution within . Furthermore, for the choice of regularizer, we consider the following specific form of :
(2.5) |
where for any given , is a convex functional of that maps each function to a scalar. Moreover, satisfies
(2.6) | ||||
(2.7) |
Equation (2.6) requires that is a non-negative functional of that is equal to if and only . Equation (2.7) requires that the functional derivative of with respect to , is linear in . One example of is the -regularizer of the following type, . Here is a subset of variables that contain values from both the endogenous variables and exogenous variables .
Minimax Estimation. To solve the optimization problem in (2.3), we first transform it into a unconditional moment formulation by introducing a dual function. By Fenchel duality, we can rewrite the objective function as follows,
(2.8) | ||||
The formulation in (2.8) leads to the following minimax optimization problem:
(2.9) |
We note that is a convex-concave functional with respect to function and . We denote by the unique saddle point of (2.9). Here the uniqueness of comes from the convexity of regularization , and implies the uniqueness of . Without the regularization, i.e., , the saddle point of (2.9) is and .
2.2 Examples of Functional Conditional Moment Equation
In this section, we discuss several important applications of the functional conditional moment equation, which serve as running examples of this paper.
Policy Evaluation. We consider a Markov decision process given by , where is the state space, is the action space, is the transition kernel, is the reward function, is the discount factor. Given a policy , an agent interacts with the environment in the following manner. At a state , the agent takes an action and receives a reward . Then, the agent transits to the next state . We denote the transition kernel induced by policy by for any . In policy evaluation, we aim to estimate the value function defined as follows,
where the expectation is taken with respect to and for . By the Bellman equation (Sutton and Barto, 2018), it holds for any that
(2.10) |
Corresponding to the Bellman equation in (2.10), let denotes the joint distribution of the state-action tuple under policy , the value function satisfies the following functional conditional moment equation,
(2.11) |
We notice that (2.11) is a special case of the functional conditional moment equation in (2.1) by setting the exogenous variable to be the current state , the endogenous variable to be the next state and the function to be estimated to be defined on the state space . In this case, the functional is , where is the reward function. We remark that the reason function can be evaluated simultaneously on and is that both and are variables defined on . Following the same derivation of (2.8), policy evaluation can be formulated as the following minimax optimization problem,
Nonparametric Instrumental Variables Regression. The nonparametric instrumental variables model is common and useful in statistics and economics. The model can be described simply by a line of equation
where in an observed outcome, is the endogenous variable, is the exogenous variable, is the true model that characterize the relationship between and and is also the function we want to estimate. In this model, is a noise possibly correlated with the endogenous but uncorrelated with the exogenous . It’s straightforward to see that NPIV model fits into the framework of the functional conditional moment equation by plugging the model equation into the equation about ,
(2.12) |
We notice that (2.12) is a special case of functional conditional moment equation in (2.4) by identifying , with the endogenous and exogenous variable respectively and setting the functional as . Following the same derivation of (2.8), the problem of NPIV is equivalent to the following minimax optimization problem,
Asset Pricing. Asset pricing refers to the process of determining the fair value of financial assets. This field is fundamental in finance and underpins much of the work in investment, portfolio management, and risk assessment. Semiparametric Consumption Captial Asset Pricing Model (CCAPM) is a foundational model in asset pricing that describes the relationship between systematic risk and expected asset returns, which also incorporates the influence of the consumption preference of investors over time. Moreover, CCAPM can be characterized through a functional conditional moment equation (Chen et al., 2014; Chen and Ludvigson, 2009). To describe the model, let denote the consumption level at time , the consumption growth. The marginal utility of consumption at time is given by , where is the discount factor, is the nonparametric structural demand function, which is an unknown positive function of our interest and is defined on , the space of consumption growth. The unknown function can be understood as a taste shifter that describes how the marginal utility of consumption changes with the state of the economy in terms of consumption growth.
Now, consider the growth-return tuple for with joint distribution , where is the consumption growth at the current time , and is the consumption growth at the next time . is a modified return observed in this period, which is a known function of the actual return and the consumption growth at time . We consider the scenario where the time series of consumption growth follows a time-homogenous Markov chain with a smooth transition kernel. That being said, both conditional transition probabilities and admit a smooth density function. The CCAPM model captures the behavior of through the following equation:
(2.13) |
where the modified return can be further expressed as , is the rate of time preference. We focus on a setting where is a compact set, and the modified return is bounded for all . We notice that (2.13) is a special case of the functional conditional moment equation in (2.4). We can identify the exogenous variable with , the consumption growth at the current time , and the endogenous variable with , the consumption growth at the next time . In this scenario, we identify the space with , the space of consumption growth and the function to be estimated is defined on . The functional is , where again denotes the modified return. Similar to the scenario of policy evaluation, the reason function can be evaluated simultaneously on and is that both and are variables defined on . Following the same derivation of (2.8), the problem of asset pricing through CCAPM is equivalent to the following minimax optimization problem,
Adversarial Riesz representer Estimation. Many problems in statistics, causal inference, and finance involve the task of learning a continuous linear functional in the following form,
(2.14) |
where function , is defined on a function space , and is a random vector of which we have access to observations and represents the source of randomness in the functional. Moreover, suppose such continuous linear functional is also mean-square continuous with respect to norm. In that case, it can be written in a more benign and useful manner, which is also often the case. Formally speaking, for such linear functional , there exists function such that for any ,
(2.15) |
The function here is called the Riesz representer of the linear functional , and the equation (2.15) is known as the Riesz representation theorem. Information about the Riesz representation of such linear functional is crucial to numerous applications and learning tasks. Therefore, we aim to estimate by exploiting the relationship characterized by the equation. We have the following trivial observation that the true Riesz representer can be recovered by solving the following equation,
(2.16) |
Of course, will solve the equation above, and therefore the true Riesz representer is achieved. We remark that this is indeed a special case since the expectation taken in (2.16) is unconditioned. In the equation, we only involve the endogenous variable , which also indicates that the exogenous variable coincides with . While special, the problem still fits in the framework discussed here. By setting , we recovered the intractable formulation of Riesz representer estimation.
However, unlike the previous examples where have access to observations of each term in the equation, here we have no direct access to values of , making the problem seemingly intractable. Fortunately, the alternative formulation of the original problem as a minimax optimization problem solves this difficulty. When written in the minimax formulation, we will again see the linear functional show up in the equation in the form of (2.14), which can be approximated using empirical values calculated from accessible observations of the random vector . Following the same derivation of (2.8) and the definition of Riesz representer in (2.15), the problem of adversarial Riesz representer estimation is equivalent to the following minimax optimization problem,
(2.17) |
Again, we stress that in (2.17), the absence of is due to the fact both the endogenous and exogenous variables are described by and the objective is computationally tractable since we have access to both observations of and .
2.3 Mean-Field Neural Network and Wasserstein Space
In the sequel, we will consider functions in the neural network function class. Consider a neural function defined on a given state space , that takes an input and parameter and outputs a value in . For where , we can define an overparameterized two-layered neural network function using neuron function ,
For such a form, we can further consider the infinite width limit when . When taking such a limit, the neural network function becomes a mean-field neural network and can be parameterized with probability measure over the parameter space, .
When considering such a limit, the optimization problem over the neural network function class is turned from a finite-dimensional problem over the parameter space into an infinite-dimensional problem over the space of probability measures. Therefore, we will need to track the convergence of probability measures over the Wasserstein space when analyzing the convergence of algorithms.
We now introduce the background knowledge of the Wasserstein space for the reader’s information. Let be the space of all the probability measures over the -dimensional Euclidean space with finite -th order moments. The Wasserstein- distance between two probability measures is defined as follows,
(2.18) |
where the infimum is taken over all the coupling of and . Here we denote by and the marginal distributions of with respect to and , respectively. We call the Wasserstein- space. For any , due to the relation that , we have that for two measures . In this paper, we focus on the cases when . Without further clarification, we refer to the distance with as the Wasserstein distance in the sequel.
The Wasserstein-2 space can be viewed as an infinite-dimensional Riemannian manifold (Villani, 2008). Formally, the tangent space at point is defined as
Then, for any absolutely continuous curve on the Wasserstein-2 space, there exists a family of vector fields such that the continuity equation
(2.19) |
holds in the sense of distributions. For any two absolutely continuous curves , we define the inner product between for any as follows,
(2.20) |
where is the inner product over , and satisfy the continuity equation in (2.19). Note that (2.20) yields a Riemannian metric over . Furthermore, the Riemannian metric induces a norm .
3 Algorithms
In this section, we introduce the stochastic gradient descent-ascent algorithm (SGDA) and its mean-field limit, which is characterized by the continuity equation.
Stochastic Gradient Descent-Ascent Algorithm. We solve the minimax optimization problem in (2.9) via SGDA. Recall that in the minimax objective, we have two functions simultaneously involved, where the primal function represents the true model of interest and the dual function represents an adversarial player. Specifically, we parameterize both and with neural networks with width and parameters and
(3.1) |
where we use bold symbols and to denote the whole parameter used by each neural net and unbold symbols and to denote the parameter used by each neuron. Here, , are the functions for neurons. In particular, we can recover the general setting of two-layer neural networks parameterization for and when we choose to be the following specific form,
where , are activation functions with input and respectively and parameters . We note that it’s not necessary to choose the same width for and , and activation functions need not have the same parameter dimension . Here we use the same width and parameter dimension to keep notations simple as these won’t affect the validity of the results presented in this paper.
Besides, we have also introduced a scaling factor in (3.1). Setting the scaling parameter in (3.1) recovers the neural tangent kernel regime (Jacot et al., 2018). Setting the parameter recovers the mean-field regime (Mei et al., 2018, 2019). In a discrete-time finite-width (DF) scenario, at the th iteration, the primal function and adversarial player are updated as follows,
DF-GD: | ||||
DF-GA: | (3.2) |
where denotes the state of the parameters at iteration , is the step-size, and the data samples are collected by independently sampling from the data distribution . When are two-layered neural networks with width , we can plug in the form for as is described in (3.1). The update for the parameter of -th neuron at -th iteration can be further specified to the following,
(3.3) |
where and , denotes the variation of with respect to . Here, is the neural network scaling parameter and is the stepsize scale. Both and show up in (3) due to the finite width parameterization of two-layered neural networks described in (3.1).
For a given space , let define a set of functions defined on . For a functional defined over the function class , , its variation at is a function , such that for any test function ,
(3.4) |
We initialize the parameters with and , with be standard Gaussian distribution in . In addition, to keep track of the evolution of the parameter distribution, we denote the empirical distribution of and at the th iteration by,
where is the Dirac mass function.
Mean-Field (MF) Limit. To analyze the convergence of the Stochastic Gradient Descent-Ascent Algorithm for solving functional conditional moment equations with neural networks, we employ an analysis that studies the mean-field limit regime (Mei et al., 2018, 2019) of the discrete-time dynamics described in (3). Here, by the mean-field limit, we are referring to an infinite-width limit, i.e., when for the neural network width and a continuous time, i.e., where the step scale in (3). In what follows, we introduce the mean-field limit of the SGDA dynamics, which refers to the infinite-width and continuous limit of (3). For and independently sampled respectively from , we can write the infinite width limit of neural networks used in (3.1) as
(3.5) |
From now on, we denote by the distribution of and the distribution of for the infinite-width and continuous limit of the neural networks at time . For notational simplicity, we overload the notation of the objective function in (2.9) via . This is to further emphasize the dependence of objective on when we parameterize the function pair using distributions on the parameter space. By Otto’s calculus (Villani, 2008), the mean-field limit of the update direction takes the following form,
(3.6) |
Here is the inner product on with respect to the Lebesgue measure. Recall that is the data distribution of random variables , we denote by the density of with respect to the Lebesgue measure on and we use to represent the inner product on with respect to the probability distribution . That is to say, for any two function , .
In the sequel, we will also slightly abuse this notation and use to denote the inner product on sub-spaces of , with the measure being the marginals of on these sub-spaces. In (3), and is the variation of and over under , where the test functions are chosen over the function class . In the same way, and respectively denote the variation of the objective with respect to distributions and under , following definition in (3.4) with the test function chosen over . We also remark that we can also define the variation under , which will only differ from the variation under by a constant function factor that corresponds to the density of the marginals of . Then, the mean-field limit of the SGDA update in (3) is characterized by the continuity equation, which is a system of PDEs given by,
(3.7) |
where , denotes the divergence with respect to , respectively. Note that the initialization and are the same as the initialization of the discrete-time dynamics in (3), i.e. , are taken to be the distribution of standard Gaussian random variables in .
4 Main Results
In this section, we introduce the main theoretical results of the stochastic gradient descent-ascent dynamics. We first present the assumptions in §4.1. Then in §4.2 we show that the SGDA dynamics converge to a mean-field limit when the network with goes to infinity and the stepsize scale goes to zero. Finally, in §4.3 we prove that the mean-field limiting dynamics converge to a globally optimal solution of the primal objective under proper assumptions. Moreover, we will show that the mean-field dynamics learns a data-dependent representation that is away from the initial representation.
4.1 Assumptions
We consider two types of assumptions in this work. The first type of assumption is about the function class in which we search for solutions to the minimax optimization problem. In this category, Assumption 4.1 and Assumption 4.2 discuss the richness and regularity of the two-layered neural network function class. The second type of assumption is about the feasible class of problems to apply our framework. In this category, Assumption 4.3 discusses several technical assumptions on the data space and the regularity/smoothness of the functionals.
We start with the discussion of the two-layered neural network function class. Consider the neuron function and with the following form,
(4.1) |
where , contains the parameters in the output layer and the hidden layer, is an odd re-scaling function and is the activation function. Note that such a form of activation function satisfies the condition of universal function approximation theorem (Theorem 3.1 in Pinkus (1999)) if is not a polynomial. For notational simplicity, we write . The re-scaling function is introduced to ensure that the value of the neural network is upper bounded. When , the function class induced by the neural network in (3.5) is equivalent to the following class,
(4.2) |
where . This captures a rich function class due to the universal function approximation theorem (Barron, 1993; Pinkus, 1999). We remark that we introduce the re-scaling function in (4.1) to avoid the study of the space of probability measures over , which has a boundary and thus lacks regularity in the study of optimal transport. Moreover, note that a scaling hyperparameter is introduced in the definition of the mean-field neural nets in (3.5). When , this causes an effect of overparameterization. In brief, controls the error between the and optimizer according to Theorem 4.7. Furthermore, the overparameterization scale has an influence through Lemma 4.6, which shows that the Wasserstein distance between the Gaussian initialization and the optimal distribution is upper-bounded by . Next, we impose the following regularity assumptions on the neural network functions and .
Assumption 4.1 (Regularity of Neural Networks).
We assume that there exist absolute constants , and such that
where denotes the hessian with respect to and respectively, denotes the vector norm, and denotes the matrix Frobenius norm. Moreover, we assume that the rescaling function is odd and its range satisfies that .
Assumption 4.1 is satisfied by a broad class of neuron functions. For example, it is satisfied when we set the activation function and rescaling function .
We also make the following assumption regarding the realizability of the saddle point solution to (2.9).
Assumption 4.2 (Realizability).
In general, problem (2.9) may not admit a saddle point within the given neural network function class. Therefore, Assumption 4.2 is introduced to guarantee that the discussion in this paper is meaningful. By universal function approximation theorem (Barron, 1993; Pinkus, 1999), the function class defined in (4.2) captures a rich class of functions. Therefore, such an assumption is quite general and does not restrict the influence of the applications of our results.
We impose the following assumptions on the integrability of the functional and and their variations, as well as the compactness of the data space and .
Assumption 4.3 (Data regularity and Functional Integrability).
(i) For the data space , , we assume that is compact, in the sense that there exists a positive constant such that for any data tuple , it satisfies that . Moreover, the data distribution admits a positive, smooth density with respect to the Lebesgue measure on .
(ii) For the functionals and , there exists a positive constant such that
(iii) We assume that as a linear functional of is upper-bounded by constant times of values of . That is to say, there exists as a part of the data tuple and a positive constant such that
(iv) We assume that the variation of minimax objective with respective to and are continuous functions defined on and . That is to say,
Item (i) of Assumption 4.3 restricts our scenarios to data spaces with bounded values and smooth densities for technical reasons. Item (ii) and (iii) of Assumption 4.3 is an integrability condition that we additionally require to avoid discussion of improper functionals that potentially have singularities with exploding values. Item (iv) is a smoothness condition that requires the variation of the minimax objective averaged over data to be continuous on respective space. We also remark that a sufficient condition for item (iv) to hold is the variation of and with respect to averaged under the marginal of on is continuous. We will also use this condition to verify item (iv) in practice. These are general and reasonable assumptions widely satisfied by various applications in machine learning, causal inference, and statistics.
4.2 Convergence of SGDA dynamics to the Mean-Field Limit
In the following proposition, we show that the empirical distribution of the parameters and weakly converges to the mean-field limit in (3.7) as the width goes to infinity and the stepsize scale goes to zero. Let , where is the PDE solution to the continuous deterministic dynamics in (3.7) and corresponds to the empirical distribution of , which is -th iterate of the discrete time stochastic dynamics in (3) with stepsize scale . The following proposition proves that the PDE solution in (3.7) well approximates the discrete time stochastic gradient descent-ascent dynamics in (3).
Proposition 4.4 (Convergence of SGDA to Mean-Field Limit).
Proof.
See §B for a detailed proof. ∎
The proof of Proposition 4.4 is based on the propagation of chaos (Mei et al., 2018, 2019; Araújo et al., 2019; Zhang et al., 2020; Sznitman, 1991). We deferred the detailed proof of Proposition 4.4 to Appendix B. Proposition 4.4 allows us to convert the discrete-time SGDA dynamics over finite dimensional parameter space to its continuous-time, infinite-dimensional counter-part in Wasserstein space, in which the training is amenable to analysis since our infinitely wide neural network and in (3.5) is linear in and respectively.
4.3 Global Optimality and Convergence of the Mean-Field Limit
In this section, we will introduce our main results that characterize the global optimality and convergence of the mean-field neural networks, parameterized by the parameter distribution . The proof contains two steps. We first show that it is sufficient to find a stationary point of the Wasserstein gradient flow defined in (3.7) in order to solve the minimax optimization problem in (2.9), then we characterize the convergence of to the stationary point. Before presenting the two stages of the proof, we would need to further clarify the notions of stationarity regarding the Wasserstein gradient flow. We introduce the following definition,
Definition 4.5 (Stationary point of Wasserstein Gradient Flow).
A distribution pair is called a stationary point of the Wasserstein gradient flow (3.7) if it satisfies
From Definition 4.5, the stationary point of Wasserstein gradient flow (3.7) is a distribution pair , at which the associated vector field is a zero function on the parameter space . Moreover, for the Wasserstein gradient flow following vector field and initial condition , the solution to its associated continuity equation is a constant flow such that for all , . Now, we have the following important supporting lemma that characterizes the relation between stationary points of Wasserstein gradient flow (3.7) and saddle points of (2.9).
Lemma 4.6.
Lemma 4.6 demonstrates that the stationary point of the Wasserstein gradient flow in (3.7) achieves global optimality as a solution to the minimax objective (2.9). Lemma 4.6 allows us to bypass the hardness of solving the nonconvex-nonconcave optimization problem (2.9) of finding saddle points in the space of neural network parameters by searching for a stationary point of the Wasserstein gradient flow instead. Moreover, there exist good pairs of stationary points that are close to the Gaussian initialization , with Wasserstein distance upper bounded by order .
Proof.
See §A.1 for a detailed proof. ∎
We are now ready to show our main results. The following theorem characterizes the global optimality and convergence of the Wasserstein gradient flow .
Theorem 4.7 (Global Convergence to Saddle Point).
Proof.
See §A.2 for a detailed proof. ∎
Theorem 4.7 says that the optimality gap between and , quantified by the -induced distance and distance respectively, decays to zero at a sublinear rate in terms of time up to the error of , where is the scaling parameter in (3.1) and (3.5). In order to prove the convergence, we construct a potential , with if and only if . Such a potential characterizes the saddle point of the minimax objective. We show that the Wasserstein gradient flow decreases the potential at a sublinear rate, thus suggesting the convergence of the gradient flow to the saddle point. Moreover, varying allows a trade-off between the error of order in the optimality gap and the maximum deviation between and the Gaussian initialization for all . In the proof of item (ii) of Lemma 4.6, we proved that the deviation of from quantified by the Wasserstein distance is of order . Regarding representation learning, this suggests that SGDA induces a data-dependent representation that is significantly different from the initialization. Choosing a small of order will correspond to the mean-field regime (Mei et al., 2018, 2019) that allows to move further away from the initialization, with the potential drawback of yielding a large error of order . On the other hand, choosing a large of order will correspond to the NTK regime (Jacot et al., 2018), and this causes the Wasserstein flow to stay close to the initial distribution along the trajectory, inducing a data-independent representation.
As we have commented before, an important class of regularizer is the regularizer. In this scenario, the left-hand side of (4.3) should be understood as a weighted distance between the gradient flow iterate at time to the optimal solution . As and go to infinity, such a distance will shrink to , thus the gradient flow converges globally in the minimal distance sense to the optimal solution. Due to this observation, in the sequel we will discuss several additional results in the case where the regularizer is strongly convex, in the sense that it’s bounded below by a quadratic function. We formalize the additional constraint in this case with the following assumption,
Assumption 4.8 (Strong Convexity).
We assume that the regularizer is -strongly convex, in the sense that there exists a constant such that for any ,
where is part of the data tuple .
Assumption 4.8 implies that regularizer is equivalent to a quadratic regularizer because is simultaneously bounded above and below by quadratic functionals. We have the following strengthened version of Theorem 4.7 in such case,
(4.4) |
Equation (4.4) shows that the iterates converges to the saddle point solution as a weighted distance decays to zero at a sublinear rate up to an error of . With Assumption 4.2, the saddle point is the global optimizer of the primal functional defined in (2.3). Therefore, as a direct consequence of Theorem 4.7, when the regularizer is strongly convex, converges globally to at a sublinear rate in terms of up to an error of .
Under Assumption 4.8, we can also quantify the optimality gap between and , in terms of the minimal distance . The following theorem characterize the global convergence of to ,
Theorem 4.9 (Global Convergence to Primal Solution).
Proof.
See §A.3 for a detailed proof. ∎
Theorem 4.9 proves that under the additional strong convexity assumption on the regularizer , the primal objective as is defined in (2.3) decays to zero at rate of in terms of time horizon , up to an error of . Here we use to denote the global minimizer instead of the saddle point. However, this will not create any confusion since for each global minimizer of the primal objective (2.3), we can find such that is a saddle point of (2.9).
5 Applications
In this section, we present the applications of Theorem 4.7 and Theorem 4.9 to several special cases of functional conditional moment equation, such as the problem of policy evaluation, instrumental variables regression, asset pricing, and adversarial Riesz representer estimation. In Section 2.2, we already discussed why these problems are special cases of functional conditional moment equations, thus Theorem 4.7 and Theorem 4.9 are potentially feasible to apply. We will recall the problem settings and examine the technical assumptions for these cases.
5.1 Application 1: Policy Evaluation
Let denote the joint distribution of the state-action tuple under policy . In this scenario, the endogenous variable is the next state while the exogenous variable is the current state. Therefore, , and . We attempt to estimate the value function , which is defined on . The functional and regularizer adopted in this case are,
Here, the regularizer we adopt is a regularizer that penalizes the squared value of the estimator evaluated at the next state . With these specific choices of functional and regularizer , the SGDA algorithm identifies with the Gradient Temporal Difference Learning (GTD) algorithm (Wai et al., 2020). Therefore, the application of our general framework to the problem of policy evaluation contributes to the reinforcement learning literature by providing an analysis of the neural GTD algorithm in the mean-field regime. Before presenting the theoretical results, we first verify that Assumption 4.3 and Assumption 4.8 hold.
Verify item (i) of Assumption 4.3. For item (i) of Assumption 4.3, it’s reasonable to assume that since we can always re-scale the state space without changing the nature of the problem, therefore the compactness assumption is inherently satisfied.
Verify item (ii) of Assumption 4.3. For item (ii) of Assumption 4.3, we first compute the variation of the functional and ,
Therefore, the desired integrability conditions hold since
(5.1) |
Verify item (iii) of Assumption 4.3. For item (iii) of Assumption 4.3, we choose , . The desired condition holds due to (5.1).
Verify item (iv) of Assumption 4.3. For item (iv) of Assumption 4.3, we first compute the variations of in explicit forms,
where , denotes the density of the marginal distribution of with respect to the current state and next state respectively. Due to the item (i) of Assumption 4.3, the variations of with respect to and are both continuous since the density of the conditional transition and are both smooth and the functions are also continuous by construction. Therefore, item (iv) is satisfied.
Verify Assumption 4.8. For Assumption 4.8, we choose and . The desired condition holds by definition of our choice of regularizer .
We have checked that the technical Assumption 4.3 and Assumption 4.8 hold for the case of policy evaluation. Assumption 4.3 allows us to apply Theorem 4.7. This implies the global convergence of the estimated value function to the minimizer of the primal objective (2.3) applied in this case. The convergence is quantified in a weighted distance. Additionally, Assumption 4.8 enables us to apply Theorem 4.9 and further characterize such convergence using the optimality gap between the value of primal objectives. We summarize the conclusions in the following corollary.
Corollary 5.1 (Global Convergence of Mean-field Neural Nets in Policy Evaluation).
Proof.
Corollary 5.1 proves that in the setting of policy evaluation, the distance between the mean-field neural network at time and the global minimizer decays to zero at a sub-linear rate, up to an error of order . Moreover, the optimality gap in terms of primal objective values decays to zero at the rate of , up to an error caused by overparameterization. Corollary 5.1 allows us to efficiently and globally solve the policy evaluation problem using overparameterized two-layer neural networks. We also remark that in such a scenario, the primal objective is also known as the regularized mean-squared Bellman error (MSBE) in the literature of reinforcement learning. As we have commented before, in the setting of policy evaluation, applying the SGDA algorithm within neural network function classes is equivalent to applying the neural GTD algorithm. Therefore, Corollary 5.1 states that, in the mean-field regime, the neural GTD algorithm converges globally to the minimizer at a sublinear rate up to an additional overparameterization error . The neural GTD algorithm also reduces regularized MSBE at the rate of up to an additional overparameterization error . Moreover, The global convergence of mean-field neural networks also implies the global convergence of the discrete dynamics in (3) due to the proximity between the discrete dynamics and continuous dynamics, which is proved in Proposition 4.4.
5.2 Application 2: Nonparametric Instrumental Variables Regression
Let denote the joint distribution of the endogenous variable , the exogenous variable , and the observed outcome . In this scenario, the endogenous variable is defined in space , the exogenous variable is defined in space , and . We attempt to estimate the model function , which is defined on . The functional and regularizer adopted in this case are,
Here, the regularizer we adopt is a regularizer that penalizes the squared value of the estimator of the model function evaluated at the endogenous variable . We examine Assumption 4.3 and Assumption 4.8 in order to apply results from Section 4.3.
Verify item (i) of Assumption 4.3. For item (i) of Assumption 4.3, the NPIV problem with compact data space captures a large class of important applications, therefore the scenarios considered are still general while imposing this assumption.
Verify item (ii) of Assumption 4.3. For item (ii) of Assumption 4.3, we first compute the variation of the functional and ,
Therefore, the desired integrability conditions hold since
(5.2) |
Verify item (iii) of Assumption 4.3. For item (iii) of Assumption 4.3, we choose , . The desired condition holds due to (5.2).
Verify item (iv) of Assumption 4.3. For item (iv) of Assumption 4.3, we first compute the variations of in explicit forms,
where , denotes the density of the marginal distribution of with respect to the endogenous variable and the exogenous variable respectively. Due to the item (i) of Assumption 4.3, the variations of with respect to and are both continuous since the density of the conditional transition and are both smooth and the functions are also continuous by construction. Therefore, item (iv) is satisfied.
Verify Assumption 4.8. For Assumption 4.8, we choose and . The desired condition holds by definition of our choice of regularizer .
We have checked that the technical Assumption 4.3 and Assumption 4.8 hold for the case of nonparametric instrumental variables regression. Theorem 4.7 can be applied in this case due to the establishment of Assumption 4.3. This implies the global convergence of the estimated model function to the minimizer of the primal objective. The convergence is quantified in a weighted distance. The choice of quadratic regularizer implies the establishment of Assumption 4.8, which further enables us to apply Theorem 4.9 and characterize the convergence in terms of primal objective value. We summarize the conclusions in the following corollary.
Corollary 5.2 (Global Convergence of Mean-field Neural Nets in NPIV).
Proof.
Corollary 5.2 proves that in the setting of NPIV, the distance between the mean-field neural network at time and the global minimizer decays to zero at a sub-linear rate, up to an error of order . Moreover, the optimality gap decays to zero at the rate of , up to an error . Corollary 5.2 allows us to solve the NPIV problem globally using overparameterized two-layer neural networks. We also want to remark that when the true model function is linear in the input, we recover the setting of instrumental variables regression as an important special instance of NPIV. Therefore, Corollary 5.2 also implies IV regression can be globally solved efficiently by using overparameterized two-layer neural networks.
5.3 Application 3: Asset Pricing
Let denote the joint distribution of the growth-return tuple . In this scenario, the exogenous variable is the consumption growth at the current time , and the endogenous variable is the consumption growth at the next time . Therefore, , where is the space of consumption growth and is also a compact subset of . Here, we consider the scenario where the modified return is also bounded for all , i.e., for some . We attempt to estimate the function , which is defined on . The functional and regularizer adopted in this case are,
Here, the regularizer we adopt is a regularizer that penalizes the squared value of the estimator evaluated at the consumption growth of the next time . Before presenting the theoretical results, we first verify that Assumption 4.3 and Assumption 4.8 hold.
Verify item (i) of Assumption 4.3. For item (i) of Assumption 4.3, since we assume that the space of consumption growth is a compact subset of , therefore there exists such that for all , . Moreover, it is reasonable to assume that the consumption growth is bounded since the data often fluctuates within certain regimes in practice.
Verify item (ii) of Assumption 4.3. For item (ii) of Assumption 4.3, we first compute the variation of the functional and ,
Therefore, the desired integrability condition holds since,
(5.3) |
Verify item (iii) of Assumption 4.3. For item (iii) of Assumption 4.3, we choose , . The desired property holds due to (5.3).
Verify item (iv) of Assumption 4.3. For item (iv) of Assumption (4.3), we first compute the variations of in explicit forms,
where denotes the density of the marginal distribution of with respect to the current time consumption growth and the next time consumption growth respectively. The variations of with respect to and are both continuous since the density of the conditional transition and are both smooth, and the function are also continuous by construction. Therefore, item (iv) is satisfied.
Verify Assumption 4.8. For Assumption 4.8, we choose and . The desired condition holds by definition of our choice of regularizer .
We have checked that the technical Assumption 4.3 and Assumption 4.8 hold for the case of asset pricing with CCAPM model. Theorem 4.7 can be applied in this case due to the establishment of Assumption 4.3. This implies the global convergence of the estimated function to the minimizer of the primal objective. The convergence is quantified in a weighted distance. Since Assumption 4.8 holds, we can apply Theorem 4.9 and characterize the convergence in terms of primal objective value. We summarize the conclusions in the following corollary.
Corollary 5.3 (Global Convergence of Mean-field Neural Nets in Asset Pricing).
Proof.
Corollary 5.3 proves that in the setting of asset pricing, the distance between the mean-field neural network at time and the global minimizer decays to zero at a sub-linear rate, up to an error of order . Moreover, the optimality gap decays to zero at the rate of , up to an error . Corollary 5.3 allows us to solve the CCAPM model globally by estimating the nonparametric structural demand function with overparameterized two-layer neural networks. Since the return on investment is linked to the marginal utility of consumption through the CCAPM equation, we can price fairly the assets by considering consumption risk and utilizing the marginal utility information.
5.4 Application 4: Adversarial Riesz Representer Estimation
Let denote the joint distribution of the endogenous variable and the random vector . In this scenario, the exogenous variable coincides with the endogenous variable , therefore the problem is essentially unconditional. The endogenous variable is defined in space , the exogenous variable is defined on , and . We attempt to estimate the Riesz representer , which is defined on . The functional and regularizer adopted in this case are,
Here, the regularizer we adopt is a regularizer that penalizes the squared value of estimator of the Riez representer evaluated at the variable . We examine Assumption 4.3 and Assumption 4.8 in order to apply results from Section 4.3.
Verify item (i) of Assumption 4.3. For item (i) of Assumption 4.3, we restrict our attention to estimating Riesz represented of functionals defined on a compact space. In practice, such an assumption is very general since we often treat data distribution on an unbounded space with exponential decay as a distribution defined on a compact space.
Verify item (ii) of Assumption 4.3. For item (ii) of Assumption 4.3, we first compute the variation of the functional and ,
Therefore, the desired integrability conditions hold since
(5.4) |
Verify item (iii) of Assumption 4.3. For item (iii) of Assumption 4.3, we choose , . The desired condition holds due to (5.4).
Verify item (iv) of Assumption 4.3. For item (iv) of Assumption 4.3, we first compute the variations of in explicit forms,
where , denotes the density of the marginal distribution of with respect to the endogenous variable and the exogenous variable respectively. Due to the item (i) of Assumption 4.3, the variations of with respect to and are both continuous since the density of the conditional transition and are both smooth and the functions are also continuous by construction. Therefore, item (iv) is satisfied.
Assumption 4.8. For Assumption 4.8, we choose and . The desired condition holds by definition of our choice of regularizer .
We have checked that the technical Assumption 4.3 and Assumption 4.8 hold for the case of adversarial Riesz representer estimation. Theorem 4.7 can be applied in this case due to the establishment of Assumption 4.3. This implies the global convergence of the estimated Riesz representer to the minimizer of the primal objective. The convergence is quantified in a weighted distance. The choice of quadratic regularizer implies the establishment of Assumption 4.8, which further enables us to apply Theorem 4.9 and characterize the convergence in terms of primal objective value. We summarize the conclusions in the following corollary.
Corollary 5.4 (Global Convergence of Mean-field Neural Nets in Adversarial Riesz Representer Estimation).
Corollary 5.4 proves that in the setting of adversarial Riesz representer estimation, the distance between the mean-field neural network at time and the global minimizer decays to zero at a sub-linear rate, up to an error of order . Moreover, the optimality gap decays to zero at the rate of , up to an error . Corollary 5.4 allows us to estimate the Riesz representer of a given functional using overparameterized two-layer neural networks.
6 Conclusion
In this paper, we focus on the minimax optimization problem derived from solving functional conditional moment equations using overparameterized two-layer neural networks. For such a problem, we first prove that the stochastic gradient descent-ascent algorithm converges to a mean-field limit as the stepsize goes to zero and the network width goes to infinity. In this mean-field limit, the optimization dynamics is characterized by a Wasserstein gradient flow in the space of probability distributions. We further establish the global convergence of the Wasserstein gradient flow, and prove that the feature representation induced by the neural networks is allowed to move by a considerable distance from the initial value. We further apply our general results to policy evaluation with high dimensional state space, nonparametric instrumental variables regression with high dimensional endogenous and exogenous variables, and asset pricing with a nonparametric structural demand function, and general Riesz representer estimation. Our analysis opens avenues for studying functional minimax optimization problems with more complicated objectives, such as nonlinear functional conditional moment equations. We leave the study of the convergence properties of the algorithm in such a general setting to future research. This setting includes nonparametric quantile instrumental variables regression as a leading and important application.
References
- Ai and Chen (2003) Ai, C. and Chen, X. (2003). Efficient estimation of models with conditional moment restrictions containing unknown functions. Econometrica, 71 1795–1843.
- Alkousa et al. (2019) Alkousa, M., Dvinskikh, D., Stonyakin, F., Gasnikov, A. and Kovalev, D. (2019). Accelerated methods for composite non-bilinear saddle point problem. arXiv preprint arXiv:1906.03620.
- Allen-Zhu et al. (2019a) Allen-Zhu, Z., Li, Y. and Liang, Y. (2019a). Learning and generalization in overparameterized neural networks, going beyond two layers. Advances in neural information processing systems, 32.
- Allen-Zhu et al. (2019b) Allen-Zhu, Z., Li, Y. and Song, Z. (2019b). A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning. PMLR.
- Ambrosio and Gigli (2013) Ambrosio, L. and Gigli, N. (2013). A user’s guide to optimal transport. In Modelling and Optimisation of Flows on Networks. Springer, 1–155.
- Ambrosio et al. (2008) Ambrosio, L., Gigli, N. and Savaré, G. (2008). Gradient flows: In metric spaces and in the space of probability measures. Springer.
- Araújo et al. (2019) Araújo, D., Oliveira, R. I. and Yukimura, D. (2019). A mean-field limit for certain deep neural networks. arXiv preprint arXiv:1906.00193.
- Arora et al. (2019a) Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R. R. and Wang, R. (2019a). On exact computation with an infinitely wide neural net. In Advances in Neural Information Processing Systems.
- Arora et al. (2019b) Arora, S., Du, S. S., Hu, W., Li, Z. and Wang, R. (2019b). Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. arXiv preprint arXiv:1901.08584.
- Barron (1993) Barron, A. R. (1993). Universal approximation bounds for superpositions of a sigmoidal function. IEEE Transactions on Information Theory, 39 930–945.
- Ben-Tal et al. (2009) Ben-Tal, A., El Ghaoui, L. and Nemirovski, A. (2009). Robust optimization, vol. 28. Princeton university press.
- Bennett et al. (2019) Bennett, A., Kallus, N. and Schnabel, T. (2019). Deep generalized method of moments for instrumental variable analysis. Advances in neural information processing systems, 32.
- Blundell et al. (2007) Blundell, R., Chen, X. and Kristensen, D. (2007). Semi-nonparametric iv estimation of shape-invariant engel curves. Econometrica, 75 1613–1669.
- Cai et al. (2019) Cai, Q., Yang, Z., Lee, J. D. and Wang, Z. (2019). Neural temporal-difference learning converges to global optima. In Advances in Neural Information Processing Systems.
- Chen et al. (2024) Chen, L., Pelger, M. and Zhu, J. (2024). Deep learning in asset pricing. Management Science, 70 714–750.
- Chen et al. (2014) Chen, X., Chernozhukov, V., Lee, S. and Newey, W. K. (2014). Local identification of nonparametric and semiparametric models. Econometrica, 82 785–809.
- Chen and Christensen (2018) Chen, X. and Christensen, T. M. (2018). Optimal sup-norm rates and uniform inference on nonlinear functionals of nonparametric iv regression. Quantitative Economics, 9 39–84.
- Chen and Ludvigson (2009) Chen, X. and Ludvigson, S. C. (2009). Land of addicts? an empirical investigation of habit-based asset pricing models. Journal of Applied Econometrics, 24 1057–1093.
- Chen and Pouzo (2012) Chen, X. and Pouzo, D. (2012). Estimation of nonparametric conditional moment models with possibly nonsmooth generalized residuals. Econometrica, 80 277–321.
- Chen and Qi (2022) Chen, X. and Qi, Z. (2022). On well-posedness and minimax optimal rates of nonparametric q-function estimation in off-policy evaluation. In International Conference on Machine Learning. PMLR.
- Chen et al. (2020a) Chen, Z., Cao, Y., Gu, Q. and Zhang, T. (2020a). A generalized neural tangent kernel analysis for two-layer neural networks. Advances in Neural Information Processing Systems, 33 13363–13373.
- Chen et al. (2020b) Chen, Z., Cao, Y., Gu, Q. and Zhang, T. (2020b). Mean-field analysis of two-layer neural networks: Non-asymptotic rates and generalization bounds. arXiv preprint arXiv:2002.04026.
- Chen et al. (2019) Chen, Z., Cao, Y., Zou, D. and Gu, Q. (2019). How much over-parameterization is sufficient to learn deep relu networks? arXiv preprint arXiv:1911.12360.
- Chernozhukov et al. (2020) Chernozhukov, V., Newey, W., Singh, R. and Syrgkanis, V. (2020). Adversarial estimation of riesz representers. arXiv preprint arXiv:2101.00009.
- Chizat (2022) Chizat, L. (2022). Mean-field langevin dynamics: Exponential convergence and annealing. arXiv preprint arXiv:2202.01009.
- Chizat and Bach (2018) Chizat, L. and Bach, F. (2018). On the global convergence of gradient descent for over-parameterized models using optimal transport. In Advances in Neural Information Processing Systems.
- Diakonikolas et al. (2021) Diakonikolas, J., Daskalakis, C. and Jordan, M. I. (2021). Efficient methods for structured nonconvex-nonconcave min-max optimization. In International Conference on Artificial Intelligence and Statistics. PMLR.
- Dikkala et al. (2020) Dikkala, N., Lewis, G., Mackey, L. and Syrgkanis, V. (2020). Minimax estimation of conditional moment models. Advances in Neural Information Processing Systems, 33 12248–12262.
- Du et al. (2019) Du, S., Lee, J., Li, H., Wang, L. and Zhai, X. (2019). Gradient descent finds global minima of deep neural networks. In International conference on machine learning. PMLR.
- Du et al. (2018) Du, S. S., Zhai, X., Poczos, B. and Singh, A. (2018). Gradient descent provably optimizes over-parameterized neural networks. arXiv preprint arXiv:1810.02054.
- Duan et al. (2020) Duan, Y., Jia, Z. and Wang, M. (2020). Minimax-optimal off-policy evaluation with linear function approximation. In International Conference on Machine Learning. PMLR.
- Duan et al. (2021) Duan, Y., Jin, C. and Li, Z. (2021). Risk bounds and rademacher complexity in batch reinforcement learning. In International Conference on Machine Learning. PMLR.
- Fang et al. (2019) Fang, C., Dong, H. and Zhang, T. (2019). Over parameterized two-level neural networks can learn near optimal feature representations. arXiv preprint arXiv:1910.11508.
- Fang et al. (2021a) Fang, C., Dong, H. and Zhang, T. (2021a). Mathematical models of overparameterized neural networks. Proceedings of the IEEE, 109 683–703.
- Fang et al. (2021b) Fang, C., Lee, J., Yang, P. and Zhang, T. (2021b). Modeling from features: a mean-field framework for over-parameterized deep neural networks. In Conference on learning theory. PMLR.
- Frei and Gu (2021) Frei, S. and Gu, Q. (2021). Proxy convexity: A unified framework for the analysis of neural networks trained by gradient descent. Advances in Neural Information Processing Systems, 34 7937–7949.
- Ganin et al. (2016) Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., Marchand, M. and Lempitsky, V. (2016). Domain-adversarial training of neural networks. The journal of machine learning research, 17 2096–2030.
- Goodfellow et al. (2020) Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A. and Bengio, Y. (2020). Generative adversarial networks. Communications of the ACM, 63 139–144.
- Grimmer et al. (2022) Grimmer, B., Lu, H., Worah, P. and Mirrokni, V. (2022). Limiting behaviors of nonconvex-nonconcave minimax optimization via continuous-time systems. In International Conference on Algorithmic Learning Theory. PMLR.
- Grimmer et al. (2023) Grimmer, B., Lu, H., Worah, P. and Mirrokni, V. (2023). The landscape of the proximal point method for nonconvex–nonconcave minimax optimization. Mathematical Programming, 201 373–407.
- Hajizadeh et al. (2024) Hajizadeh, S., Lu, H. and Grimmer, B. (2024). On the linear convergence of extragradient methods for nonconvex–nonconcave minimax problems. INFORMS Journal on Optimization, 6 19–31.
- Han et al. (2024) Han, Y., Xie, G. and Zhang, Z. (2024). Lower complexity bounds of finite-sum optimization problems: The results and construction. Journal of Machine Learning Research, 25 1–86.
- Holte (2009) Holte, J. M. (2009). Discrete Gronwall lemma and applications. In MAA-NCS meeting at the University of North Dakota, vol. 24.
- Hu et al. (2021) Hu, K., Ren, Z., Šiška, D. and Szpruch, Ł. (2021). Mean-field langevin dynamics and energy landscape of neural networks. In Annales de l’Institut Henri Poincare (B) Probabilites et statistiques, vol. 57. Institut Henri Poincaré.
- Huang and Yau (2020) Huang, J. and Yau, H.-T. (2020). Dynamics of deep neural networks and neural tangent hierarchy. In International conference on machine learning. PMLR.
- Huang et al. (2022) Huang, M., Chen, X., Ji, K., Ma, S. and Lai, L. (2022). Efficiently escaping saddle points in bilevel optimization. arXiv preprint arXiv:2202.03684.
- Ibrahim et al. (2019) Ibrahim, A., Azizian, W., Gidel, G. and Mitliagkas, I. (2019). Lower bounds and conditioning of differentiable games. arXiv preprint arXiv:1906.07300 31.
- Jacot et al. (2018) Jacot, A., Gabriel, F. and Hongler, C. (2018). Neural tangent kernel: Convergence and generalization in neural networks. In Advances in Neural Information Processing Systems, vol. 31.
- Jin et al. (2019) Jin, C., Netrapalli, P. and Jordan, M. I. (2019). Minmax optimization: Stable limit points of gradient descent ascent are locally optimal. arXiv preprint arXiv:1902.00618.
- Jin et al. (2022) Jin, Y., Sidford, A. and Tian, K. (2022). Sharper rates for separable minimax and finite sum optimization via primal-dual extragradient methods. In Conference on Learning Theory. PMLR.
- Jin et al. (2021) Jin, Y., Yang, Z. and Wang, Z. (2021). Is pessimism provably efficient for offline rl? In International Conference on Machine Learning. PMLR.
- Levy et al. (2020) Levy, D., Carmon, Y., Duchi, J. C. and Sidford, A. (2020). Large-scale methods for distributionally robust optimization. Advances in Neural Information Processing Systems, 33 8847–8860.
- Li et al. (2023) Li, C. J., Yuan, H., Gidel, G., Gu, Q. and Jordan, M. (2023). Nesterov meets optimism: rate-optimal separable minimax optimization. In International Conference on Machine Learning. PMLR.
- Li et al. (2022) Li, J., Zhu, L. and So, A. M.-C. (2022). Nonsmooth nonconvex-nonconcave minimax optimization: Primal-dual balancing and iteration complexity analysis. arXiv preprint arXiv:2209.10825.
- Liao et al. (2020) Liao, L., Chen, Y.-L., Yang, Z., Dai, B., Kolar, M. and Wang, Z. (2020). Provably efficient neural estimation of structural equation models: An adversarial approach. Advances in Neural Information Processing Systems, 33 8947–8958.
- Lin et al. (2020a) Lin, T., Jin, C. and Jordan, M. (2020a). On gradient descent ascent for nonconvex-concave minimax problems. In International Conference on Machine Learning. PMLR.
- Lin et al. (2020b) Lin, T., Jin, C. and Jordan, M. I. (2020b). Near-optimal algorithms for minimax optimization. In Conference on Learning Theory. PMLR.
- Lu et al. (2020a) Lu, S., Tsaknakis, I., Hong, M. and Chen, Y. (2020a). Hybrid block successive approximation for one-sided non-convex min-max problems: algorithms and applications. IEEE Transactions on Signal Processing, 68 3676–3691.
- Lu et al. (2020b) Lu, Y., Ma, C., Lu, Y., Lu, J. and Ying, L. (2020b). A mean-field analysis of deep resnet and beyond: Towards provable optimization via overparameterization from depth.
- Luo et al. (2021) Luo, L., Xie, G., Zhang, T. and Zhang, Z. (2021). Near optimal stochastic algorithms for finite-sum unbalanced convex-concave minimax optimization. arXiv preprint arXiv:2106.01761.
- Luo et al. (2020) Luo, L., Ye, H., Huang, Z. and Zhang, T. (2020). Stochastic recursive gradient descent ascent for stochastic nonconvex-strongly-concave minimax problems. Advances in Neural Information Processing Systems, 33 20566–20577.
- Madry et al. (2017) Madry, A., Makelov, A., Schmidt, L., Tsipras, D. and Vladu, A. (2017). Towards deep learning models resistant to adversarial attacks. arXiv preprint arXiv:1706.06083.
- Mei et al. (2019) Mei, S., Misiakiewicz, T. and Montanari, A. (2019). Mean-field theory of two-layers neural networks: Dimension-free bounds and kernel limit. arXiv preprint arXiv:1902.06015.
- Mei et al. (2018) Mei, S., Montanari, A. and Nguyen, P.-M. (2018). A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115 E7665–E7671.
- Nitanda et al. (2022) Nitanda, A., Wu, D. and Suzuki, T. (2022). Convex analysis of the mean field langevin dynamics. In International Conference on Artificial Intelligence and Statistics. PMLR.
- Nouiehed et al. (2019) Nouiehed, M., Sanjabi, M., Huang, T., Lee, J. D. and Razaviyayn, M. (2019). Solving a class of non-convex min-max games using iterative first order methods. Advances in Neural Information Processing Systems, 32.
- Ostrovskii et al. (2021a) Ostrovskii, D. M., Barazandeh, B. and Razaviyayn, M. (2021a). Nonconvex-nonconcave min-max optimization with a small maximization domain. arXiv preprint arXiv:2110.03950.
- Ostrovskii et al. (2021b) Ostrovskii, D. M., Lowy, A. and Razaviyayn, M. (2021b). Efficient search of first-order nash equilibria in nonconvex-concave smooth min-max problems. SIAM Journal on Optimization, 31 2508–2538.
- Otto and Villani (2000) Otto, F. and Villani, C. (2000). Generalization of an inequality by Talagrand and links with the logarithmic Sobolev inequality. Journal of Functional Analysis, 173 361–400.
- Ouyang and Xu (2021) Ouyang, Y. and Xu, Y. (2021). Lower complexity bounds of first-order methods for convex-concave bilinear saddle-point problems. Mathematical Programming, 185 1–35.
- Pinkus (1999) Pinkus, A. (1999). Approximation theory of the MLP model in neural networks. Acta Numerica, 8 143–195.
- Ramprasad et al. (2022) Ramprasad, P., Li, Y., Yang, Z., Wang, Z., Sun, W. W. and Cheng, G. (2022). Online bootstrap inference for policy evaluation in reinforcement learning. Journal of the American Statistical Association 1–14.
- Salimans et al. (2016) Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A. and Chen, X. (2016). Improved techniques for training gans. Advances in neural information processing systems, 29.
- Sirignano and Spiliopoulos (2020a) Sirignano, J. and Spiliopoulos, K. (2020a). Mean field analysis of neural networks: A central limit theorem. Stochastic Processes and their Applications, 130 1820–1852.
- Sirignano and Spiliopoulos (2020b) Sirignano, J. and Spiliopoulos, K. (2020b). Mean field analysis of neural networks: A law of large numbers. SIAM Journal on Applied Mathematics, 80 725–752.
- Sirignano and Spiliopoulos (2022) Sirignano, J. and Spiliopoulos, K. (2022). Mean field analysis of deep neural networks. Mathematics of Operations Research, 47 120–152.
- Sutton and Barto (2018) Sutton, R. S. and Barto, A. G. (2018). Reinforcement learning: An introduction. MIT press.
- Sznitman (1991) Sznitman, A.-S. (1991). Topics in propagation of chaos. In Ecole d’Été de Probabilités de Saint-Flour XIX—1989. Springer, 165–251.
- Thekumparampil et al. (2019) Thekumparampil, K. K., Jain, P., Netrapalli, P. and Oh, S. (2019). Efficient algorithms for smooth minimax optimization. Advances in Neural Information Processing Systems, 32.
- Uehara et al. (2020) Uehara, M., Huang, J. and Jiang, N. (2020). Minimax weight and q-function learning for off-policy evaluation. In International Conference on Machine Learning. PMLR.
- Villani (2003) Villani, C. (2003). Topics in optimal transportation. American Mathematical Society.
- Villani (2008) Villani, C. (2008). Optimal transport: Old and new. Springer.
- Wai et al. (2020) Wai, H.-T., Yang, Z., Wang, Z. and Hong, M. (2020). Provably efficient neural GTD for off-policy learning. Advances in Neural Information Processing Systems, 33.
- Wainwright (2019) Wainwright, M. J. (2019). High-dimensional statistics: A non-asymptotic viewpoint. Cambridge University Press.
- Wang et al. (2022) Wang, S., Yu, X. and Perdikaris, P. (2022). When and why pinns fail to train: A neural tangent kernel perspective. Journal of Computational Physics, 449 110768.
- Xie et al. (2020a) Xie, G., Luo, L., Lian, Y. and Zhang, Z. (2020a). Lower complexity bounds for finite-sum convex-concave minimax optimization problems. In International Conference on Machine Learning. PMLR.
- Xie et al. (2020b) Xie, Q., Chen, Y., Wang, Z. and Yang, Z. (2020b). Learning zero-sum simultaneous-move markov games using function approximation and correlated equilibrium. In Conference on learning theory. PMLR.
- Xu et al. (2020) Xu, L., Chen, Y., Srinivasan, S., de Freitas, N., Doucet, A. and Gretton, A. (2020). Learning deep features in instrumental variable regression. arXiv preprint arXiv:2010.07154.
- Xu et al. (2021) Xu, L., Kanagawa, H. and Gretton, A. (2021). Deep proxy causal learning and its application to confounded bandit policy evaluation. Advances in Neural Information Processing Systems, 34 26264–26275.
- Xu and Gu (2020) Xu, P. and Gu, Q. (2020). A finite-time analysis of q-learning with neural network function approximation. In International Conference on Machine Learning. PMLR.
- Yang et al. (2020) Yang, J., Kiyavash, N. and He, N. (2020). Global convergence and variance reduction for a class of nonconvex-nonconcave minimax problems. Advances in Neural Information Processing Systems, 33 1153–1165.
- Yang et al. (2022) Yang, J., Orvieto, A., Lucchi, A. and He, N. (2022). Faster single-loop algorithms for minimax optimization without strong concavity. In International Conference on Artificial Intelligence and Statistics. PMLR.
- Zhang et al. (2021a) Zhang, S., Yang, J., Guzmán, C., Kiyavash, N. and He, N. (2021a). The complexity of nonconvex-strongly-concave minimax optimization. In Uncertainty in Artificial Intelligence. PMLR.
- Zhang et al. (2020) Zhang, Y., Cai, Q., Yang, Z., Chen, Y. and Wang, Z. (2020). Can temporal-difference and q-learning learn representation? A mean-field theory. arXiv preprint arXiv:2006.04761.
- Zhang et al. (2021b) Zhang, Y., Chen, S., Yang, Z., Jordan, M. and Wang, Z. (2021b). Wasserstein flow meets replicator dynamics: A mean-field analysis of representation learning in actor-critic. Advances in Neural Information Processing Systems, 34 15993–16006.
- Zhao (2023) Zhao, R. (2023). A primal-dual smoothing framework for max-structured non-convex optimization. Mathematics of operations research.
- Zhao et al. (2022) Zhao, Y., Tian, Y., Lee, J. and Du, S. (2022). Provably efficient policy optimization for two-player zero-sum markov games. In International Conference on Artificial Intelligence and Statistics. PMLR.
- Zou and Gu (2019) Zou, D. and Gu, Q. (2019). An improved analysis of training over-parameterized deep neural networks. In Advances in Neural Information Processing Systems.
Appendix A Proof of Main Results
In this section, we provide proofs for the main theorems and technical lemmas in our work.
A.1 Proof of Lemma 4.6
Proof of (i). The proof for Claim (i) will be two-stage. First, we will show that if function pair is a stationary point for with respect to , then it’s a saddle point for the same objective as well. Then we will show that the distribution pair being a stationary point of implies that the corresponding is a stationary point for , which concludes the claim. We will start with the first part. We define the following functional and ,
We see that the minimax objective in (2.9) is indeed the sum of such two functionals,
For any function pair , we can verify that the following chain of equalities holds,
(A.1) |
We considered the function space and equipped with inner product , which are also Hilbert spaces. Since are compact, continuous function and parameterized in the form of (3.5) are square-integrable, thus naturally belong to and .
For a fixed , is a continuous linear functional in defined on . Thus, there exists function in such that . Similarly, for a fixed , is a continuous linear functional in , thus there exists function in such that . In fact, and matches the variation of with respect to and .
Since is a concave functional with respect to , we apply Jensen’s inequality and it holds that,
(A.2) |
Follow a similar reasoning, using the fact that is a linear functional with respect to and is a convex functional with respect to , it holds that
(A.3) |
Plugging (A.1) and (A.3) into (A.1), we re-write the minimax expression in (A.1) using the variation of , the following inequality holds,
(A.4) |
Thus, if is the stationary point, i.e.,
(A.5) |
then (A.4) suggests that for such stationary point , for any function pair , the following inequality holds,
(A.6) |
Equation (A.6) proves that is a saddle point for the minimx objective . Therefore, the stationarity of implies that it’s a saddle point for objective .
Now, we proceed to show the second stage of the proof. We now show that if is the stationary point of , i.e., , the corresponding function pair is the stationary point of with respect to . We recall that the correspondence between and is through (3.5). Let be a stationary point of (2.9), that is
(A.7) |
We can also compute the variation of explicitly.
By the oddness of in Assumption 4.1, we have that , This implies that the variation of with respect to and are when , i.e.,
Combined with (A.7), we deduced that
Note that we can expand the variation of with respect to ,
(A.8) |
By the universal function approximation theorem (Lemma D.1), since is in as is assumed in item (iv) of Assumption 4.3, there exists such that uniformly. Here, denotes the space of functions that are linearly spanned by By (A.8), it holds that
(A.9) |
Following a similar strategy, we can show that there exists such that , where for each , it holds that
(A.10) |
We take the limit of (A.9) and (A.10) by passing and conclude,
(A.11) |
Equation (A.11) proves that if is a stationary point of the Wasserstein gradient flow, then the associated function pair is a stationary point of the minimax objective , which matches the conditions we conclude in (A.5). Therefore, we prove that is a saddle point of the minimax objective . We complete the proof of item (i).
Proof of (ii). We now show that there exists good solution pair that is both optimal as well as close to initialization in Wasserstein distance. By Assumption 4.2, there exists distribution such that the optimal solution to the optimization problem (2.9) satisfies the following,
Recall that is the scaling parameter in neural network parameterization. We can construct using a convex combination of and the initialization ,
(A.12) |
We claim that constructed in (A.12) satisfies all the desired requirements. Since are standard Gaussian distribution, the integration of with respect to and with respect to are identically due to oddness of neuron functions,
Thus, the expressions for are simplified to
By Talagrand’s inequality (Lemma D.5), the following chain of inequalities holds,
(A.13) |
A similar bound on also applies,
(A.14) |
Let , we conclude the proof of item (ii).
A.2 Proof of Theorem 4.7
By Lemma 4.6, there exists distribution that is a stationary point of Wasserstein gradient flow (3.7) and simultaneously satisfying the distance bound in item (ii) of Lemma 4.6. For such , we denote as their product measure. Moreover, for any distribution pair , we use as their product measure for simplicity. To rewrite the Wasserstein gradient flow for into the flow for , we define vector the stacked vector field as,
(A.15) |
Following from Lemma D.2, (A.1), and (A.14), it holds that , where is defined in Lemma 4.6. Note that
Thus, we overload the notation to write and for . By writing , the update in (3.7) takes the following form
Before we prove Theorem 4.7, we first show the following important technical lemma.
Lemma A.1.
Proof.
Let be the geodesic connecting and with and . Let be the corresponding veclocity field such that . By the first variation formula of Wasserstein distance in Lemma D.3, it holds that
(A.17) | ||||
where the notation for any distribution and functions . We will provide bounds for term (i) and (ii) separately in the sequel.
Upper bounding term (i). For term (i) of (A.17), by the definitions of , , and in (A.15) and (3), we have that
where the second inequality holds since a constant, independent function, is linear in , and satisfies
A similar computation for gives
We recall that is the linear component in . We note that the variation of is the same as the variation of with respect to ,
We define the potential as
Then, the vector field is the gradient of such potential
where the gradient operator . Then, by Stoke’s formula and integration by parts, we have
Integrating potential with respect to simplied the expression to
(A.18) |
By convexity of and for all , it holds that
(A.19) |
Integrating (A.2) with respect to , we have that
(A.20) |
where the first inequality holds due to (A.19), and the second holds by Jensen’s inequality.
Upper bounding term (ii). By Lemma D.6, for term (ii) in (A.17), it holds that
(A.21) |
where denotes the Frobenius norm. Since is the velocity field corresponding to the geodesic connecting , by assumptions, it holds that
(A.22) |
On the other hand, by the definition of in (A.15), we have that
(A.23) |
By the definition of in (3), we have that
(A.24) |
where the first inequality follows from Assumption 4.1, and second inequality comes from the integrability conditions in Assumption 4.3. Thus, it suffices to upper bound and for all . For , we have that
(A.25) |
Moreover, it holds that
(A.26) |
where the second inequality follows from the fact that is the geodesic connecting and and the last inequality follows from (ii) in Lemma 4.6. Plugging (A.26) into (A.2), we have that
(A.27) |
Through a similar argument, such an upper bound can also be established for for all ,
(A.28) |
Plugging (A.27) and (A.28) into (A.2), we establish an upper bound for ,
(A.29) |
Similarly, by the definition of in (3) we have that
(A.30) |
Combining the bound from (A.29) and (A.2) and plugging into(A.23), it holds that
(A.31) |
Equation (A.22) and (A.31) provide upper bounds on the two terms involved in (A.2). Plugging the upper bounds that we have achieved, it holds that
(A.32) |
Now combining (A.2) and (A.32), we have that
where is a constant. This completes the proof of Lemma A.1. ∎
Proof.
We define
(A.33) |
Also, we define
(A.34) |
In other words, (A.16) of Lemma A.1 holds for , and for , we have
We now show that by contradiction. By the continuity of with respect to Ambrosio et al. (2008), since , it holds that . Let’s assume , then . Thus, by (A.16), (A.33), (A.34), it holds that for that
which further implies that . This contradicts the definition of in (A.34). Thus, it holds that , which implies that (A.16) of Lemma A.1 holds for any . We now discuss two different situations.
Scenario (ii) If , then (A.16) in Lemma A.1 holds for . Re-arranging the terms, we have the following inequality for all ,
(A.36) |
This further suggests the following upper bound,
(A.37) |
where the second inequality comes from integrating (A.36) in for , the third inequality comes from (ii) in Lemma 4.6 and last equality comes from setting to . Therefore, (A.2) implies Theorem 4.7 in this scenario.
Based on the discussion of scenarios (i) and (ii) above, we finish the proof of Theorem 4.7. ∎
A.3 Proof of Theorem 4.9
Proof.
We now prove Theorem 4.9. For notation simplicity, we denote as the estimator at time . Recall the definition of from (2.3) and from (2.2).
Plugging the definition of , it holds that
(A.38) |
Similar to the proof of Theorem 4.7, we define as,
We will upper-bound the term in (A.3) separately in two different scenarios, depending on the value of compared with .
Scenario (i) If , then we have that
(A.39) |
In order to upper-bound right-hand side of (A.39), we need to uniformly upper-bound and for all . For , we have that
(A.40) |
where the first inequality follows from Lemma D.7, the second inequality follows from Lemma D.2. The last inequality follows from (ii) in Lemma (4.6) and definition of . For , a similar chain of inequalities would apply,
(A.41) |
With uniform bounds on and , we are now ready to upper-bound through upper-bounding ,
(A.42) |
where is a constant and its values changes from line to line. The second inequality follows from (A.3) and (A.3). The last inequality follows from (A.2) in the proof of Theorem (4.7). Therefore, in this scenario, we have that
(A.43) |
Equation (A.43) concludes the proof of Theorem 4.9 in the scenario of .
Scenario (ii) If , by definition of , we have that
Following the same arguments in (A.3) and (A.3), we have a uniform upper-bound for for all and that writes,
Following the same derivation of (A.3), we have that
(A.44) |
where the last inequality follows from (A.36) and (A.2) in the proof of Theorem 4.7. Equation (A.3) concludes the proof of Theorem (4.9) in the scenario of .
Based on the discussion of scenarios (i) and (ii) above, we finish the proof of Theorem 4.9. ∎
Appendix B Mean Field Limit of Neural Networks
In this section, we prove Proposition 4.4. The formal version is presented as follows. Let , where is the PDE solution in (3.7) and is the empirical distribution of . Here we omit the dependence of the empirical distribution on and stepsize scale for notational simplicity.
Proposition B.1 (Formal Version of Proposition 4.4).
The proof of Proposition B.1 based heavily on Mei et al. (2018, 2019); Araújo et al. (2019); Zhang et al. (2020), which make use of the propagation of chaos arguments in Sznitman (1991). Recall that is the a vector field defined as,
(B.1) |
From now on, we equivalently write , to emphasize the dependence on iterations. For abbreviation, we denote and . We recall the finite-width representation of and are,
Correspondingly, we defined the finite-width counter-part of and as following,
(B.2) |
And we also defined the stochastic counter-part,
(B.3) |
where . Following from Mei et al. (2019); Araújo et al. (2019), we consider the following four dynamics.
-
•
Stochastic Gradient Descent Ascent (SGDA). We consider the following SGDA dynamics for and , where , with as its initialization,
(B.4) Note that this dynamics is equivalent to (3).
-
•
Population Gradient Descent Ascent (PGDA). We consider the following population gradient descent ascent dynamics for and , where , with , as its initialization,
(B.5) -
•
Continuous-time Population Gradient Descent Ascent (CTPGDA). We consider the following continuous time population gradient descent ascent dynamics for and , where , with , as initialization,
(B.6) -
•
Ideal particle (IP). We consider the following ideal particle dynamics for and , where , with , as initialization,
(B.7)
We aim to prove that weakly converges to . For any continuous function that satisfies the assumptions of Proposition B.1, using the IP, CTPGDA, and PGDA dynamics as interpolating dynamics, we have,
(B.8) |
The last inequality follows from the fact that . Here the norm denotes the supremum norm over the sequence of vectors ,
(B.9) |
In what follows, we define as a constant with its value varying from line to line. We establish the following lemmas as upper-bound of the four terms on right-hand side of (B).
Lemma B.2 (Upper Bound of ).
Lemma B.3 (Upper Bound of ).
Lemma B.5 (Upper Bound of ).
With these lemmas, we are now ready to present the proof of Proposition B.1.
Proof.
B.1 Proofs of Lemmas B.2-B.5
In this section, we present the proofs of Lemmas B.2-B.5, which based heavily on Mei et al. (2018, 2019); Araújo et al. (2019); Zhang et al. (2020). The required supporting technical lemmas are in §C. The constant presented in the proof is a positive constant whose values varies from line to line for notational simplicity.
B.1.1 Proof of Lemma B.2
Proof.
We first consider the ideal particle dynamics in (B.7). It holds that (Proposition 8.1.8 in Ambrosio et al. (2008)). Since the randomness of and comes from and respectively while and are independent, , . Due to independence of and , we also have . This implies the following,
For notational simplicity, we denote , similar notations also generalize to . Let and be two sets of variables that only differ in the -th element. Then, by the assumption that , we have the following bounded difference property,
Applying McDiarmid’s inequality (Wainwright, 2019), we have for a fixed that
(B.14) |
Moreover, we have for any that,
where the second inequality follows from the fact that and Lemma D.7. The last inequality follows from the definition of , (B.9) and Lemma D.2. Applying (C.12), (C.14) of Lemma C.2, we have for any that
Apply the union bound to (B.14) for , we have that
Setting and , we have that
with probability at least . Thus, we complete the proof of Lemma B.2. ∎
B.1.2 Proof of Lemma B.3
Following from the definition of , and , in (B.6) and (B.7). We have for any and that
(B.15) |
where the last inequality follows from (C.8) of Lemma C.1. Similarly, we have that
(B.16) |
where the inequality follows from (C.9). We now upper-bound the second term of (B.15) and (B.16). We start with (B.15). Following from the definition of and in (B) and (B), we have for any and that
(B.17) |
where is given by,
Following from Assumption 4.1 and 4.3, we have that . When , since , it holds that . Following from Lemma C.3, we have for fixed and that
(B.18) |
From Lemma D.7 and (C.14) of Lemma C.2, we have that
Following from Assumption 4.1 and 4.3, Lemma C.2, we have for any that,
Applying the union bound to (B.18) for and , we have that
Setting and , we have that
(B.19) |
with probability at least . Following from Assumption 4.1, when , in (B.17). Plugging (B.19) into (B.17), with probability at least , we have that
(B.20) |
Through similar arguments, with probability at least , the second term of (B.16) holds
(B.21) |
Now, conditioning on the intersection of event in (B.1.2) and event in (B.21), the following holds simultaneously for any
(B.22) | |||
(B.23) |
Summing (B.22) and (B.23) and applying Gronwall’s Lemma (Holte, 2009), with probability at least , for any , it holds that
(B.24) |
The last inequality holds since as a constant represents values changing from line to line. Therefore, equation (B.24) implies (B.11). Thus, we complete the proof of Lemma B.3.
B.1.3 Proof of Lemma B.4
By the definition of in (B), in (• ‣ B), in (B.6), it holds that the distances and satisfy
(B.25) |
(B.26) |
where (B.25) follows from (C.8) of Lemma C.1 and (C.13) of Lemma C.2, (B.26) follows from (C.9) of Lemma C.1 and (C.13) of Lemma C.2. Combining the inequalities in (B.25) and (B.26), it holds for any that
(B.27) |
Applying the discrete Gronwall’s lemma (Holte, 2009) to (B.27) , we have that
where the inequalities hold since we allow the value of to vary from line to line. Thus, we complete the proof of Lemma B.4.
B.1.4 Proof of Lemma B.5
Proof.
Let be the algebra generated by and . Following from the definition of and in (B) and (B), we have for any and that
Recall the definition of and as the SGDA and PGDA dynamics defined in (• ‣ B) and (• ‣ B). We have for any , that
(B.28) |
where the last inequality follows from (C.8) of Lemma C.1. and are defined as,
Following from (C.7) of Lemma C.1, it holds that , thus the stochastic process is a martingale with . Applying the Azuma-Hoeffding bound in Lemma C.4, we have that
(B.29) |
Apply the union bound to (B.29) for , we have that
Setting , with probability at least , it holds that
(B.30) |
Plug (B.30) into (B.28) and taking supremum norm over , we have that
(B.31) |
Through similar arguments, for and , with probability at least ,
(B.32) |
Conditioning on the intersection of event in (B.31) and event in (B.32), summing (B.31), (B.32), and applying the discrete Gronwall’s lemma (Holte, 2009), for any , the following inequality holds with probability at least ,
Here the last inequality holds since we allow the value of to vary from line to line. Thus, we complete the proof of Lemma B.5. ∎
Appendix C Supporting Lemmas
C.1 Supporting Lemmas for §B
In what follows, we presented the technical lemmas heavily used in B. We recall the definition of , and as in (B), (B), and (B) respectively. Let be a constant depending on , whose value varies from line to line. Recall that and are the finite width representation with parameters , whose definitions are given by
Lemma C.1.
Under Assumption 4.1 and 4.3, it holds that for any , , , , that, and are uniformly bounded and Lipschitz in respectively, which is given by the following,
(C.1) | |||
(C.2) | |||
(C.3) |
Recall the definition of and in (B), (B), the finite width representation of the velocity field and its stochastic counter-part, when evaluated at arbitrary , are also uniformly bounded and lipschitz in respectively. This means for , the following inequalities hold,
(C.4) | |||
(C.5) | |||
(C.6) |
A similar series of inequalities also hold for ,
(C.7) | |||
(C.8) | |||
(C.9) |
As a corollary of the inequalities stated above, the uniform bounds in fact hold for any , which says,
(C.10) |
Similarly, the uniform bounds also hold for the velocity field , such that for any , it holds that
(C.11) |
Proof.
We will prove these results separately.
For (C.1) of Lemma C.1, since , are bounded as is assumed in Assumption 4.1, we have for any , any and that
For (C.2), and (C.3) of Lemma C.1, since for any , , has a bounded gradient in , has a bounded gradient in . The uniform upper bound of the gradient controls the Lipschitz constant of the function, thus it holds for any , any and that
For (C.4) of Lemma C.1, recall the definition of , in (B), for any and ,
For notational simplicity, we further define
For (C.5) of Lemma C.1, following from Assumption 4.3 and the definition of in (B), we have for any and that
Moreover, is also Lipschitz in since
where the second inequality is achieved by applying (C.2), (C.3). Therefore, the fact that is Lipschitz in is due to and is uniformly bounded.
For (C.6) of Lemma C.1, following from Assumption 4.3 and the definition of in (B), through a similar argument as is in the proof of (C.5), we have for any and that
Again, is Lipschitz in since
Therefore, the Lipschtizness of in comes from and is uniformly bounded.
Equations (C.7), (C.8), (C.9) of Lemma C.1 for and follow from the fact that
Therefore, (C.7) follows from (C.4) and triangle inequality,
Equations (C.8) and (C.9) follows from (C.5), (C.6) and triangle inequality,
Equation (C.10) follows from the definition of in (4.2) and the uniform bounds of neuron functions and . For any , there exists probability measures over the parameter space such that
We apply the triangle inequality and achieve,
Equation (C.11) follows from the definition of in (B) and the proof of (C.4) and (C.7). Proof of (C.11) is the same as the proof for (C.4) and (C.7), except for the fact that a uniform bound is needed for the infinite width representation of and , which is proved in (C.10).
Based on proofs for items (i), (ii), (iii), and (iv) above, we finish the proof of Lemma (C.1). ∎
Now, recall is the PDE solution to (3.7), is the IP dynamics defined in (B.7), is the CTPGDA dynamics defined in (B.6). We have the following lemma that also bound the difference of iterates for IP, CTPGDA dynamics between time and .
Proof.
Lemma C.3.
Let be i.i.d. random variables with and Then it holds for any , there exists being an absolute constant that
Proof.
See Lemma 30 in Mei et al. (2019) ∎
Lemma C.4 (Azuma-Hoeffding bound).
Let be a martingale with respect to the filtration with . We assume for and any that,
Then it holds that, with being an absolute constant.
Appendix D Technical Results
D.1 Universal Function Approximation Theorem
In what follows, we introduce the universal function approximation theorem (Pinkus, 1999). For any given activation function , we consider the following function class,
We denote by the class of continuous functions over . Then, the following theorem holds.
Lemma D.1 (Universal Function Approximation Theorem, Theorem 3.1 in Pinkus (1999)).
If the activation function is not a polynomial, the function class is dense in in the topology of uniform convergence on a compact set.
D.2 Wasserstein Space
We use the definition of absolutely continuous curves in in Ambrosio et al. (2008) and introduce the following lemmas.
Lemma D.2.
For any probability measures , it holds that
Lemma D.3 (First Variation Formula, Theorem 8.4.7 in Ambrosio et al. (2008)).
Given and an absolutely continuous curve , let be the geodesic connecting and . It holds that
where , .
Lemma D.4 (Benamou-Brenier formula, Proposition 2.30 in Ambrosio and Gigli (2013)).
Let . Then, it holds that
Lemma D.5 (Talagrand’s Inequality, Corollary 2.1 in Otto and Villani (2000)).
Let be . It holds for any that
Lemma D.6 (Eulerian Representation of Geodesics, Proposition 5.38 in Villani (2003)).
Let be a geodesic and be the corresponding vector field such that . It holds that
where is the outer product of two vectors.
Lemma D.7 (Dual Representation of the first order Wasserstein Distance, Villani (2008)).
The first order Wasserstein distance has the following dual representation form
for any two probability measures .