- 表達(dá)式模板是為了支持一種數(shù)值數(shù)組的類引入的技術(shù)。比如希望像內(nèi)置類型一樣對數(shù)組進(jìn)行下列操作,要在支持這種緊湊寫法的同時獲得高效率,就需要通過表達(dá)式模板來完成。表達(dá)式模板與元編程互補(bǔ),元編程主要用于小的、大小固定的數(shù)組,表達(dá)式模板則用于能在運(yùn)行期確定大小、中等大小的數(shù)組
Array<double> x(1000), y(1000);
x = 1.2 * x + x * y;
臨時變量與分割循環(huán)的問題
- 先看一種簡單的支持?jǐn)?shù)值數(shù)組操作的模板實(shí)現(xiàn)
template <typename T>
class A {
public:
explicit A(std::size_t i) : v(new T[i]), n(i) { init(); }
A(const A<T>& rhs) : v(new T[rhs.size()]), n(rhs.size()) { copy(rhs); }
~A() { delete[] v; }
A<T>& operator=(const A<T>& rhs) {
if (&rhs != this) copy(rhs);
return *this;
}
std::size_t size() const { return n; }
T& operator[](std::size_t i) { return v[i]; }
const T& operator[](std::size_t i) const { return v[i]; }
protected:
void init() {
for (std::size_t i = 0; i < size(); ++i) {
v[i] = T {}
};
}
void copy(const A<T>& rhs) {
assert(size() == rhs.size());
for (std::size_t i = 0; i < size(); ++i) {
v[i] = rhs.v[i];
}
}
private:
T* v;
std::size_t n;
};
template <typename T>
A<T> operator+(const A<T>& a, const A<T>& b) {
assert(a.size() == b.size());
A<T> res(a.size());
for (std::size_t i = 0; i < a.size(); ++i) {
res[i] = a[i] + b[i];
}
return res;
}
template <typename T>
A<T> operator*(const A<T>& a, const A<T>& b) {
assert(a.size() == b.size());
A<T> res(a.size());
for (std::size_t i = 0; i < a.size(); ++i) {
res[i] = a[i] * b[i];
}
return res;
}
template <typename T>
A<T> operator*(const T& s, const A<T>& a) {
A<T> res(a.size());
for (std::size_t i = 0; i < a.size(); ++i) {
res[i] = s * a[i];
}
return res;
}
- 使用上面這些運(yùn)算符即可完成表達(dá)式計算
A<double> x(1000), y(1000);
... x = 1.2 * x + x * y;
// 計算過程相當(dāng)于
tmp1 = 1.2 * x; // 循環(huán) 1000 次元素操作,并創(chuàng)建和刪除 tmp1
tmp2 = x * y; // 循環(huán) 1000 次元素操作,并創(chuàng)建和刪除 tmp2
tmp3 = tmp1 + tmp2; // 循環(huán) 1000 次讀寫操作,并創(chuàng)建和刪除 tmp3
x = tmp3; // 1000 次讀操作和寫操作
- 但這個實(shí)現(xiàn)非常低效,有兩方面原因
- 每個運(yùn)算符操作(除了賦值運(yùn)算符)至少要生成一個臨時數(shù)組,至少要生成 3 個大小為 1000 的臨時數(shù)組
- 每次使用運(yùn)算符都要求對實(shí)參和結(jié)果數(shù)組進(jìn)行額外遍歷,生成一個 A 對象,需要讀取 6000 次 double 值,寫入 4000 次 double 值
- 每個數(shù)值數(shù)組庫的實(shí)現(xiàn)都會面臨的一個問題是,對于元素很多的數(shù)組,沒有足夠的內(nèi)存容納這些臨時對象,因此通常使用 computed assignments(如
+=、*=)來代替前面的賦值運(yùn)算符,這樣就不需要創(chuàng)建任何臨時對象
template <typename T>
class A {
public:
A<T>& operator+=(const A<T>& b);
A<T>& operator*=(const A<T>& b);
A<T>& operator*=(const T& s);
};
template <class T>
A<T>& A<T>::operator+=(const A<T>& b) {
assert(size() == rhs.size());
for (std::size_t i = 0; i < size(); ++i) {
(*this)[i] += b[i];
}
return *this;
}
template <class T>
A<T>& A<T>::operator*=(const A<T>& b) {
assert(size() == rhs.size());
for (std::size_t i = 0; i < size(); ++i) {
(*this)[i] *= b[i];
}
return *this;
}
template <class T>
A<T>& A<T>::operator*=(const T& s) {
for (std::size_t i = 0; i < size(); ++i) {
(*this)[i] *= s;
}
return *this;
}
A<double> x(1000), y(1000);
A<double> tmp{x};
tmp *= y;
x *= 1.2;
x += tmp;
- 但這樣的缺點(diǎn)顯而易見,符號變得不雅觀,并且仍要創(chuàng)建一個非必要的局部變量 tmp,此外循環(huán)被分割成多個操作,且仍要對 double 進(jìn)行 6000 次讀和 4000 次寫操作。理想操作是針對數(shù)組每個下標(biāo),只對表達(dá)式循環(huán)一次
A<double> x(1000), y(1000);
for (int i = 0; i < x.size(); ++i) {
x[i] = 1.2 * x[i] + x[i] * y[i];
}
- 這樣不需要臨時數(shù)組,每次迭代只需要兩次讀(x[i] 和 y[i])和一次寫(x[i]),在循環(huán)中共需 2000 次讀和 1000 次寫,但需要手動循環(huán)。為了兼顧高性能和避免循環(huán),讓代碼更優(yōu)雅且減少錯誤產(chǎn)生,就需要用到下面的技術(shù)
在模板實(shí)參中編碼表達(dá)式
- 一個很好的解決思路是,直到看到整個表達(dá)式才對表達(dá)式各部分求值,在求值前只記錄每個對象和該對象上的每個操作。這些操作在編譯期已經(jīng)確定,因此可以編碼為模板實(shí)參
1.2 * x + x * y;
- 上面的表達(dá)式中,
1.2 * x不是一個新的數(shù)組,而是一個用于表示 x 的每個值都乘 1.2 的對象,并沒有真正進(jìn)行計算,x * y同理。這樣就能把上面的表達(dá)式轉(zhuǎn)為一個對象,其類型如下
A_Add<A_Mult<A_Scalar<double>, Array<double>>,
A_Mult<Array<double>, Array<double>>>
- 表達(dá)式
1.2 * x + x * y的前序語法樹的表示如下
+
/ \
* *
/ \ / \
1.2 x x y
Operand
- 為了完整表示整個表達(dá)式,在每個 A_Add 和 A_Mult 對象中必須存儲實(shí)參的引用,另外在 A_Scalar 對象中需要記錄這個表示放大倍數(shù)的值或引用,下面是對這些操作數(shù)的可行定義
template <typename T>
class A_Scalar {
public:
constexpr A_Scalar(const T& v) : val(v) {}
// 對于任意下標(biāo)都只返回scalar值
constexpr const T& operator[](std::size_t) const { return val; }
// scalar視為size為0的Array操作數(shù)
constexpr std::size_t size() const { return 0; };
private:
const T& val; // value of the scalar
};
template <typename T>
class A_Traits {
public:
using ExprRef = const T&;
};
template <typename T>
class A_Traits<A_Scalar<T>> {
public:
using ExprRef = A_Scalar<T>;
};
template <typename T, typename OP1, typename OP2>
class A_Add {
public:
A_Add(const OP1& a, const OP2& b) : op1(a), op2(b) {}
T operator[](std::size_t i) const { return op1[i] + op2[i]; }
// size為兩者中較大者
std::size_t size() const {
assert(op1.size() == 0 || op2.size() == 0 || op1.size() == op2.size());
return op1.size() != 0 ? op1.size() : op2.size();
}
private:
typename A_Traits<OP1>::ExprRef op1;
typename A_Traits<OP2>::ExprRef op2;
};
template <typename T, typename OP1, typename OP2>
class A_Mult {
public:
A_Mult(const OP1& a, const OP2& b) : op1(a), op2(b) {}
T operator[](std::size_t i) const { return op1[i] * op2[i]; }
std::size_t size() const {
assert(op1.size() == 0 || op2.size() == 0 || op1.size() == op2.size());
return op1.size() != 0 ? op1.size() : op2.size();
}
private:
typename A_Traits<OP1>::ExprRef op1;
typename A_Traits<OP2>::ExprRef op2;
};
- 使用 A_Traits 來定義操作數(shù)是必要的,A_Scalar 在運(yùn)算符函數(shù)內(nèi)部綁定,不能一直存在到完整表達(dá)式的求值,因此需要傳值拷貝而非傳引用
// 對操作數(shù)一般是 const&
const OP1& op1; // refer to first operand by reference
const OP2& op2; // refer to second operand by reference
// 但對 scalar 則是原始值
OP1 op1; // refer to first operand by value
OP2 op2; // refer to second operand by value
- 如果 A_Scalar 對象引用的是頂層定義的 scalar,對這些 scalar 也可以用引用類型
Array
- 下面創(chuàng)建一個 Array 類型,它能同時適用占用實(shí)際內(nèi)存的數(shù)組和表達(dá)式模板,接口設(shè)計上應(yīng)該與占用存儲空間的真實(shí)數(shù)組相似,同時要與基于數(shù)組的表達(dá)式具有相同表示
template <typename T, typename Rep = A<T>>
class Array {
private:
Rep expr_rep; // Rep是A(占內(nèi)存的數(shù)組)或A_Add、A_Mult等表達(dá)式模板
public:
explicit Array(std::size_t i) : expr_rep(i) {}
// 從可能的表達(dá)式創(chuàng)建
Array(const Rep& rb) : expr_rep(rb) {}
// 相同類型的賦值運(yùn)算符
Array& operator=(const Array& b) {
assert(size() == b.size());
for (std::size_t i = 0; i < b.size(); ++i) {
expr_rep[i] = b[i];
}
return *this;
}
// 不同類型的賦值運(yùn)算符
template <typename T2, typename Rep2>
Array& operator=(const Array<T2, Rep2>& b) {
assert(size() == b.size());
for (std::size_t i = 0; i < b.size(); ++i) {
expr_rep[i] = b[i];
}
return *this;
}
std::size_t size() const { return expr_rep.size(); }
T& operator[](std::size_t i) {
assert(i < size());
return expr_rep[i];
}
decltype(auto) operator[](std::size_t i) const {
assert(i < size());
return expr_rep[i];
}
Rep& rep() { return expr_rep; }
const Rep& rep() const { return expr_rep; }
};
Operator
- 目前只是實(shí)現(xiàn)了用于代表運(yùn)算符的、針對數(shù)值 Array 模板的運(yùn)算符操作(如 A_Add),但沒實(shí)現(xiàn)運(yùn)算符本身(如
operator+)。正如前面所說,這些運(yùn)算符只是用于代表表達(dá)式模板對象,實(shí)際上并不對結(jié)果數(shù)組求值。顯然對每個普通的二元運(yùn)算符,必須實(shí)現(xiàn)三個版本,即 array-array、array-scalar、scalar-array,如為了計算前面的表達(dá)式初始值需要用到下面的運(yùn)算符
// addition of two Arrays
template <typename T, typename R1, typename R2>
Array<T, A_Add<T, R1, R2>> operator+(const Array<T, R1>& a,
const Array<T, R2>& b) {
return Array<T, A_Add<T, R1, R2>>(A_Add<T, R1, R2>(a.rep(), b.rep()));
}
// multiplication of two Arrays
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T, R1, R2>> operator*(const Array<T, R1>& a,
const Array<T, R2>& b) {
return Array<T, A_Mult<T, R1, R2>>(A_Mult<T, R1, R2>(a.rep(), b.rep()));
}
// multiplication of scalar and Array
template <typename T, typename R2>
Array<T, A_Mult<T, A_Scalar<T>, R2>> operator*(const T& s,
const Array<T, R2>& b) {
return Array<T, A_Mult<T, A_Scalar<T>, R2>>(
A_Mult<T, A_Scalar<T>, R2>(A_Scalar<T>(s), b.rep()));
}
- 這些運(yùn)算符聲明看起來復(fù)雜,實(shí)際上函數(shù)做的工作不多。如對兩個Array的加法運(yùn)算符,首先生成一個用于A_Add對象,用于表示運(yùn)算符和操作數(shù)
A_Add<T, R1, R2>(a.rep(), b.rep())
- 接著把這個對象封裝到一個Array中,從而可以借助Array來操作這個運(yùn)算結(jié)果
return Array<T, A_Add<T, R1, R2>>(...);
- scalar和Array的乘法運(yùn)算符先使用了A_Scalar模板創(chuàng)建A_Mult對象
A_Mult<T, A_Scalar<T>, R2>(A_Scalar<T>(s), b.rep())
return Array<T, A_Mult<T, A_Scalar<T>, R2>>(...);
- 其他二元運(yùn)算符實(shí)現(xiàn)類似,也可以用宏來聲明這些運(yùn)算符,從而只需要使用較少的代碼
Review
- 現(xiàn)在進(jìn)行一個自頂向下的回顧。下面是要分析的代碼
Array<double> x(1000), y(1000);
x = 1.2 * x + x * y;
- 由于 x 和 y 的定義中省略了 Rep 實(shí)參,所以該參數(shù)使用默認(rèn)值
A<double>,因此 x 和 y 是占用真實(shí)內(nèi)存的數(shù)組,也就是說它們不只是用于記錄操作。解析表達(dá)式時,編譯器首先作用于左邊的乘號,它是一個 scalar-array 運(yùn)算符,于是重載解析規(guī)則選擇 scalar-array 形式的乘法運(yùn)算符
template <typename T, typename R2>
Array<T, A_Mult<T, A_Scalar<T>, R2>> operator*(const T& s,
const Array<T, R2>& b) {
return Array<T, A_Mult<T, A_Scalar<T>, R2>>(
A_Mult<T, A_Scalar<T>, R2>(A_Scalar<T>(s), b.rep()));
}
- 操作數(shù)類型是
double 和 Array<double, A<double>>,因此實(shí)際的結(jié)果類型
Array<double, A_Mult<double, A_Scalar<double>, A<double>>>
- 結(jié)果的值是一個構(gòu)造自 double 值 1.2 的
A_Scalar<double> 對象和一個表示 x 的 A<double> 對象的乘積
- 接著對第二個乘法求值,它是一個 array-array 操作,使用相應(yīng)的運(yùn)算符
template <typename T, typename R1, typename R2>
Array<T, A_Mult<T, R1, R2>> operator*(const Array<T, R1>& a,
const Array<T, R2>& b) {
return Array<T, A_Mult<T, R1, R2>>(A_Mult<T, R1, R2>(a.rep(), b.rep()));
}
- 兩個操作數(shù)類型都是
Array<double, A<double>>,因此結(jié)果類型為
Array<double, A_Mult<double, A<double>, A<double>>>
- 這次 A_Mult 所封裝的兩個參數(shù)對象都引用了一個
A<double>,一個用于表示 x 對象,另一個用于表示 y 對象
- 最后對加法運(yùn)算符求值,依然是 array-array 操作 ,操作數(shù)類型是上面兩個結(jié)果的類型,調(diào)用 array-array 的加法運(yùn)算符
template <typename T, typename R1, typename R2>
Array<T, A_Add<T, R1, R2>> operator+(const Array<T, R1>& a,
const Array<T, R2>& b) {
return Array<T, A_Add<T, R1, R2>>(A_Add<T, R1, R2>(a.rep(), b.rep()));
}
A_Mult<double, A_Scalar<double>, A<double>>
A_Mult<double, A<double>, A<double>>
- 最終賦值運(yùn)算符右邊的表達(dá)式類型為
Array<double,
A_Add<double,
A_Mult<double, A_Scalar<double>, A<double>>,
A_Mult<double, A<double>, A<double>>
>
>
- 這個類型與Array模板的賦值運(yùn)算符模板進(jìn)行匹配
template <typename T, typename Rep = A<T>>
class Array {
public:
// assignment operator for arrays of different type
template <typename T2, typename Rep2>
Array& operator=(const Array<T2, Rep2>& b) {
assert(size() == b.size());
for (std::size_t i = 0; i < b.size(); ++i) {
expr_rep[i] = b[i];
}
return *this;
}
};
- 賦值運(yùn)算符將使用參數(shù)b的下標(biāo)運(yùn)算符來計算目標(biāo)數(shù)組x的每個元素,參數(shù)的實(shí)際類型為
A_Add<double,
A_Mult<double, A_Scalar<double>, A<double>>,
A_Mult<double, A<double>, A<double>>
>
(1.2 * x[i]) + (x[i] * y[i])
Assignment
- 對于一個 Rep 實(shí)參基于 A_Mult 或 A_Add 表達(dá)式模板的數(shù)組,是不能為該數(shù)組實(shí)例化寫操作的(即編寫
a + b = c 的式子毫無意義),但可以編寫其他的表達(dá)式模板,從而能對這些表達(dá)式模板的結(jié)果賦值,如以具有整數(shù)值數(shù)組為下標(biāo)的索引操作通常會涉及到子集的選擇,即 x[y] = 2 * x[y] 等價于
for (std::size_t i = 0; i < y.size(); ++i) {
x[y[i]] = 2 * x[y[i]];
}
- 為了使上面這種寫法可行,必須令這種基于表達(dá)式模板的數(shù)組的行為能像一個左值,即可寫。這個表達(dá)式模板與 A_Mult 等類似,唯一區(qū)別在于它提供了下標(biāo)運(yùn)算符的 const 版本和 non-const 版本,并返回一個左值引用
template <typename T, typename A1, typename A2>
class A_Subscript {
public:
A_Subscript(const A1& a, const A2& b) : a1(a), a2(b) {}
T& operator[](std::size_t i) { return a1[a2[i]]; }
decltype(auto) operator[](std::size_t i) const { return a1[a2[i]]; }
std::size_t size() const { return a2.size(); }
private:
const A1& a1; // reference to first operand
const A2& a2; // reference to second operand
};
- 接著在Array中添加下標(biāo)運(yùn)算符模板即可
template <typename T, typename Rep = A<T>>
class Array {
public:
template <class T2, class Rep2>
Array<T, A_Subscript<T, Rep, Rep2>> operator[](const Array<T2, Rep2>& b) {
return Array<T, A_Subscript<T, Rep, Rep2>>(
A_Subscript<T, Rep, Rep2>(*this, b));
}
template <class T2, class Rep2>
decltype(auto) operator[](const Array<T2, Rep2>& b) const {
return Array<T, A_Subscript<T, Rep, Rep2>>(
A_Subscript<T, Rep, Rep2>(*this, b));
}
};
性能與約束
- 表達(dá)式模板可以提高數(shù)組操作性能,跟蹤其行為可以發(fā)現(xiàn)許多很小的內(nèi)聯(lián)函數(shù)互相調(diào)用,在調(diào)用堆棧還分配了許多小的表達(dá)式模板對象,因此編譯器必須執(zhí)行完整的內(nèi)聯(lián)小對象和去除小對象操作,以產(chǎn)生性能上和手寫循環(huán)媲美的代碼
- 表達(dá)式模板并沒有解決所有涉及數(shù)組數(shù)值操作的問題,如對
x = A * x 這種矩陣 - vector 乘法,x 是一個大小為 n 的 vector,A 是一個n * n 矩陣。問題在于臨時變量的使用不可避免,因?yàn)樽罱K結(jié)果的每個元素都依賴于最初 x 的每個元素,而表達(dá)式模板會在一次計算上馬上更新x的首個元素,計算下一個元素時用到這個已更新的元素就改變了原來的數(shù)組
- 但針對
x = A * y,如果 x 和 y 不互為別名,就不需要一個臨時對象,這意味著必須在運(yùn)行期知道操作數(shù)是否為別名關(guān)系,反過來又說明必須生成一個用于表示表達(dá)式樹的運(yùn)行期結(jié)構(gòu),而不是在表達(dá)式模板的類型中編碼這棵樹