自动微分的原理,基于接口的泛型实现,以及最小二乘迭代法求解优化问题
柏舟 新冠5年 03-24
最近用自动微分+牛顿迭代法写了一个求解器,关于求导一般有两种方法:
- 手算解析形式,或者使用Mathmatica等软件算解析形式,然后使用代码实现;
- 差分求导,存在误差。
但是自动微分可以自动求解导数而且不损失精度,而且借助编程语言的模板可以不重复编写代码,并且适用于多种类型。
自动微分介绍
自动微分的实现有两种方式,一种是神经网络中常用的动态图模式,通过重载NdArray(Tensor)的算子,在算子内部形成抽象语法树,然后就可以选择前向微分和后向微分自动求导,导出导数的抽象语法树。实现起来不算复杂,并且适用于stack,concat操作,缺点就是执行的时候是计算抽象语法树,就像解释器一样效率会打折,但是对于NdArray可以接受。
另一种是Ceres Solver的实现模式。我们都知道,复数是实数的一个扩域,我们不仅可以定义:\(i^2 = -1\) ,我们还可以定义 \(\varepsilon^2=0\),而定义的这个运算刚好是微分的运算。我们定义a为常值,b为一阶微分的系数值,它有以下性质:
- \(\varepsilon^2=0\);
- 加法定义:\((a+b\varepsilon)+(c+d\varepsilon)=(a+c)+(b+d)\varepsilon\),单位元$0+0\varepsilon$;
- 加法逆元:\(-(a+b\varepsilon)=(-a)+(-b)\varepsilon\);
- 乘法定义:\((a+b\varepsilon)+(c+d\varepsilon)=(ac)+(ad+bc)\varepsilon\),单位元$1+0\varepsilon$;
- 乘法逆元$a\ne 0$:\((a+b\varepsilon)^{-1}=\frac{1}{a}-\frac{b}{a^2}\varepsilon\);
- 满足加法交换律和乘法交换律。
我们对一个函数Taylor展开:
由于$\varepsilon^2=0$,所以对于任何一个能够Taylor展开的函数满足:
比如:
如果你难以理解,可以简单的把记号ε替换为dx,因为本质上就是微分运算。 下面举个求具体函数值和导数值的例子:
则
以上计算的是在变量x等于1时,z的值和一阶导数值,你会发现计算顺序和链式积分是一模一样的。需要特别注意的是变量和常量对于这个数的映射关系是不同的:
需要注意的是这个方法同样可以用来求偏导。因为对于不同的自变量
使用这套计算方式类似于前向的自动微分,但是不用保存抽象语法树,而且使用SIMD加速也很方便。
它为什么有这么好的运算性质?
这主要是因为满足抽象代数关于域的定义,它不仅满足加法和乘法,同时它有逆元和单位元,也就是加法和乘法都存在逆运算。
并不是所有计算都有逆运算,对于误差的计算包含常值、误差上界和下界,即$x\in [x-\varepsilon_1, x+\varepsilon_2]$它的加法和减法比较简单,但是乘法和除法很困难,因为涉及正负和0的处理。最重要的是$x\cdot x \div x \ne x$,也就是说,虽然它的乘法有实际含义,但是它没有逆元,运算不可逆。
使用泛型实现函数,使用接口约束类型
自动微分最大的优势就是在CPP中编写代码非常简单,只要函数参数的类型是泛型,就可以自动替换。比如:
<typename T>
T f(T x){
return x*x + sin(x) + 1;
}
只需要简单地重载自动微分的sin、cos这些初等函数,CPP会自动替换成正确的类型。有的时候发现不得不使用CPP,因为只有CPP的模板是字面替换而且支持重载函数。
我实际编写代码是C#写的,因为C#在.NET 7以后对数值计算的支持很好,它除了支持重载运算符以外还支持重载初等函数。比如ITrigonometricFunctions<TSelf>重载三角函数。唯一的缺点就是在编写中很难和double这些混用。
当然也存在硬写的方法,就是使用Mixins和高阶函数的思路,将函数注入进去,比如:
abstract class Fn<T> where T : INumberBase<T>, ITrigonometricFunctions<T> {
public T f(T x) => x * x + T.Sin(x) + TFromDouble(1);
public abstract T TFromDouble(double v);
}
class DoubleFn : Fn<double> {
public override double TFromDouble(double v) => v;
}
或者对于每一个数乘和加法写一个IMultiplyOperators,IAdditionOperators等等。
class Fn<T>
where T : INumberBase<T>,
ITrigonometricFunctions<T>,
IAdditionOperators<T, double, T> {
public T f(T x) => x * x + T.Sin(x) + 1;
}
就非常鸡肋,当然也可以用dynamic硬搞。C#的数值计算方便程度算得上仅次于CPP的了,而且有类型约束的保证,只要通过检查就基本能运行,就像写Rust一样。CPP也有了concept,约束可写可不写,很灵活。
用最小二乘迭代法求解优化问题
我觉得这个方法最大的用处就是用最小二乘迭代法求解优化问题,因为最小二乘迭代法不需要求二阶导,使用一阶导就可以计算。
比如你要拟合$y=h(x,\theta)$,theta是待定的参数,求在样本D{x,y}下使MSE最小。
对每一处$x_i$,对theta进行展开,相当于用一阶导近似原函数,用最小二乘法拟合
这一步使用自动微分就可以一次性把常数项和梯度项全部算出来,不用自己手动编写微分的计算方式。
记Jacobi矩阵为
用最小二乘法计算新的theta
更新theta
直到$\Delta\theta$收敛。
我甚至看到有人用这个方法求解单层的神经网络,非常方便。
高阶自动微分
Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX