原文地址: https://www.zhouwenzhen.top/archives/48/
使用Python生成LaTeX 數(shù)學(xué)公式
在閱讀算法文獻(xiàn)或者數(shù)學(xué)相關(guān)的文章中經(jīng)常會(huì)看到一些簡(jiǎn)單或復(fù)雜的數(shù)學(xué)公式,最近在分享此類文章時(shí),想使用LaTex鍵入數(shù)學(xué)公式以美化閱讀,發(fā)現(xiàn)需要反復(fù)去查詢LaTex相關(guān)的語(yǔ)法,效率較低且容易出錯(cuò)。
最近 GitHub 上出現(xiàn)了一個(gè)開(kāi)源項(xiàng)目 latexify_py,它使用 Python 就能生成 LaTeX 數(shù)學(xué)公式。打開(kāi)Google Colaboratory示例列舉了幾個(gè)案例:

先試試看
在本地安裝相應(yīng)的Python包,Python版本 >= 3.6
pip install latexify-py
參考官方示例進(jìn)行測(cè)試:
import math
import latexify
@latexify.with_latex
def solve(a, b, c):
return (-b + math.sqrt(b ** 2 - 4 * a * c)) / (2 * a)
if __name__ == '__main__':
print(solve)
終端打印結(jié)果為:
\mathrm{solve}(a, b, c)\triangleq \frac{-b + \sqrt{b^{2} - 4ac}}{2a}
將打印結(jié)果輸入到支持LaTeX的編輯器中,以Typora為例。選擇插入公式塊:
于是,把最近閱讀的facebook開(kāi)源的prophet時(shí)間序列預(yù)測(cè)算法提到的飽和增長(zhǎng)模型公式進(jìn)行測(cè)試,原文中為

開(kāi)始在python中鍵入代碼:
@latexify.with_latex
def g(t):
return C(t) / (1 + exp(1-(k + alpha(t) ** T * delta) * (t -(m + alpha(t) ** T * gamma))))
終端打印結(jié)果并輸入Typora為:
\mathrm{g}(t)\triangleq \frac{\mathrm{C}\left(t\right)}{1 + \mathrm{exp}\left(1 - (k + \mathrm{{\alpha}}\left(t\right)^{t}{\delta})(t - m + \mathrm{{\alpha}}\left(t\right)^{T}{\gamma})\right)}
對(duì)比發(fā)現(xiàn)python輸出的公式中有一個(gè)錯(cuò)誤:刪除了一個(gè)括號(hào),而python代碼中是包含的,由
變成了:
為了進(jìn)一步驗(yàn)證上面出現(xiàn)的問(wèn)題,輸入一段很簡(jiǎn)單的代碼:
@latexify.with_latex
def test(a, b):
return - (a + b)
輸出的公式和預(yù)想的一致:
這時(shí),小小的修改一下代碼:
@latexify.with_latex
def test(a, b):
return 1 - (a + b)
預(yù)想的公式應(yīng)該為:
而實(shí)際卻是:
猜想,這可能是一個(gè)bug或者是輸入的方式不對(duì),雖然這個(gè)問(wèn)題很好解決,但是一直很疑惑。。。。。
latexify_py做了什么?
為了一探究竟,嘗試去閱讀其源碼,看看它都做了哪些事情?
首先入口是@latexify.with_latex這個(gè)注解。latexify提供with_latex和get_latex兩個(gè)注解,with_latex只是先做一些初始化,實(shí)際也是調(diào)用get_latex。重點(diǎn)看一下get_latex,其源碼:
def get_latex(fn, math_symbol=True):
try:
source = inspect.getsource(fn)##獲取整個(gè)模塊的源代碼
except Exception:
# Maybe running on console.
source = dill.source.getsource(fn)
return LatexifyVisitor(math_symbol=math_symbol).visit(ast.parse(source)) ##ast.parse把源碼解析為AST節(jié)點(diǎn),AST是抽象語(yǔ)法樹(shù),不依賴于具體的文法,不依賴于語(yǔ)言的細(xì)節(jié),我們將源代碼轉(zhuǎn)化為AST后,可以對(duì)AST做很多的操作
LatexifyVisitor繼承ast的NodeVisitor,ast.NodeVisitor是一個(gè)專門用來(lái)遍歷語(yǔ)法樹(shù)的工具,可以通過(guò)繼承這個(gè)類來(lái)完成對(duì)語(yǔ)法樹(shù)的遍歷以及遍歷過(guò)程中的處理。
LatexifyVisitor首先從根節(jié)點(diǎn)root進(jìn)行遍歷,在遍歷的過(guò)程中,每個(gè)節(jié)點(diǎn)類型都有專用的類型處理函數(shù),以"visit_" + "Node類型"為名稱,如果不存在,則調(diào)用通用的的處理函數(shù)generic_visit。
在latexify的core.py直接引入astunparse,將生成的ast打印出來(lái):
def get_latex(fn, math_symbol=True):
try:
source = inspect.getsource(fn)
print(astunparse.dump(ast.parse(source)))
except Exception:
# Maybe running on console.
source = dill.source.getsource(fn)
return LatexifyVisitor(math_symbol=math_symbol).visit(ast.parse(source))
下面是test對(duì)應(yīng)的ast結(jié)構(gòu):
Module(
body=[FunctionDef(
name='test',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='a',
annotation=None,
type_comment=None),
arg(
arg='b',
annotation=None,
type_comment=None)],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[]),
body=[Return(value=BinOp(
left=Constant(
value=1,
kind=None),
op=Sub(),
right=BinOp(
left=Name(
id='a',
ctx=Load()),
op=Add(),
right=Name(
id='b',
ctx=Load()))))],
decorator_list=[Attribute(
value=Name(
id='latexify',
ctx=Load()),
attr='with_latex',
ctx=Load())],
returns=None,
type_comment=None)],
type_ignores=[])
首先訪問(wèn)根節(jié)點(diǎn)root,root為Moudle類型,會(huì)調(diào)用visit_Moudle函數(shù),以此始遍歷子節(jié)點(diǎn)FunctionDef、Return和BinOp,調(diào)用對(duì)應(yīng)的visit_FunctionDef、visit_Return和vist_BinOp。
參照打印出來(lái)的python公式代碼和ast結(jié)構(gòu),來(lái)分析一下整體邏輯:
vist_FunctionDef
def visit_FunctionDef(self, node):
name_str = r'\mathrm{' + str(node.name) + '}'
arg_strs = [self._parse_math_symbols(str(arg.arg)) for arg in node.args.args]
body_str = self.visit(node.body[0])
return name_str + '(' + ', '.join(arg_strs) + r')\triangleq ' + body_str
遍歷FunctionDef節(jié)點(diǎn)后,輸出為:
\mathrm{test}(a,b)\triangleq
visit_Return
def visit_Return(self, node):
return self.visit(node.value)
Return節(jié)點(diǎn)的值為子節(jié)點(diǎn),類型為BinOp。ast將輸入的代碼分為left和right,test例子中,left為常數(shù)1,right是下一個(gè)子節(jié)點(diǎn),類型為BinOp,op為運(yùn)算符,這里為Sub減法??纯磛isit_BinOp:
visit_BinOp
def visit_BinOp(self, node):
priority = {
ast.Add: 10,
ast.Sub: 10,
ast.Mult: 20,
ast.MatMult: 20,
ast.Div: 20,
ast.FloorDiv: 20,
ast.Mod: 20,
ast.Pow: 30,
}
def _unwrap(child):
return self.visit(child)
def _wrap(child):
latex = _unwrap(child)
if isinstance(child, ast.BinOp):
cp = priority[type(child.op)] if type(child.op) in priority else 100
pp = priority[type(node.op)] if type(node.op) in priority else 100
if cp < pp:
return '(' + latex + ')'
return latex
l = node.left
r = node.right
reprs = {
ast.Add: (lambda: _wrap(l) + ' + ' + _wrap(r)),
ast.Sub: (lambda: _wrap(l) + ' - ' + _wrap(r)),
ast.Mult: (lambda: _wrap(l) + _wrap(r)),
ast.MatMult: (lambda: _wrap(l) + _wrap(r)),
ast.Div: (lambda: r'\frac{' + _unwrap(l) + '}{' + _unwrap(r) + '}'),
ast.FloorDiv: (lambda: r'\left\lfloor\frac{' + _unwrap(l) + '}{' + _unwrap(r) + r'}\right\rfloor'),
ast.Mod: (lambda: _wrap(l) + r' \bmod ' + _wrap(r)),
ast.Pow: (lambda: _wrap(l) + '^{' + _unwrap(r) + '}'),
}
if type(node.op) in reprs:
return reprs[type(node.op)]()
else:
return r'\mathrm{unknown\_binop}(' + _unwrap(l) + ', ' + _unwrap(r) + ')'
ast.Add和ast.Sub設(shè)置的優(yōu)先級(jí)都為10,_wrap方法通過(guò)優(yōu)先級(jí)來(lái)判斷是否添加括號(hào),即:
cp = priority[type(child.op)] if type(child.op) in priority else 100
pp = priority[type(node.op)] if type(node.op) in priority else 100
if cp < pp:
return '(' + latex + ')'
test例子中child.op為Sub,node.op是right中的op為Add,優(yōu)先級(jí)相同不添加括號(hào),所以輸出:
1 - a + b
遍歷結(jié)束后輸出:
\mathrm{test}(a, b)\triangleq 1 - a + b
這和公式實(shí)際上表達(dá)的意思南轅北轍,解決方法就是將小于改為小于等于,即
if cp <= pp:
return '(' + latex + ')'