wasm 原理
wasm 指令的解析,其實都是 入棧,出棧的操作, 它是一個基于棧的虛擬機,比如
get_local 0, 它就是獲取函數(shù)的第一個參數(shù),并把它放到棧里.
i32.const 42 就是把一個 42(int32)放入棧中.
i32.add 就是從棧中取出兩個數(shù),相加后再放回棧里。
下面看一個具體的例子
cpp如下extern "C" { int large(int num) { if (num > 10) { num = num + 12; } else { num = num + 100; } return num; } }指定
Optimization Level-o3 優(yōu)化后 編譯后的wast如下(table 0 anyfunc) (memory $0 1) (export "memory" (memory $0)) (export "large" (func $large)) (func $large (; 0 ;) (param $0 i32) (result i32) (i32.add // 8 (select // 6 (i32.const 12) // 1 (i32.const 100) // 2 (i32.gt_s // 5 (get_local $0) // 3 (i32.const 10) // 4 ) ) (get_local $0) // 7 ) ) )指令解析 $0 為函數(shù)輸入?yún)?shù)
-
(i32.const 12)將 12push到stack;stack=>[12]
-
(i32.const 100)將 100push到stack;stack=>[100, 12]
-
(get_local $0)將 $0 (參數(shù)) 從local中讀取,并push到stack;stack=>[$0, 100, 12]
-
(i32.const 10)將 10push到stack;stack=>[10, $0, 100, 12]
-
(i32.gt_s ()())從stack中pop至v1, v2,并比較大小,v1 > v2則push1 到stack, 反之push 0;stack=>[1, 100, 12]
-
(select ()()() )從stack中pop 1 => v3,從stack中pop 100 => v2, 從 stack 中pop 12 => v1,if v3 為 1(true), 將v2 push到stack, 反之 將v1 push到stack;stack=>[100]
-
(get_local $0)將 $0 (參數(shù)) 從local中讀取,并push到stack;stack=>[$0, 100]
-
(i32.add ()())從stack 中pop 兩個數(shù),相加后push 到stack;stack=>[108]
- 返回結(jié)果 108
做一個 webassembly 的虛擬機主要分兩塊, compile 和 Interpreter. 我們先看 compile 模塊.
Compile
- 編譯主要是對 wasm 結(jié)構(gòu)進行解析, 首先看看 module 對象,這是個核心類, wasm 就是解析到這個對象
type Module struct {
Version uint32 // wasm 的版本
Sections []Section // wasm 中所有的section 數(shù)組, 一個 wasm 文件, 就是由version 和多個 section 組成
Types *SectionTypes // wasm 中所有的函數(shù)描述
Import *SectionImports // wasm 中導入的函數(shù)
Function *SectionFunctions // wasm 中聲明的函數(shù),每個函數(shù)對應(yīng)一個index 指向 Types內(nèi)的函數(shù)類型
Table *SectionTables
Memory *SectionMemories
Global *SectionGlobals
Export *SectionExports // wasm 中導出的函數(shù)描述
Start *SectionStartFunction // 需要立刻執(zhí)行的函數(shù)
Elements *SectionElements // 定義在 table 中的元素
Code *SectionCode // 該 module 的所有函數(shù)信息數(shù)據(jù)
Data *SectionData // 數(shù)據(jù)區(qū), 比如一些字符串等數(shù)據(jù), 會放在Data里, 用 offset 標記
Customs []*SectionCustom
// The function index space of the module
FunctionIndexSpace []Function // wasm 中所有的函數(shù)包括 SectionImports 和 SectionFunctions,函數(shù)中的 type 指向 Types 中的類型
GlobalIndexSpace []GlobalEntry
// function indices into the global function space
// the limit of each table is its capacity (cap)
TableIndexSpace [][]uint32
LinearMemoryIndexSpace [][]byte // 線性內(nèi)存, Data 數(shù)據(jù)會存放在這里
imports struct {
Funcs []uint32 // 導入的函數(shù)
Globals int
Tables int
Memories int
}
}
- 讀取 wasm 文件
// 從本地讀取一個 wasm 文件,并返回 module
func ReadModule(r io.Reader, resolvePath ResolveFunc) (*Module, error) {
// 通過解析 二進制 wasm 文件,將數(shù)據(jù)解析道對應(yīng)的 section 中去
m, err := DecodeModule(r)
...
if m.Import != nil && resolvePath != nil {
if m.Code == nil {
m.Code = &SectionCode{}
}
// 解析 導入 的 module
err := m.resolveImports(resolvePath)
}
for _, fn := range []func() error{
m.populateGlobals,
// 將內(nèi)部函數(shù)轉(zhuǎn)化為 Function 對象,并將 導入的函數(shù)也 一并添加到 FunctionIndexSpace 中
m.populateFunctions,
m.populateTables,
// 將 m.Data 放到線性內(nèi)存中
m.populateLinearMemory,
} {
if err := fn(); err != nil {
return nil, err
}
}
return m, nil
}
func DecodeModule(r io.Reader) (*Module, error) {
reader := &readpos.ReadPos{
R: r,
CurPos: 0,
}
m := &Module{}
...
err = newSectionsReader(m).readSections(reader)
return m, nil
}
-
DecodeModule新建sectionReader, 并調(diào)用readSections
func (s *sectionsReader) readSections(r *readpos.ReadPos) error {
for {
// 循環(huán)讀取section,知道讀完
done, err := s.readSection(r)
switch {
case err != nil:
return err
case done:
return nil
}
}
}
// 從reader 中讀取一個有效的 section. The first return value is true if and only if
// the module has been completely read.
func (sr *sectionsReader) readSection(r *readpos.ReadPos) (bool, error) {
m := sr.m
logger.Println("Reading section ID")
// 從 reader 中讀取一個字節(jié)
id, err := r.ReadByte()
...
s := RawSection{ID: SectionID(id)}
logger.Println("Reading payload length")
// 讀取實際 數(shù)據(jù)
payloadDataLen, err := leb128.ReadVarUint32(r)
if err != nil {
return false, err
}
logger.Printf("Section payload length: %d", payloadDataLen)
s.Start = r.CurPos
sectionBytes := new(bytes.Buffer)
sectionBytes.Grow(int(getInitialCap(payloadDataLen)))
sectionReader := io.LimitReader(io.TeeReader(r, sectionBytes), int64(payloadDataLen))
// 判斷section 的類型,并將該類型空的 struct 賦值給 module 對應(yīng)的屬性
var sec Section
switch s.ID {
case SectionIDCustom:
logger.Println("section custom")
cs := &SectionCustom{}
m.Customs = append(m.Customs, cs)
sec = cs
case SectionIDType:
logger.Println("section type")
m.Types = &SectionTypes{}
sec = m.Types
case SectionIDImport:
logger.Println("section import")
m.Import = &SectionImports{}
sec = m.Import
case SectionIDFunction:
logger.Println("section function")
m.Function = &SectionFunctions{}
sec = m.Function
case SectionIDTable:
logger.Println("section table")
m.Table = &SectionTables{}
sec = m.Table
case SectionIDMemory:
logger.Println("section memory")
m.Memory = &SectionMemories{}
sec = m.Memory
case SectionIDGlobal:
logger.Println("section global")
m.Global = &SectionGlobals{}
sec = m.Global
case SectionIDExport:
logger.Println("section export")
m.Export = &SectionExports{}
sec = m.Export
case SectionIDStart:
logger.Println("section start")
m.Start = &SectionStartFunction{}
sec = m.Start
case SectionIDElement:
logger.Println("section element")
m.Elements = &SectionElements{}
sec = m.Elements
case SectionIDCode:
logger.Println("section code")
m.Code = &SectionCode{}
sec = m.Code
case SectionIDData:
logger.Println("section data")
m.Data = &SectionData{}
sec = m.Data
default:
return false, InvalidSectionIDError(s.ID)
}
// 從reader 中讀取數(shù)據(jù),存入 section (對應(yīng)到 module 的某個變量中)
err = sec.ReadPayload(sectionReader)
if err != nil {
logger.Println(err)
return false, err
}
s.End = r.CurPos
s.Bytes = sectionBytes.Bytes()
// 將 raw s 保存到 對應(yīng)的 xxxSection 中
*sec.GetRawSection() = s
...
// 保存 section
m.Sections = append(m.Sections, sec)
return false, nil
}
- 將文件讀取到 module 中后,還需要加載 import 的模塊
// 解析import 的函數(shù)
func (module *Module) resolveImports(resolve ResolveFunc) error {
if module.Import == nil {
return nil
}
modules := make(map[string]*Module)
var funcs uint32
// 遍歷 module.Import 下的 ”入口“
for _, importEntry := range module.Import.Entries {
importedModule, ok := modules[importEntry.ModuleName]
if !ok {
var err error
// 如果不存在,就調(diào)用外部注入的 resolver 函數(shù)解析,并返回 module 對象
importedModule, err = resolve(importEntry.ModuleName)
if err != nil {
return err
}
// 將導入的 module 保存起來
modules[importEntry.ModuleName] = importedModule
}
if importedModule.Export == nil {
return ErrNoExportsInImportedModule
}
// 判斷 導入的module 中是否暴露了 importEntry.FieldName(本module 需要調(diào)用的方法)
exportEntry, ok := importedModule.Export.Entries[importEntry.FieldName]
if !ok {
return ExportNotFoundError{importEntry.ModuleName, importEntry.FieldName}
}
// 判斷 待導入函數(shù)類型, 與被導入模塊的函數(shù)類型 是否一致
if exportEntry.Kind != importEntry.Type.Kind() {
return KindMismatchError{
FieldName: importEntry.FieldName,
ModuleName: importEntry.ModuleName,
Import: importEntry.Type.Kind(),
Export: exportEntry.Kind,
}
}
index := exportEntry.Index
switch exportEntry.Kind {
case ExternalFunction:
// 根據(jù) exportEntry 對應(yīng)的 functionIndex ,獲取對應(yīng)的 Function 類型
fn := importedModule.GetFunction(int(index))
if fn == nil {
return InvalidFunctionIndexError(index)
}
importIndex := importEntry.Type.(FuncImport).Type
// 下面就判斷 待帶入的function 和 別導入的 function 的類型是否一致
// 比較參數(shù)以及返回值長度
if len(fn.Sig.ReturnTypes) != len(module.Types.Entries[importIndex].ReturnTypes) || len(fn.Sig.ParamTypes) != len(module.Types.Entries[importIndex].ParamTypes) {
return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
}
// 比較返回值類型
for i, typ := range fn.Sig.ReturnTypes {
if typ != module.Types.Entries[importIndex].ReturnTypes[i] {
return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
}
}
// 比較參數(shù)類型
for i, typ := range fn.Sig.ParamTypes {
if typ != module.Types.Entries[importIndex].ParamTypes[i] {
return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
}
}
// 將 Function 對象(被導入的函數(shù)),添加到 module 的 FunctionIndexSpace 數(shù)組中
module.FunctionIndexSpace = append(module.FunctionIndexSpace, *fn)
// 保存 Function 的函數(shù)體
module.Code.Bodies = append(module.Code.Bodies, *fn.Body)
// 將 Function 對象保存到 module 的 import.Funcs 數(shù)組中
module.imports.Funcs = append(module.imports.Funcs, funcs)
funcs++
case ExternalGlobal:
// todo ...
glb := importedModule.GetGlobal(int(index))
if glb == nil {
return InvalidGlobalIndexError(index)
}
if glb.Type.Mutable {
return ErrImportMutGlobal
}
module.GlobalIndexSpace = append(module.GlobalIndexSpace, *glb)
module.imports.Globals++
// In both cases below, index should be always 0 (according to the MVP)
// We check it against the length of the index space anyway.
case ExternalTable:
if int(index) >= len(importedModule.TableIndexSpace) {
return InvalidTableIndexError(index)
}
module.TableIndexSpace[0] = importedModule.TableIndexSpace[0]
module.imports.Tables++
case ExternalMemory:
if int(index) >= len(importedModule.LinearMemoryIndexSpace) {
return InvalidLinearMemoryIndexError(index)
}
module.LinearMemoryIndexSpace[0] = importedModule.LinearMemoryIndexSpace[0]
module.imports.Memories++
default:
return InvalidExternalError(exportEntry.Kind)
}
}
return nil
}
- populateFunctions
// 函數(shù)索引空間索引所有導入和內(nèi)部定義的函數(shù)定義
func (m *Module) populateFunctions() error {
...
// 給內(nèi)部定義的 func 構(gòu)造 fn
// Add the functions from the wasm itself to the function list
numImports := len(m.FunctionIndexSpace)
for codeIndex, typeIndex := range m.Function.Types {
if int(typeIndex) >= len(m.Types.Entries) {
return InvalidFunctionIndexError(typeIndex)
}
// Create the main function structure
fn := Function{
Sig: &m.Types.Entries[typeIndex],
Body: &m.Code.Bodies[codeIndex],
Name: names[uint32(codeIndex+numImports)], // Add the name string if we have it
}
m.FunctionIndexSpace = append(m.FunctionIndexSpace, fn)
}
funcs := make([]uint32, 0, len(m.Function.Types)+len(m.imports.Funcs))
funcs = append(funcs, m.imports.Funcs...)
funcs = append(funcs, m.Function.Types...)
m.Function.Types = funcs
return nil
}
- 新建一個 VM
先看 vm 類型
// VM is the execution context for executing WebAssembly bytecode.
type VM struct {
ctx context // 執(zhí)行上下文
type context struct {
stack []uint64 // 棧深度
locals []uint64 // 局部變量
code []byte // 函數(shù)的字節(jié)碼
asm []asmBlock
pc int64 // 當前的字節(jié)碼 index
curFunc int64 // 當前函數(shù)在 funcs 的index
}
module *wasm.Module
globals []uint64
memory []byte
funcs []function // 函數(shù)數(shù)組 compiledFunction or goFunction
funcTable [256]func() // 指令集,對應(yīng)的解析函數(shù)
// RecoverPanic controls whether the `ExecCode` method
// recovers from a panic and returns it as an error
// instead.
// A panic can occur either when executing an invalid VM
// or encountering an invalid instruction, e.g. `unreachable`.
RecoverPanic bool
abort bool // Flag for host functions to terminate execution
nativeBackend *nativeCompiler
}
// 通過 module 對象和 options 構(gòu)造一個 vm
func NewVM(module *wasm.Module, opts ...VMOption) (*VM, error) {
var vm VM
var options config
for _, opt := range opts {
opt(&options)
}
if module.Memory != nil && len(module.Memory.Entries) != 0 {
if len(module.Memory.Entries) > 1 {
return nil, ErrMultipleLinearMemories
}
vm.memory = make([]byte, uint(module.Memory.Entries[0].Limits.Initial)*wasmPageSize)
copy(vm.memory, module.LinearMemoryIndexSpace[0])
}
vm.funcs = make([]function, len(module.FunctionIndexSpace))
vm.globals = make([]uint64, len(module.GlobalIndexSpace))
vm.newFuncTable()
vm.module = module
nNatives := 0
for i, fn := range module.FunctionIndexSpace {
// 如果是 import 的原生 golang 方法,使用 goFunction 處理
if fn.IsHost() {
vm.funcs[i] = goFunction{
typ: fn.Host.Type(),
val: fn.Host,
}
nNatives++
continue
}
// 將function拆卸并封裝成新的結(jié)構(gòu)
disassembly, err := disasm.NewDisassembly(fn, module)
if err != nil {
return nil, err
}
totalLocalVars := 0
totalLocalVars += len(fn.Sig.ParamTypes)
for _, entry := range fn.Body.Locals {
totalLocalVars += int(entry.Count)
}
// 編譯 字節(jié)碼
code, meta := compile.Compile(disassembly.Code)
vm.funcs[i] = compiledFunction{
codeMeta: meta,
code: code,
branchTables: meta.BranchTables,
maxDepth: disassembly.MaxDepth,
totalLocalVars: totalLocalVars,
args: len(fn.Sig.ParamTypes),
returns: len(fn.Sig.ReturnTypes) != 0,
}
}
...
return &vm, nil
}
Interpreter
下面執(zhí)行代碼的過程,即是翻譯代碼的過程
// fnIndex 函數(shù)的index, args 是該函數(shù)的參數(shù)
func (vm *VM) ExecCode(fnIndex int64, args ...uint64) (rtrn interface{}, err error) {
...
if int(fnIndex) > len(vm.funcs) {
return nil, InvalidFunctionIndexError(fnIndex)
}
if len(vm.module.GetFunction(int(fnIndex)).Sig.ParamTypes) != len(args) {
return nil, ErrInvalidArgumentCount
}
compiled, ok := vm.funcs[fnIndex].(compiledFunction)
if !ok {
panic(fmt.Sprintf("exec: function at index %d is not a compiled function", fnIndex))
}
depth := compiled.maxDepth + 1
// 初始化執(zhí)行棧
if cap(vm.ctx.stack) < depth {
vm.ctx.stack = make([]uint64, 0, depth)
} else {
vm.ctx.stack = vm.ctx.stack[:0]
}
vm.ctx.locals = make([]uint64, compiled.totalLocalVars)
vm.ctx.pc = 0
vm.ctx.code = compiled.code
vm.ctx.asm = compiled.asm
vm.ctx.curFunc = fnIndex
// 給函數(shù)的參數(shù)賦值
for i, arg := range args {
vm.ctx.locals[i] = arg
}
res := vm.execCode(compiled)
if compiled.returns {
rtrnType := vm.module.GetFunction(int(fnIndex)).Sig.ReturnTypes[0]
switch rtrnType {
case wasm.ValueTypeI32:
rtrn = uint32(res)
case wasm.ValueTypeI64:
rtrn = uint64(res)
case wasm.ValueTypeF32:
rtrn = math.Float32frombits(uint32(res))
case wasm.ValueTypeF64:
rtrn = math.Float64frombits(res)
default:
return nil, InvalidReturnTypeError(rtrnType)
}
}
return rtrn, nil
}
func (vm *VM) execCode(compiled compiledFunction) uint64 {
outer:
for int(vm.ctx.pc) < len(vm.ctx.code) && !vm.abort {
op := vm.ctx.code[vm.ctx.pc]
vm.ctx.pc++
switch op {
// 解析到 return 指令的時候,退出循環(huán)
case ops.Return:
break outer
// 省略一些不常用的case
...
default:
// 大部分會走這個case
vm.funcTable[op]()
}
}
if compiled.returns {
//如果有返回值,從棧中取出返回
return vm.ctx.stack[len(vm.ctx.stack)-1]
}
return 0
}
funcTable [256]func() 的初始化 ,一個指令(Op)對應(yīng)一個解析方法, 看看 Op 的結(jié)構(gòu)
// Op describes a WASM operator.
type Op struct {
Code byte // The single-byte opcode
Name string // 該操作的名稱
// Whether this operator is polymorphic.
// A polymorphic operator has a variable arity. call, call_indirect, and
// drop are examples of polymorphic operators.
Polymorphic bool // 是否是動態(tài)的, true:比如一些邏輯控制語句, 還有 get/setlocal 等
Args []wasm.ValueType // 該指令需要的參數(shù)類型(數(shù)量)(會從棧中pop出來)
Returns wasm.ValueType // 返回的參數(shù)類型
}
func (vm *VM) newFuncTable() {
vm.funcTable[ops.I32Clz] = vm.i32Clz
vm.funcTable[ops.I32Ctz] = vm.i32Ctz
vm.funcTable[ops.I32Popcnt] = vm.i32Popcnt
vm.funcTable[ops.I32Add] = vm.i32Add
vm.funcTable[ops.I32Sub] = vm.i32Sub
vm.funcTable[ops.I32Mul] = vm.i32Mul
....
....
vm.funcTable[ops.Drop] = vm.drop
vm.funcTable[ops.Select] = vm.selectOp
vm.funcTable[ops.GetLocal] = vm.getLocal
vm.funcTable[ops.SetLocal] = vm.setLocal
vm.funcTable[ops.TeeLocal] = vm.teeLocal
vm.funcTable[ops.GetGlobal] = vm.getGlobal
vm.funcTable[ops.SetGlobal] = vm.setGlobal
vm.funcTable[ops.Unreachable] = vm.unreachable
vm.funcTable[ops.Nop] = vm.nop
vm.funcTable[ops.Call] = vm.call
vm.funcTable[ops.CallIndirect] = vm.callIndirect
}
例如
// 從棧中pop 兩個uint32 出來,相加后在push 到棧中
func (vm *VM) i32Add() {
vm.pushUint32(vm.popUint32() + vm.popUint32())
}
這里再講一下 vm.call
func (vm *VM) call() {
index := vm.fetchUint32()
// 這里會從 funcs 數(shù)組里取出 Function(or goFunction) 對象,調(diào)用call
vm.funcs[index].call(vm, int64(index))
}
// goFunction 利用反射機制,執(zhí)行函數(shù)
func (fn goFunction) call(vm *VM, index int64) {
// numIn = # of call inputs + vm, as the function expects
// an additional *VM argument
numIn := fn.typ.NumIn()
args := make([]reflect.Value, numIn)
proc := NewProcess(vm)
// 第一個參數(shù)必須是 *Process 類型
if reflect.ValueOf(proc).Kind() != fn.typ.In(0).Kind() {
panic(fmt.Sprintf("exec: the first argument of a host function was %s, expected %s", fn.typ.In(0).Kind(), reflect.ValueOf(vm).Kind()))
}
args[0] = reflect.ValueOf(proc)
// 給函數(shù)的參數(shù)賦值
for i := numIn - 1; i >= 1; i-- {
val := reflect.New(fn.typ.In(i)).Elem()
raw := vm.popUint64()
kind := fn.typ.In(i).Kind()
switch kind {
case reflect.Float64, reflect.Float32:
val.SetFloat(math.Float64frombits(raw))
case reflect.Uint32, reflect.Uint64:
val.SetUint(raw)
case reflect.Int32, reflect.Int64:
val.SetInt(int64(raw))
default:
panic(fmt.Sprintf("exec: args %d invalid kind=%v", i, kind))
}
args[i] = val
}
// 執(zhí)行函數(shù)
rtrns := fn.val.Call(args)
// 將返回值 push 到棧中
for i, out := range rtrns {
kind := out.Kind()
switch kind {
case reflect.Float64, reflect.Float32:
vm.pushFloat64(out.Float())
case reflect.Uint32, reflect.Uint64:
vm.pushUint64(out.Uint())
case reflect.Int32, reflect.Int64:
vm.pushInt64(out.Int())
default:
panic(fmt.Sprintf("exec: return value %d invalid kind=%v", i, kind))
}
}
}
// 非原生函數(shù)實現(xiàn)
func (compiled compiledFunction) call(vm *VM, index int64) {
newStack := make([]uint64, 0, compiled.maxDepth+1)
locals := make([]uint64, compiled.totalLocalVars)
// 給參數(shù)賦值
for i := compiled.args - 1; i >= 0; i-- {
locals[i] = vm.popUint64()
}
//保存執(zhí)行上下文
prevCtxt := vm.ctx
// 新建執(zhí)行上下文
vm.ctx = context{
stack: newStack,
locals: locals,
code: compiled.code,
asm: compiled.asm,
pc: 0,
curFunc: index,
}
rtrn := vm.execCode(compiled)
//被調(diào)用函數(shù)執(zhí)行完了,恢復(fù)上下文
vm.ctx = prevCtxt
if compiled.returns {
// 把返回值push到棧中
vm.pushUint64(rtrn)
}
}