這里對(duì)使用python求解常微分方程提供兩種思路,一種是自己編程實(shí)現(xiàn)歐拉法,改進(jìn)歐拉法或者四階龍格庫(kù)塔,這樣有助于理解上述三種數(shù)值計(jì)算方法的原理;一種是調(diào)用python已有的庫(kù),不再重復(fù)造輪子。
本文對(duì)上述兩種思路都給出代碼示例,并進(jìn)行比較;同時(shí)針對(duì)單個(gè)微分方程和含有多個(gè)微分方程的微分方程組給出代碼示例
代碼地址:https://github.com/DeqianBai/Python-solves-ordinary-differential-equations/tree/master
1. 常微分方程定義
凡含有參數(shù),未知函數(shù)和未知函數(shù)導(dǎo)數(shù) (或微分) 的方程,稱為微分方程。
- 未知函數(shù)是一元函數(shù)的微分方程稱作常微分方程
- 未知數(shù)是多元函數(shù)的微分方程稱作偏微分方程。
微分方程中出現(xiàn)的未知函數(shù)最高階導(dǎo)數(shù)的階數(shù),稱為微分方程的階數(shù)。
2. 調(diào)用現(xiàn)有的庫(kù)
scipy中提供了用于解常微分方程的函數(shù)odeint(),完整的調(diào)用形式如下:
scipy.integrate.odeint(func, y0, t, args=(), Dfun=None, col_deriv=0, full_output=0, ml=None, mu=None, \
rtol=None, atol=None, tcrit=None, h0=0.0, hmax=0.0,hmin=0.0, ixpr=0, mxstep=0, mxhnil=0, mxordn=12, mxords=5, printmessg=0)
實(shí)際使用中,還是主要使用前三個(gè)參數(shù),即微分方程的描寫函數(shù)、初值和需要求解函數(shù)值對(duì)應(yīng)的的時(shí)間點(diǎn)。接收數(shù)組形式。這個(gè)函數(shù),要求微分方程必須化為標(biāo)準(zhǔn)形式,即,實(shí)際操作中會(huì)發(fā)現(xiàn),高階方程的標(biāo)準(zhǔn)化工作,其實(shí)是解微分方程最主要的工作。
示例1:?jiǎn)蝹€(gè)微分方程
import math
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
def func(y, t):
return t * math.sqrt(y)
YS=odeint(func,y0=1,t=np.arange(0,10.1,0.1))
t=np.arange(0,10.1,0.1)
plt.plot(t, YS, label='odeint')
plt.legend()
plt.show()
結(jié)果如下:

示例2:微分方程組
與單個(gè)微分方程不同的是,此時(shí)的函數(shù)變成了向量函數(shù)
from scipy.integrate import odeint
import numpy as np
def lorenz(w, t, p, r, b):
# 給出位置矢量w,和三個(gè)參數(shù)p, r, b計(jì)算出
# dx/dt, dy/dt, dz/dt的值
x, y, z = w
# 直接與lorenz的計(jì)算公式對(duì)應(yīng)
return np.array([p*(y-x), x*(r-z)-y, x*y-b*z])
t = np.arange(0, 30, 0.01) # 創(chuàng)建時(shí)間點(diǎn)
# 調(diào)用ode對(duì)lorenz進(jìn)行求解, 用兩個(gè)不同的初始值
track1 = odeint(lorenz, (0.0, 1.00, 0.0), t, args=(10.0, 28.0, 3.0))
track2 = odeint(lorenz, (0.0, 1.01, 0.0), t, args=(10.0, 28.0, 3.0))
# 繪圖
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure()
ax = Axes3D(fig)
ax.plot(track1[:,0], track1[:,1], track1[:,2])
ax.plot(track2[:,0], track2[:,1], track2[:,2])
plt.show()
結(jié)果如下

3. 自己編程實(shí)現(xiàn)歐拉法/改進(jìn)歐拉法/四階龍格庫(kù)塔
示例 3:?jiǎn)蝹€(gè)函數(shù)使用四階龍格庫(kù)塔
import math
import numpy as np
import matplotlib.pyplot as plt
def runge_kutta(y, x, dx, f):
""" y is the initial value for y
x is the initial value for x
dx is the time step in x
f is derivative of function y(t)
"""
k1 = dx * f(y, x)
k2 = dx * f(y + 0.5 * k1, x + 0.5 * dx)
k3 = dx * f(y + 0.5 * k2, x + 0.5 * dx)
k4 = dx * f(y + k3, x + dx)
return y + (k1 + 2 * k2 + 2 * k3 + k4) / 6.
if __name__=='__main__':
t = 0.
y = 1.
dt = .1
ys, ts = [], []
def func(y, t):
return t * math.sqrt(y)
while t <= 10:
y = runge_kutta(y, t, dt, func)
t += dt
ys.append(y)
ts.append(t)
exact = [(t ** 2 + 4) ** 2 / 16. for t in ts]
plt.plot(ts, ys, label='runge_kutta')
plt.plot(ts, exact, label='exact')
plt.legend()
#plt.show()
結(jié)果如下:

示例4:示例1和示例3放在一起進(jìn)行對(duì)比
import math
import numpy as np
import matplotlib.pyplot as plt
def runge_kutta(y, x, dx, f):
""" y is the initial value for y
x is the initial value for x
dx is the time step in x
f is derivative of function y(t)
"""
k1 = dx * f(y, x)
k2 = dx * f(y + 0.5 * k1, x + 0.5 * dx)
k3 = dx * f(y + 0.5 * k2, x + 0.5 * dx)
k4 = dx * f(y + k3, x + dx)
return y + (k1 + 2 * k2 + 2 * k3 + k4) / 6.
if __name__=='__main__':
t = 0.
y = 1.
dt = .1
ys, ts = [], []
def func(y, t):
return t * math.sqrt(y)
while t <= 10:
y = runge_kutta(y, t, dt, func)
t += dt
ys.append(y)
ts.append(t)
from scipy.integrate import odeint
YS=odeint(func,y0=1, t=np.arange(0,10.1,0.1))
plt.plot(ts, ys, label='runge_kutta')
plt.plot(ts, YS, label='odeint')
plt.legend()
plt.show()
結(jié)果如下

示例5:多個(gè)微分方程(歐拉法)
import numpy as np
"""
移動(dòng)方程:
t時(shí)刻的位置P(x,y,z)
steps:dt的大小
sets:相關(guān)參數(shù)
"""
def move(P, steps, sets):
x, y, z = P
sgima, rho, beta = sets
# 各方向的速度近似
dx = sgima * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return [x + dx * steps, y + dy * steps, z + dz * steps]
# 設(shè)置sets參數(shù)
sets = [10., 28., 3.]
t = np.arange(0, 30, 0.01)
# 位置1:
P0 = [0., 1., 0.]
P = P0
d = []
for v in t:
P = move(P, 0.01, sets)
d.append(P)
dnp = np.array(d)
# 位置2:
P02 = [0., 1.01, 0.]
P = P02
d = []
for v in t:
P = move(P, 0.01, sets)
d.append(P)
dnp2 = np.array(d)
"""
畫圖
"""
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure()
ax = Axes3D(fig)
ax.plot(dnp[:, 0], dnp[:, 1], dnp[:, 2])
ax.plot(dnp2[:, 0], dnp2[:, 1], dnp2[:, 2])
plt.show()
結(jié)果如下:

參考: