1 神奇的Python多維向量操作
用慣 python 的同學(xué)都知道,python 對(duì)于多維向量操作的靈活性可謂是 有!如!神!跡!
來(lái)來(lái),眼見(jiàn)為實(shí)!

即便 2個(gè)多維向量輸入的形狀不同,它們依然可以進(jìn)行計(jì)算并得到一個(gè)我們 相對(duì)期望的結(jié)果。
相關(guān)的 python 代碼如下,大家也可以自己再體驗(yàn)下。
import numpy as np
aa = np.full((3,4,8), 1);
bb = np.full((3,4,8), 2);
cc = aa + bb;
print(cc)
bb = np.full((1,1,1), 2);
cc = aa + bb;
print(cc)
bb = np.full((1,4,8), 2);
cc = aa + bb;
print(cc)
bb = np.full((3,1,1), 2);
cc = aa + bb;
print(cc)
然而,這樣的靈活很難實(shí)現(xiàn)么?其實(shí)都是說(shuō)穿不如一文錢的啦~走著!
2 思路整理
代碼要寫(xiě),但不要著急,咱們先把需求搞明白,思路想清楚。
要實(shí)現(xiàn) 異Shape多維向量點(diǎn)加,我們要引入一個(gè) 維度步幅(Stride) 的概念,太難解釋了。先記住它是要點(diǎn)就好了,跟著我繼續(xù)實(shí)踐中去理解。
我們先從 一維的 異Shape操作 看起。
2.1 一維點(diǎn)加

如圖,一個(gè) 長(zhǎng)度為5的數(shù)組(數(shù)組A) 和一個(gè) 長(zhǎng)度為1的數(shù)組(數(shù)組B) 相加,我們期望 數(shù)組A 的每個(gè)值和 數(shù)組B 的唯一一個(gè)值分別相加。
for (int i = 0; i < 5; i++) {
arrayC[i * strideC] = arrayA[i * strideA] + arrayB[i * strideB];
}
如上面的偽代碼,一個(gè)循環(huán)就可以解決
大家著重關(guān)注下圖中的 strideB = 0,然后體會(huì)一下~
2.2 二維點(diǎn)加

如圖,當(dāng)維度擴(kuò)展為二維時(shí),我們需要 為每個(gè)維度 提供一個(gè)數(shù)組來(lái)管理 stride
此時(shí)我們的需求 可以用兩層循環(huán)實(shí)現(xiàn),偽代碼如下:
for (int i = 0; i < 5; i++) {
pic = arrayC + i * strideC[1];
pia = arrayA + i * strideA[1];
pib = arrayB + i * strideB[1];
for (int j = 0; j < 3; j++) {
pjc = pic + j * strideC[0];
pja = pia + j * strideA[0];
pjb = pib + j * strideB[0];
*pjc = *pja + *pjb;
}
}
Tips: 一定要逐行代碼結(jié)合圖示分析一下,這種多維的操作太難不實(shí)踐直接講清楚了!
那么我們?cè)贀Q一個(gè)場(chǎng)景(如下圖),大家可以自己套用上面?zhèn)未a的邏輯自己推導(dǎo)下~

2.3 三維點(diǎn)加

大家用數(shù)學(xué)歸納自己分析一下,我懶懶地畫(huà)出圖示,列出來(lái)stride信息,就不具體分析咯~
偽代碼如下,一個(gè)三層循環(huán):
for (int i = 0; i < 5; i++) {
pic = arrayC + i * strideC[2];
pia = arrayA + i * strideA[2];
pib = arrayB + i * strideB[2];
for (int j = 0; j < 16; j++) {
pjc = pic + j * strideC[1];
pja = pia + j * strideA[1];
pjb = pib + j * strideB[1];
for (int k = 0; k < 8; k++) {
pkc = pjc + k * strideC[0];
pka = pja + k * strideA[0];
pkb = pjb + k * strideB[0];
*pkc = *pka + *pkb;
}
}
}
外層 i循環(huán) 定位面,中層 j循環(huán) 定位行,內(nèi)層 k循環(huán) 定位列。
2.4 多維
再一維一維地講就是浪費(fèi)篇幅咯一樣,感興趣的話用數(shù)學(xué)歸納自己推導(dǎo)下并不復(fù)雜~
3 參考 MNN BinaryOp 的實(shí)現(xiàn)
在 MNN 的 BinaryOp 操作中,即用到了這種算法思路,不過(guò):
- 它固定了 最高支持計(jì)算維度為6(很OK了,我們一般也沒(méi)有超過(guò)6維計(jì)算的需求)
- 它的 外層循環(huán)是低維度的變化,內(nèi)層循環(huán)是高緯度的變化(這和我上面介紹的思路 正好相反,所以如果要分析源碼時(shí)候要注意不要因?yàn)樗惴▽?duì)不上兒讓自己暈掉~(yú))
關(guān)鍵源碼函數(shù)如下:
// MNN binaryOp 關(guān)鍵計(jì)算操作源碼
template <typename Tin, typename Tout, typename Func>
static ErrorCode _binaryOp(Tensor* input0, Tensor* input1, Tensor* output) {
Func f;
const int input0DataCount = input0->elementSize();
const int input1DataCount = input1->elementSize();
const Tin* input0Data = input0->host<Tin>();
const Tin* input1Data = input1->host<Tin>();
Tout* outputData = output->host<Tout>();
if (input0DataCount == 1) { // data count == 1, not only mean scalar input, maybe of shape (1, 1, 1, ...,1)
for (int i = 0; i < input1DataCount; i++) {
outputData[i] = static_cast<Tout>(f(input0Data[0], input1Data[i]));
}
} else if (input1DataCount == 1) {
for (int i = 0; i < input0DataCount; i++) {
outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[0]));
}
} else { // both input contains more than one element,which means no scalar input
bool sameShape = input0->elementSize() == input1->elementSize();
if (sameShape) { // two inputs have the same shape, apply element-wise operation
for (int i = 0; i < input0DataCount; i++) {
outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[i]));
}
} else { // not the same shape, use broadcast
#define MAX_DIM 6
MNN_ASSERT(output->dimensions() <= MAX_DIM);
int dims[MAX_DIM];
int stride[MAX_DIM];
int iStride0[MAX_DIM];
int iStride1[MAX_DIM];
for (int i = MAX_DIM - 1; i >= 0; --i) {
dims[i] = 1;
stride[i] = 0;
iStride0[i] = 0;
iStride1[i] = 0;
int input0I = i - (output->dimensions() - input0->dimensions());
int input1I = i - (output->dimensions() - input1->dimensions());
if (i < output->dimensions()) {
dims[i] = output->length(i);
stride[i] = output->stride(i);
}
if (input0I >= 0 && input0->length(input0I) != 1) {
iStride0[i] = input0->stride(input0I);
}
if (input1I >= 0 && input1->length(input1I) != 1) {
iStride1[i] = input1->stride(input1I);
}
}
for (int w = 0; w < dims[5]; ++w) {
auto ow = outputData + w * stride[5];
auto i0w = input0Data + w * iStride0[5];
auto i1w = input1Data + w * iStride1[5];
#define PTR(x, y, i) \
auto o##x = o##y + x * stride[i]; \
auto i0##x = i0##y + x * iStride0[i]; \
auto i1##x = i1##y + x * iStride1[I]
for (int v = 0; v < dims[4]; ++v) {
PTR(v, w, 4);
for (int u = 0; u < dims[3]; ++u) {
PTR(u, v, 3);
for (int z = 0; z < dims[2]; ++z) {
PTR(z, u, 2);
for (int y = 0; y < dims[1]; ++y) {
PTR(y, z, 1);
for (int x = 0; x < dims[0]; ++x) {
PTR(x, y, 0);
*ox = static_cast<Tout>(f(*i0x, *i1x));
}
}
}
}
}
}
#undef MAX_DIM
#undef PTR
}
// broadcast-capable check is done in compute size
}
return NO_ERROR;
}
4 可調(diào)式的Demo
MNN 的源碼確實(shí)不太容易閱讀,畢竟它在實(shí)現(xiàn)算法的同時(shí):
- 考慮的不同操作的兼容(不僅僅支持加法)
- 考慮了不同數(shù)據(jù)類型的兼容
- 基于 MNN 的數(shù)據(jù)結(jié)構(gòu)
但是不用擔(dān)心,像往常一樣,我的技術(shù)文章一般都會(huì)為大家配套一份簡(jiǎn)單的參考代碼,這份代碼相比 MNN 源碼:
- 只支持加法
- 只支持 float 操作類型
- 數(shù)據(jù)結(jié)構(gòu)即 float * 與 std::vector<int> 的組合
- 不限維度的操作(即你可以進(jìn)行 20維的 float 相加操作)
- 左加數(shù)(A) 的 Shape 可以和 輸出(C) 的 Shape 不同(詳見(jiàn) 【5 思考 】)
// cymv_add 點(diǎn)加主函數(shù)
// __add 為支持不限維度操作而實(shí)現(xiàn)的遞歸子函數(shù)
static void __add(int dimTag,
float *pC, std::vector<int> &rev_stepCs,
float *pA, std::vector<int> &rev_stepAs, std::vector<int> &dimAs,
float *pB, std::vector<int> &rev_stepBs, std::vector<int> &dimBs) {
int dimNum = (int)dimAs.size();
int curDimA = dimAs[dimTag];
int curDimB = dimBs[dimTag];
int curDimC = curDimA > curDimB ? curDimA : curDimB;
int curStepA = rev_stepAs[dimNum - 1 - dimTag];
int curStepB = rev_stepBs[dimNum - 1 - dimTag];
int curStepC = rev_stepCs[dimNum - 1 - dimTag];
float *tmppa = pA;
float *tmppb = pB;
float *tmppc = pC;
for (int i = 0; i < curDimC; i++) {
if (dimTag == dimNum - 1) {
*tmppc = *tmppa + *tmppb;
} else {
__add(dimTag + 1,
tmppc, rev_stepCs,
tmppa, rev_stepAs, dimAs,
tmppb, rev_stepBs, dimBs);
}
tmppc += curStepC;
tmppa += curStepA;
tmppb += curStepB;
}
}
void cymv_add(float *dst,
float *src0,
std::vector<int> shape0,
float *src1,
std::vector<int> shape1) {
if (shape0.size() != shape1.size()) {
printf("維度不等,無(wú)法計(jì)算");
return;
}
/* 小維度在前,高維度在后 */
std::vector<int> step0;
std::vector<int> step1;
std::vector<int> stepOut;
int tmpStep0 = 1;
int tmpStep1 = 1;
int tmpStepOut = 1;
for (int i = (int)(shape0.size()) - 1; i >= 0 ; i--) {
if (1 == shape0[i]) {
step0.push_back(0);
} else {
step0.push_back(tmpStep0);
}
if (1 == shape1[i]) {
step1.push_back(0);
} else {
step1.push_back(tmpStep1);
}
stepOut.push_back(tmpStepOut);
tmpStep0 *= shape0[i];
tmpStep1 *= shape1[i];
int maxVal = shape0[i] > shape1[i] ? shape0[i] : shape1[i];
tmpStepOut *= maxVal;
}
__add(0,
dst, stepOut,
src0, step0, shape0,
src1, step1, shape1);
}
當(dāng)然,我說(shuō)好讀也是相對(duì)的,一樣要費(fèi)點(diǎn)心思哦!但好在有源碼能調(diào)試嘛!
5 思考
最后,這樣的場(chǎng)景我們的算法能支持么?

下載 Demo 試試看咯!