24 表達(dá)式模板

  • 表達(dá)式模板是為了支持一種數(shù)值數(shù)組的類引入的技術(shù)。比如希望像內(nèi)置類型一樣對數(shù)組進(jìn)行下列操作,要在支持這種緊湊寫法的同時獲得高效率,就需要通過表達(dá)式模板來完成。表達(dá)式模板與元編程互補(bǔ),元編程主要用于小的、大小固定的數(shù)組,表達(dá)式模板則用于能在運(yùn)行期確定大小、中等大小的數(shù)組
Array<double> x(1000), y(1000);
x = 1.2 * x + x * y;

臨時變量與分割循環(huán)的問題

  • 先看一種簡單的支持?jǐn)?shù)值數(shù)組操作的模板實(shí)現(xiàn)
template <typename T>
class A {
 public:
  explicit A(std::size_t i) : v(new T[i]), n(i) { init(); }
  A(const A<T>& rhs) : v(new T[rhs.size()]), n(rhs.size()) { copy(rhs); }
  ~A() { delete[] v; }
  A<T>& operator=(const A<T>& rhs) {
    if (&rhs != this) copy(rhs);
    return *this;
  }
  std::size_t size() const { return n; }
  T& operator[](std::size_t i) { return v[i]; }
  const T& operator[](std::size_t i) const { return v[i]; }

 protected:
  void init() {
    for (std::size_t i = 0; i < size(); ++i) {
      v[i] = T {}
    };
  }
  void copy(const A<T>& rhs) {
    assert(size() == rhs.size());
    for (std::size_t i = 0; i < size(); ++i) {
      v[i] = rhs.v[i];
    }
  }

 private:
  T* v;
  std::size_t n;
};

template <typename T>
A<T> operator+(const A<T>& a, const A<T>& b) {
  assert(a.size() == b.size());
  A<T> res(a.size());
  for (std::size_t i = 0; i < a.size(); ++i) {
    res[i] = a[i] + b[i];
  }
  return res;
}

template <typename T>
A<T> operator*(const A<T>& a, const A<T>& b) {
  assert(a.size() == b.size());
  A<T> res(a.size());
  for (std::size_t i = 0; i < a.size(); ++i) {
    res[i] = a[i] * b[i];
  }
  return res;
}

template <typename T>
A<T> operator*(const T& s, const A<T>& a) {
  A<T> res(a.size());
  for (std::size_t i = 0; i < a.size(); ++i) {
    res[i] = s * a[i];
  }
  return res;
}
  • 使用上面這些運(yùn)算符即可完成表達(dá)式計算
A<double> x(1000), y(1000);
... x = 1.2 * x + x * y;

// 計算過程相當(dāng)于
tmp1 = 1.2 * x;      // 循環(huán) 1000 次元素操作,并創(chuàng)建和刪除 tmp1
tmp2 = x * y;        // 循環(huán) 1000 次元素操作,并創(chuàng)建和刪除 tmp2
tmp3 = tmp1 + tmp2;  // 循環(huán) 1000 次讀寫操作,并創(chuàng)建和刪除 tmp3
x = tmp3;            // 1000 次讀操作和寫操作
  • 但這個實(shí)現(xiàn)非常低效,有兩方面原因
    • 每個運(yùn)算符操作(除了賦值運(yùn)算符)至少要生成一個臨時數(shù)組,至少要生成 3 個大小為 1000 的臨時數(shù)組
    • 每次使用運(yùn)算符都要求對實(shí)參和結(jié)果數(shù)組進(jìn)行額外遍歷,生成一個 A 對象,需要讀取 6000 次 double 值,寫入 4000 次 double 值
  • 每個數(shù)值數(shù)組庫的實(shí)現(xiàn)都會面臨的一個問題是,對于元素很多的數(shù)組,沒有足夠的內(nèi)存容納這些臨時對象,因此通常使用 computed assignments(如 +=、*=)來代替前面的賦值運(yùn)算符,這樣就不需要創(chuàng)建任何臨時對象
template <typename T>
class A {
 public:
  A<T>& operator+=(const A<T>& b);
  A<T>& operator*=(const A<T>& b);
  A<T>& operator*=(const T& s);
};

template <class T>
A<T>& A<T>::operator+=(const A<T>& b) {
  assert(size() == rhs.size());
  for (std::size_t i = 0; i < size(); ++i) {
    (*this)[i] += b[i];
  }
  return *this;
}

template <class T>
A<T>& A<T>::operator*=(const A<T>& b) {
  assert(size() == rhs.size());
  for (std::size_t i = 0; i < size(); ++i) {
    (*this)[i] *= b[i];
  }
  return *this;
}

template <class T>
A<T>& A<T>::operator*=(const T& s) {
  for (std::size_t i = 0; i < size(); ++i) {
    (*this)[i] *= s;
  }
  return *this;
}
  • 之前的表達(dá)式現(xiàn)在可以改寫為
A<double> x(1000), y(1000);
A<double> tmp{x};
tmp *= y;
x *= 1.2;
x += tmp;
  • 但這樣的缺點(diǎn)顯而易見,符號變得不雅觀,并且仍要創(chuàng)建一個非必要的局部變量 tmp,此外循環(huán)被分割成多個操作,且仍要對 double 進(jìn)行 6000 次讀和 4000 次寫操作。理想操作是針對數(shù)組每個下標(biāo),只對表達(dá)式循環(huán)一次
A<double> x(1000), y(1000);

for (int i = 0; i < x.size(); ++i) {
  x[i] = 1.2 * x[i] + x[i] * y[i];
}
  • 這樣不需要臨時數(shù)組,每次迭代只需要兩次讀(x[i] 和 y[i])和一次寫(x[i]),在循環(huán)中共需 2000 次讀和 1000 次寫,但需要手動循環(huán)。為了兼顧高性能和避免循環(huán),讓代碼更優(yōu)雅且減少錯誤產(chǎn)生,就需要用到下面的技術(shù)

在模板實(shí)參中編碼表達(dá)式

  • 一個很好的解決思路是,直到看到整個表達(dá)式才對表達(dá)式各部分求值,在求值前只記錄每個對象和該對象上的每個操作。這些操作在編譯期已經(jīng)確定,因此可以編碼為模板實(shí)參
1.2 * x + x * y;
  • 上面的表達(dá)式中,1.2 * x不是一個新的數(shù)組,而是一個用于表示 x 的每個值都乘 1.2 的對象,并沒有真正進(jìn)行計算,x * y同理。這樣就能把上面的表達(dá)式轉(zhuǎn)為一個對象,其類型如下
A_Add<A_Mult<A_Scalar<double>, Array<double>>,
      A_Mult<Array<double>, Array<double>>>
  • 表達(dá)式1.2 * x + x * y的前序語法樹的表示如下
     +
   /   \
  *     *
 / \   / \
1.2 x x   y

Operand

  • 為了完整表示整個表達(dá)式,在每個 A_Add 和 A_Mult 對象中必須存儲實(shí)參的引用,另外在 A_Scalar 對象中需要記錄這個表示放大倍數(shù)的值或引用,下面是對這些操作數(shù)的可行定義
template <typename T>
class A_Scalar {
 public:
  constexpr A_Scalar(const T& v) : val(v) {}
  // 對于任意下標(biāo)都只返回scalar值
  constexpr const T& operator[](std::size_t) const { return val; }
  // scalar視為size為0的Array操作數(shù)
  constexpr std::size_t size() const { return 0; };

 private:
  const T& val;  // value of the scalar
};

template <typename T>
class A_Traits {
 public:
  using ExprRef = const T&;
};

template <typename T>
class A_Traits<A_Scalar<T>> {
 public:
  using ExprRef = A_Scalar<T>;
};

template <typename T, typename OP1, typename OP2>
class A_Add {
 public:
  A_Add(const OP1& a, const OP2& b) : op1(a), op2(b) {}

  T operator[](std::size_t i) const { return op1[i] + op2[i]; }

  // size為兩者中較大者
  std::size_t size() const {
    assert(op1.size() == 0 || op2.size() == 0 || op1.size() == op2.size());
    return op1.size() != 0 ? op1.size() : op2.size();
  }

 private:
  typename A_Traits<OP1>::ExprRef op1;
  typename A_Traits<OP2>::ExprRef op2;
};

template <typename T, typename OP1, typename OP2>
class A_Mult {
 public:
  A_Mult(const OP1& a, const OP2& b) : op1(a), op2(b) {}

  T operator[](std::size_t i) const { return op1[i] * op2[i]; }

  std::size_t size() const {
    assert(op1.size() == 0 || op2.size() == 0 || op1.size() == op2.size());
    return op1.size() != 0 ? op1.size() : op2.size();
  }

 private:
  typename A_Traits<OP1>::ExprRef op1;
  typename A_Traits<OP2>::ExprRef op2;
};
  • 使用 A_Traits 來定義操作數(shù)是必要的,A_Scalar 在運(yùn)算符函數(shù)內(nèi)部綁定,不能一直存在到完整表達(dá)式的求值,因此需要傳值拷貝而非傳引用
// 對操作數(shù)一般是 const&
const OP1& op1;  // refer to first operand by reference
const OP2& op2;  // refer to second operand by reference
// 但對 scalar 則是原始值
OP1 op1;  // refer to first operand by value
OP2 op2;  // refer to second operand by value
  • 如果 A_Scalar 對象引用的是頂層定義的 scalar,對這些 scalar 也可以用引用類型

Array

  • 下面創(chuàng)建一個 Array 類型,它能同時適用占用實(shí)際內(nèi)存的數(shù)組和表達(dá)式模板,接口設(shè)計上應(yīng)該與占用存儲空間的真實(shí)數(shù)組相似,同時要與基于數(shù)組的表達(dá)式具有相同表示
template <typename T, typename Rep = A<T>>
class Array {
 private:
  Rep expr_rep;  // Rep是A(占內(nèi)存的數(shù)組)或A_Add、A_Mult等表達(dá)式模板
 public:
  explicit Array(std::size_t i) : expr_rep(i) {}

  // 從可能的表達(dá)式創(chuàng)建
  Array(const Rep& rb) : expr_rep(rb) {}

  // 相同類型的賦值運(yùn)算符
  Array& operator=(const Array& b) {
    assert(size() == b.size());
    for (std::size_t i = 0; i < b.size(); ++i) {
      expr_rep[i] = b[i];
    }
    return *this;
  }

  // 不同類型的賦值運(yùn)算符
  template <typename T2, typename Rep2>
  Array& operator=(const Array<T2, Rep2>& b) {
    assert(size() == b.size());
    for (std::size_t i = 0; i < b.size(); ++i) {
      expr_rep[i] = b[i];
    }
    return *this;
  }

  std::size_t size() const { return expr_rep.size(); }

  T& operator[](std::size_t i) {
    assert(i < size());
    return expr_rep[i];
  }

  decltype(auto) operator[](std::size_t i) const {
    assert(i < size());
    return expr_rep[i];
  }

  Rep& rep() { return expr_rep; }
  const Rep& rep() const { return expr_rep; }
};

Operator

  • 目前只是實(shí)現(xiàn)了用于代表運(yùn)算符的、針對數(shù)值 Array 模板的運(yùn)算符操作(如 A_Add),但沒實(shí)現(xiàn)運(yùn)算符本身(如 operator+)。正如前面所說,這些運(yùn)算符只是用于代表表達(dá)式模板對象,實(shí)際上并不對結(jié)果數(shù)組求值。顯然對每個普通的二元運(yùn)算符,必須實(shí)現(xiàn)三個版本,即 array-array、array-scalar、scalar-array,如為了計算前面的表達(dá)式初始值需要用到下面的運(yùn)算符
// addition of two Arrays
template <typename T, typename R1, typename R2>
Array<T, A_Add<T, R1, R2>> operator+(const Array<T, R1>& a,
                                     const Array<T, R2>& b) {
  return Array<T, A_Add<T, R1, R2>>(A_Add<T, R1, R2>(a.rep(), b.rep()));
}

// multiplication of two Arrays
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T, R1, R2>> operator*(const Array<T, R1>& a,
                                      const Array<T, R2>& b) {
  return Array<T, A_Mult<T, R1, R2>>(A_Mult<T, R1, R2>(a.rep(), b.rep()));
}

// multiplication of scalar and Array
template <typename T, typename R2>
Array<T, A_Mult<T, A_Scalar<T>, R2>> operator*(const T& s,
                                               const Array<T, R2>& b) {
  return Array<T, A_Mult<T, A_Scalar<T>, R2>>(
      A_Mult<T, A_Scalar<T>, R2>(A_Scalar<T>(s), b.rep()));
}
  • 這些運(yùn)算符聲明看起來復(fù)雜,實(shí)際上函數(shù)做的工作不多。如對兩個Array的加法運(yùn)算符,首先生成一個用于A_Add對象,用于表示運(yùn)算符和操作數(shù)
A_Add<T, R1, R2>(a.rep(), b.rep())
  • 接著把這個對象封裝到一個Array中,從而可以借助Array來操作這個運(yùn)算結(jié)果
return Array<T, A_Add<T, R1, R2>>(...);
  • scalar和Array的乘法運(yùn)算符先使用了A_Scalar模板創(chuàng)建A_Mult對象
A_Mult<T, A_Scalar<T>, R2>(A_Scalar<T>(s), b.rep())
  • 接著同樣封裝到Array中
return Array<T, A_Mult<T, A_Scalar<T>, R2>>(...);
  • 其他二元運(yùn)算符實(shí)現(xiàn)類似,也可以用宏來聲明這些運(yùn)算符,從而只需要使用較少的代碼

Review

  • 現(xiàn)在進(jìn)行一個自頂向下的回顧。下面是要分析的代碼
Array<double> x(1000), y(1000);
x = 1.2 * x + x * y;
  • 由于 x 和 y 的定義中省略了 Rep 實(shí)參,所以該參數(shù)使用默認(rèn)值 A<double>,因此 x 和 y 是占用真實(shí)內(nèi)存的數(shù)組,也就是說它們不只是用于記錄操作。解析表達(dá)式時,編譯器首先作用于左邊的乘號,它是一個 scalar-array 運(yùn)算符,于是重載解析規(guī)則選擇 scalar-array 形式的乘法運(yùn)算符
template <typename T, typename R2>
Array<T, A_Mult<T, A_Scalar<T>, R2>> operator*(const T& s,
                                               const Array<T, R2>& b) {
  return Array<T, A_Mult<T, A_Scalar<T>, R2>>(
      A_Mult<T, A_Scalar<T>, R2>(A_Scalar<T>(s), b.rep()));
}
  • 操作數(shù)類型是 doubleArray<double, A<double>>,因此實(shí)際的結(jié)果類型
Array<double, A_Mult<double, A_Scalar<double>, A<double>>>
  • 結(jié)果的值是一個構(gòu)造自 double 值 1.2 的 A_Scalar<double> 對象和一個表示 x 的 A<double> 對象的乘積
  • 接著對第二個乘法求值,它是一個 array-array 操作,使用相應(yīng)的運(yùn)算符
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T, R1, R2>> operator*(const Array<T, R1>& a,
                                      const Array<T, R2>& b) {
  return Array<T, A_Mult<T, R1, R2>>(A_Mult<T, R1, R2>(a.rep(), b.rep()));
}
  • 兩個操作數(shù)類型都是 Array<double, A<double>>,因此結(jié)果類型為
Array<double, A_Mult<double, A<double>, A<double>>>
  • 這次 A_Mult 所封裝的兩個參數(shù)對象都引用了一個 A<double>,一個用于表示 x 對象,另一個用于表示 y 對象
  • 最后對加法運(yùn)算符求值,依然是 array-array 操作 ,操作數(shù)類型是上面兩個結(jié)果的類型,調(diào)用 array-array 的加法運(yùn)算符
template <typename T, typename R1, typename R2>
Array<T, A_Add<T, R1, R2>> operator+(const Array<T, R1>& a,
                                     const Array<T, R2>& b) {
  return Array<T, A_Add<T, R1, R2>>(A_Add<T, R1, R2>(a.rep(), b.rep()));
}
  • R1替換為
A_Mult<double, A_Scalar<double>, A<double>>
  • R2替換為
A_Mult<double, A<double>, A<double>>
  • 最終賦值運(yùn)算符右邊的表達(dá)式類型為
Array<double,
  A_Add<double,
      A_Mult<double, A_Scalar<double>, A<double>>,
      A_Mult<double, A<double>, A<double>>
    >
  >
  • 這個類型與Array模板的賦值運(yùn)算符模板進(jìn)行匹配
template <typename T, typename Rep = A<T>>
class Array {
 public:
  // assignment operator for arrays of different type
  template <typename T2, typename Rep2>
  Array& operator=(const Array<T2, Rep2>& b) {
    assert(size() == b.size());
    for (std::size_t i = 0; i < b.size(); ++i) {
      expr_rep[i] = b[i];
    }
    return *this;
  }
};
  • 賦值運(yùn)算符將使用參數(shù)b的下標(biāo)運(yùn)算符來計算目標(biāo)數(shù)組x的每個元素,參數(shù)的實(shí)際類型為
A_Add<double,
    A_Mult<double, A_Scalar<double>, A<double>>,
    A_Mult<double, A<double>, A<double>>
  >
  • 對下標(biāo)i,得出的x為
(1.2 * x[i]) + (x[i] * y[i])
  • 這正是所期望計算的表達(dá)式

Assignment

  • 對于一個 Rep 實(shí)參基于 A_Mult 或 A_Add 表達(dá)式模板的數(shù)組,是不能為該數(shù)組實(shí)例化寫操作的(即編寫 a + b = c 的式子毫無意義),但可以編寫其他的表達(dá)式模板,從而能對這些表達(dá)式模板的結(jié)果賦值,如以具有整數(shù)值數(shù)組為下標(biāo)的索引操作通常會涉及到子集的選擇,即 x[y] = 2 * x[y] 等價于
for (std::size_t i = 0; i < y.size(); ++i) {
  x[y[i]] = 2 * x[y[i]];
}
  • 為了使上面這種寫法可行,必須令這種基于表達(dá)式模板的數(shù)組的行為能像一個左值,即可寫。這個表達(dá)式模板與 A_Mult 等類似,唯一區(qū)別在于它提供了下標(biāo)運(yùn)算符的 const 版本和 non-const 版本,并返回一個左值引用
template <typename T, typename A1, typename A2>
class A_Subscript {
 public:
  A_Subscript(const A1& a, const A2& b) : a1(a), a2(b) {}

  T& operator[](std::size_t i) { return a1[a2[i]]; }

  decltype(auto) operator[](std::size_t i) const { return a1[a2[i]]; }

  std::size_t size() const { return a2.size(); }

 private:
  const A1& a1;  // reference to first operand
  const A2& a2;  // reference to second operand
};
  • 接著在Array中添加下標(biāo)運(yùn)算符模板即可
template <typename T, typename Rep = A<T>>
class Array {
 public:
  template <class T2, class Rep2>
  Array<T, A_Subscript<T, Rep, Rep2>> operator[](const Array<T2, Rep2>& b) {
    return Array<T, A_Subscript<T, Rep, Rep2>>(
        A_Subscript<T, Rep, Rep2>(*this, b));
  }

  template <class T2, class Rep2>
  decltype(auto) operator[](const Array<T2, Rep2>& b) const {
    return Array<T, A_Subscript<T, Rep, Rep2>>(
        A_Subscript<T, Rep, Rep2>(*this, b));
  }
};

性能與約束

  • 表達(dá)式模板可以提高數(shù)組操作性能,跟蹤其行為可以發(fā)現(xiàn)許多很小的內(nèi)聯(lián)函數(shù)互相調(diào)用,在調(diào)用堆棧還分配了許多小的表達(dá)式模板對象,因此編譯器必須執(zhí)行完整的內(nèi)聯(lián)小對象和去除小對象操作,以產(chǎn)生性能上和手寫循環(huán)媲美的代碼
  • 表達(dá)式模板并沒有解決所有涉及數(shù)組數(shù)值操作的問題,如對 x = A * x 這種矩陣 - vector 乘法,x 是一個大小為 n 的 vector,A 是一個n * n 矩陣。問題在于臨時變量的使用不可避免,因?yàn)樽罱K結(jié)果的每個元素都依賴于最初 x 的每個元素,而表達(dá)式模板會在一次計算上馬上更新x的首個元素,計算下一個元素時用到這個已更新的元素就改變了原來的數(shù)組
  • 但針對 x = A * y,如果 x 和 y 不互為別名,就不需要一個臨時對象,這意味著必須在運(yùn)行期知道操作數(shù)是否為別名關(guān)系,反過來又說明必須生成一個用于表示表達(dá)式樹的運(yùn)行期結(jié)構(gòu),而不是在表達(dá)式模板的類型中編碼這棵樹
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容