python的多維向量操作沒(méi)那么神 —— 異Shape多維向量點(diǎn)加的實(shí)現(xiàn)分析

1 神奇的Python多維向量操作

用慣 python 的同學(xué)都知道,python 對(duì)于多維向量操作的靈活性可謂是 有!如!神!跡!
來(lái)來(lái),眼見(jiàn)為實(shí)!

python多維向量的操作示意

即便 2個(gè)多維向量輸入的形狀不同,它們依然可以進(jìn)行計(jì)算并得到一個(gè)我們 相對(duì)期望的結(jié)果。
相關(guān)的 python 代碼如下,大家也可以自己再體驗(yàn)下。

import numpy as np

aa = np.full((3,4,8), 1);

bb = np.full((3,4,8), 2);
cc = aa + bb;
print(cc)

bb = np.full((1,1,1), 2);
cc = aa + bb;
print(cc)

bb = np.full((1,4,8), 2);
cc = aa + bb;
print(cc)

bb = np.full((3,1,1), 2);
cc = aa + bb;
print(cc)

然而,這樣的靈活很難實(shí)現(xiàn)么?其實(shí)都是說(shuō)穿不如一文錢的啦~走著!

2 思路整理

代碼要寫(xiě),但不要著急,咱們先把需求搞明白,思路想清楚。

要實(shí)現(xiàn) 異Shape多維向量點(diǎn)加,我們要引入一個(gè) 維度步幅(Stride) 的概念,太難解釋了。先記住它是要點(diǎn)就好了,跟著我繼續(xù)實(shí)踐中去理解。

我們先從 一維的 異Shape操作 看起。

2.1 一維點(diǎn)加

一維點(diǎn)加

如圖,一個(gè) 長(zhǎng)度為5的數(shù)組(數(shù)組A) 和一個(gè) 長(zhǎng)度為1的數(shù)組(數(shù)組B) 相加,我們期望 數(shù)組A 的每個(gè)值和 數(shù)組B 的唯一一個(gè)值分別相加。

for (int i = 0; i < 5; i++) {
        arrayC[i * strideC] = arrayA[i * strideA] + arrayB[i * strideB];
}

如上面的偽代碼,一個(gè)循環(huán)就可以解決
大家著重關(guān)注下圖中的 strideB = 0,然后體會(huì)一下~

2.2 二維點(diǎn)加

二維點(diǎn)加

如圖,當(dāng)維度擴(kuò)展為二維時(shí),我們需要 為每個(gè)維度 提供一個(gè)數(shù)組來(lái)管理 stride
此時(shí)我們的需求 可以用兩層循環(huán)實(shí)現(xiàn),偽代碼如下:

    for (int i = 0; i < 5; i++) {
        pic = arrayC + i * strideC[1];
        pia = arrayA + i * strideA[1];
        pib = arrayB + i * strideB[1];
        
        for (int j = 0; j < 3; j++) {
            pjc = pic + j * strideC[0];
            pja = pia + j * strideA[0];
            pjb = pib + j * strideB[0];
            
            *pjc = *pja + *pjb;
        }
    }

Tips: 一定要逐行代碼結(jié)合圖示分析一下,這種多維的操作太難不實(shí)踐直接講清楚了!

那么我們?cè)贀Q一個(gè)場(chǎng)景(如下圖),大家可以自己套用上面?zhèn)未a的邏輯自己推導(dǎo)下~

二維練習(xí)圖

2.3 三維點(diǎn)加

三維點(diǎn)加

大家用數(shù)學(xué)歸納自己分析一下,我懶懶地畫(huà)出圖示,列出來(lái)stride信息,就不具體分析咯~
偽代碼如下,一個(gè)三層循環(huán):

for (int i = 0; i < 5; i++) {
        pic = arrayC + i * strideC[2];
        pia = arrayA + i * strideA[2];
        pib = arrayB + i * strideB[2];
        
        for (int j = 0; j < 16; j++) {
            pjc = pic + j * strideC[1];
            pja = pia + j * strideA[1];
            pjb = pib + j * strideB[1];
            
            for (int k = 0; k < 8; k++) {
                pkc = pjc + k * strideC[0];
                pka = pja + k * strideA[0];
                pkb = pjb + k * strideB[0];
                
                *pkc = *pka + *pkb;
            }
        }
    }

外層 i循環(huán) 定位面,中層 j循環(huán) 定位行,內(nèi)層 k循環(huán) 定位列。

2.4 多維

再一維一維地講就是浪費(fèi)篇幅咯一樣,感興趣的話用數(shù)學(xué)歸納自己推導(dǎo)下并不復(fù)雜~

3 參考 MNN BinaryOp 的實(shí)現(xiàn)

MNNBinaryOp 操作中,即用到了這種算法思路,不過(guò):

  1. 它固定了 最高支持計(jì)算維度為6(很OK了,我們一般也沒(méi)有超過(guò)6維計(jì)算的需求)
  2. 它的 外層循環(huán)是低維度的變化,內(nèi)層循環(huán)是高緯度的變化(這和我上面介紹的思路 正好相反,所以如果要分析源碼時(shí)候要注意不要因?yàn)樗惴▽?duì)不上兒讓自己暈掉~(yú))

關(guān)鍵源碼函數(shù)如下:

// MNN binaryOp 關(guān)鍵計(jì)算操作源碼

template <typename Tin, typename Tout, typename Func>
static ErrorCode _binaryOp(Tensor* input0, Tensor* input1, Tensor* output) {
    Func f;

    const int input0DataCount = input0->elementSize();
    const int input1DataCount = input1->elementSize();

    const Tin* input0Data = input0->host<Tin>();
    const Tin* input1Data = input1->host<Tin>();
    Tout* outputData      = output->host<Tout>();

    if (input0DataCount == 1) { // data count == 1, not only mean scalar input, maybe of shape (1, 1, 1, ...,1)
        for (int i = 0; i < input1DataCount; i++) {
            outputData[i] = static_cast<Tout>(f(input0Data[0], input1Data[i]));
        }
    } else if (input1DataCount == 1) {
        for (int i = 0; i < input0DataCount; i++) {
            outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[0]));
        }
    } else { // both input contains more than one element,which means no scalar input
        bool sameShape = input0->elementSize() == input1->elementSize();
        if (sameShape) { // two inputs have the same shape, apply element-wise operation
            for (int i = 0; i < input0DataCount; i++) {
                outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[i]));
            }
        } else { // not the same shape, use broadcast
#define MAX_DIM 6
            MNN_ASSERT(output->dimensions() <= MAX_DIM);
            int dims[MAX_DIM];
            int stride[MAX_DIM];
            int iStride0[MAX_DIM];
            int iStride1[MAX_DIM];
            for (int i = MAX_DIM - 1; i >= 0; --i) {
                dims[i]     = 1;
                stride[i]   = 0;
                iStride0[i] = 0;
                iStride1[i] = 0;
                int input0I = i - (output->dimensions() - input0->dimensions());
                int input1I = i - (output->dimensions() - input1->dimensions());
                if (i < output->dimensions()) {
                    dims[i]   = output->length(i);
                    stride[i] = output->stride(i);
                }
                if (input0I >= 0 && input0->length(input0I) != 1) {
                    iStride0[i] = input0->stride(input0I);
                }
                if (input1I >= 0 && input1->length(input1I) != 1) {
                    iStride1[i] = input1->stride(input1I);
                }
            }
            for (int w = 0; w < dims[5]; ++w) {
                auto ow  = outputData + w * stride[5];
                auto i0w = input0Data + w * iStride0[5];
                auto i1w = input1Data + w * iStride1[5];
#define PTR(x, y, i)                      \
    auto o##x  = o##y + x * stride[i];    \
    auto i0##x = i0##y + x * iStride0[i]; \
    auto i1##x = i1##y + x * iStride1[I]

                for (int v = 0; v < dims[4]; ++v) {
                    PTR(v, w, 4);
                    for (int u = 0; u < dims[3]; ++u) {
                        PTR(u, v, 3);
                        for (int z = 0; z < dims[2]; ++z) {
                            PTR(z, u, 2);
                            for (int y = 0; y < dims[1]; ++y) {
                                PTR(y, z, 1);
                                for (int x = 0; x < dims[0]; ++x) {
                                    PTR(x, y, 0);
                                    *ox = static_cast<Tout>(f(*i0x, *i1x));
                                }
                            }
                        }
                    }
                }
            }
#undef MAX_DIM
#undef PTR
        }
        // broadcast-capable check is done in compute size
    }

    return NO_ERROR;
}

4 可調(diào)式的Demo

MNN 的源碼確實(shí)不太容易閱讀,畢竟它在實(shí)現(xiàn)算法的同時(shí):

  1. 考慮的不同操作的兼容(不僅僅支持加法)
  2. 考慮了不同數(shù)據(jù)類型的兼容
  3. 基于 MNN 的數(shù)據(jù)結(jié)構(gòu)

但是不用擔(dān)心,像往常一樣,我的技術(shù)文章一般都會(huì)為大家配套一份簡(jiǎn)單的參考代碼,這份代碼相比 MNN 源碼:

  1. 只支持加法
  2. 只支持 float 操作類型
  3. 數(shù)據(jù)結(jié)構(gòu)即 float *std::vector<int> 的組合
  4. 不限維度的操作(即你可以進(jìn)行 20維的 float 相加操作)
  5. 左加數(shù)(A)Shape 可以和 輸出(C)Shape 不同(詳見(jiàn) 【5 思考 】
// cymv_add 點(diǎn)加主函數(shù)
// __add 為支持不限維度操作而實(shí)現(xiàn)的遞歸子函數(shù)

static void __add(int dimTag,
                  float *pC, std::vector<int> &rev_stepCs,
                  float *pA, std::vector<int> &rev_stepAs, std::vector<int> &dimAs,
                  float *pB, std::vector<int> &rev_stepBs, std::vector<int> &dimBs) {
    
    int dimNum = (int)dimAs.size();
    
    int curDimA = dimAs[dimTag];
    int curDimB = dimBs[dimTag];
    int curDimC = curDimA > curDimB ? curDimA : curDimB;
    
    int curStepA = rev_stepAs[dimNum - 1 - dimTag];
    int curStepB = rev_stepBs[dimNum - 1 - dimTag];
    int curStepC = rev_stepCs[dimNum - 1 - dimTag];
    
    float *tmppa = pA;
    float *tmppb = pB;
    float *tmppc = pC;
    for (int i = 0; i < curDimC; i++) {
        
        if (dimTag == dimNum - 1) {
            *tmppc = *tmppa + *tmppb;
        } else {
            __add(dimTag + 1,
                  tmppc, rev_stepCs,
                  tmppa, rev_stepAs, dimAs,
                  tmppb, rev_stepBs, dimBs);
        }
        
        tmppc += curStepC;
        tmppa += curStepA;
        tmppb += curStepB;
    }
}

void cymv_add(float *dst,
              float *src0,
              std::vector<int> shape0,
              float *src1,
              std::vector<int> shape1) {
    
    if (shape0.size() != shape1.size()) {
        printf("維度不等,無(wú)法計(jì)算");
        return;
    }
    
    /* 小維度在前,高維度在后 */
    std::vector<int> step0;
    std::vector<int> step1;
    std::vector<int> stepOut;
    
    int tmpStep0 = 1;
    int tmpStep1 = 1;
    int tmpStepOut = 1;
    for (int i = (int)(shape0.size()) - 1; i >= 0 ; i--) {
        
        if (1 == shape0[i]) {
            step0.push_back(0);
        } else {
            step0.push_back(tmpStep0);
        }
        if (1 == shape1[i]) {
            step1.push_back(0);
        } else {
            step1.push_back(tmpStep1);
        }
        stepOut.push_back(tmpStepOut);
        
        tmpStep0 *= shape0[i];
        tmpStep1 *= shape1[i];
        int maxVal = shape0[i] > shape1[i] ? shape0[i] : shape1[i];
        tmpStepOut *= maxVal;
    }
    
    __add(0,
          dst, stepOut,
          src0, step0, shape0,
          src1, step1, shape1);
}

GitHub可調(diào)式源碼鏈接

當(dāng)然,我說(shuō)好讀也是相對(duì)的,一樣要費(fèi)點(diǎn)心思哦!但好在有源碼能調(diào)試嘛!

5 思考

最后,這樣的場(chǎng)景我們的算法能支持么?

思考

下載 Demo 試試看咯!

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容