介紹
Caffe2中Blob的概念應(yīng)該來自于Caffe。它是有類型的內(nèi)存抽象,主要包含兩個(gè)成員,一為指向存儲(chǔ)元素的指針,另一則為此元素的類型(TypeMeta)。這么說來它其實(shí)與Tensor好像,本質(zhì)上它有些贅余,更像是來自Caffe的一種包袱。在筆者已知的框架設(shè)計(jì)里像Tensorflow/Pytorch/Mxnet等無不是只提供Tensor這么一種有類型的內(nèi)存抽象。不過在Caffe2中,框架設(shè)計(jì)者可能是不想它太多余,于是將Serialization(從而將weights存成string)的功能給了它。
Anyway,這樣我們?cè)谑褂肅affe Operator時(shí),會(huì)以Blob作為輸入、輸出(與Caffe一樣),只是在Operator內(nèi)部,一般需要使用Blob的data方法得到指向其元素的指針,然后再將它強(qiáng)制類型轉(zhuǎn)換為合適的類型T(一般為Tensor),再使用它進(jìn)行各種具體運(yùn)算。
Caffe2中與Blob相關(guān)的代碼如下。本節(jié)當(dāng)中我們將重點(diǎn)介紹其中所涉及的Blob/BlobSerializaer/BlobStats等類及相關(guān)功能函數(shù)。
core]$ ls blob
blob_gpu_test.cc blob_serialization_gpu.cc blob_stats.cc
blob.h blob_serialization.h blob_stats.h
blob_serialization.cc blob_serializer_base.h blob_test.cc
Blob
以下為Blob的基本描述,可見看出它只有兩個(gè)成員meta_與pointer_,分別表示指向存儲(chǔ)對(duì)象的指針以及此指針的類型。
/**
* @brief Blob is a general container that hosts a typed pointer.
*
* A Blob hosts a pointer as well as its type, and takes charge of deleting it
* properly when the blob is deallocated or re-allocated with a new type. A blob
* could contain anything, although the most common case is to contain a Tensor.
*/
class CAFFE2_API Blob final {
public:
using DestroyCall = void(void*);
/**
* Initializes an empty Blob.
*/
Blob() : meta_(), pointer_(nullptr) {}
~Blob() { Reset(); }
Blob(Blob&& other) noexcept
: meta_(std::move(other.meta_)),
pointer_(std::move(other.pointer_)),
destroy_(std::move(other.destroy_)) {
other.meta_ = {};
other.pointer_ = nullptr;
other.destroy_ = nullptr;
}
........
........
};
通過下面兩個(gè)成員函數(shù),我們可以檢查Blob所包含的對(duì)象的類型及是否是某種Device tensor類型等。
/**
* Checks if the content stored in the blob is of type T.
*/
template <class T>
bool IsType() const {
return meta_.Match<T>();
}
bool IsTensorType(DeviceType device_type) const {
bool is_match = meta_.Match<Tensor>();
auto* tensor = static_cast<Tensor*>(pointer_);
if (is_match && tensor && tensor->GetDeviceType() == device_type) {
return true;
}
return false;
}
以下為兩種得到有類型指針與裸對(duì)象指針的辦法。
/**
* @brief Gets the const reference of the stored object. The code checks if
* the stored object is of the desired type.
*/
// TODO(jerryzh): add a Get(DeviceType) function?
template <class T>
const T& Get() const {
CAFFE_ENFORCE(
IsType<T>(),
"wrong type for the Blob instance. Blob contains ",
meta_.name(),
" while caller expects ",
TypeMeta::TypeName<T>());
// TODO: after we add Get<Tensor>(DeviceType)
// and changed all the callsites, we can add
// a static assert here to enforce T != Tensor
return *static_cast<const T*>(pointer_);
}
const void* GetRaw() const {
return pointer_;
}
若想要對(duì)其存儲(chǔ)對(duì)象進(jìn)行寫操作,則需要調(diào)用mutable_data方法,如下所示。
/**
* @brief Gets a mutable pointer to the stored object.
*
* If the current object is not of the right type, a new object is created
* and the old object is freed. Note that type T should have a default
* constructor. Otherwise, create the object yourself first, and use
* Reset().
*/
template <class T>
T* GetMutable() {
static_assert(
std::is_default_constructible<T>::value,
"GetMutable can't be called with non-default-constructible types. "
"Try using specialized methods");
static_assert(
!std::is_same<T, Tensor>::value,
"Use GetMutableTensor(DeviceType) instead");
if (IsType<T>()) {
return static_cast<T*>(pointer_);
} else {
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<T>();
return Reset<T>(new T());
}
}
inline Tensor* GetMutableTensor(DeviceType device_type) {
if (IsTensorType(device_type)) {
return static_cast<Tensor*>(pointer_);
} else {
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " DeviceType:" << device_type;
return Reset<Tensor>(new Tensor(device_type));
}
}
Reset 成員函數(shù)將使Blob得到此傳入對(duì)象的ownership。在此前總要先釋放指之前擁有的對(duì)象的ownership。
/**
* Sets the underlying object to the allocated one. The Blob then takes over
* the ownership of the passed in pointer. If there is already an object in
* the Blob, the old object is freed.
*
* This is used when the underlying class T does not have a default ctor, or
* complex initializations needs to be done outside the blob.
*/
template <class T>
T* Reset(T* allocated) {
if (pointer_ && destroy_) {
destroy_(pointer_);
}
meta_ = TypeMeta::Make<T>();
pointer_ = static_cast<void*>(allocated);
destroy_ = &Destroy<T>;
return allocated;
}
inline void*
Reset(void* allocated, const TypeMeta& meta, DestroyCall* destroy) {
if (pointer_ && destroy_) {
destroy_(pointer_);
}
meta_ = meta;
pointer_ = static_cast<void*>(allocated);
destroy_ = destroy;
return allocated;
}
/**
* Resets the Blob to an empty one.
*/
inline void Reset() {
if (pointer_ && destroy_) {
destroy_(pointer_);
}
pointer_ = nullptr;
meta_ = TypeMeta();
destroy_ = nullptr;
}
ShareExternal與Reset相反,它只享用此傳入對(duì)象,但并不負(fù)責(zé)釋放它即并不需對(duì)它付責(zé)任。
/**
* Sets the underlying object to the allocated one, but does not take over
* the ownership of the passed in pointer. If there is already an object in
* the Blob, the old object is freed.
*
* Unlike Reset, this does not take over the ownership of the pointer and the
* caller is responsible for making sure that the lifetime of the allocated
* blob outlasts the lifetime of any access to this blob, until another Reset
* call is made or the blob is destructed.
*/
template <class T>
typename std::remove_const<T>::type* ShareExternal(
typename std::remove_const<T>::type* allocated) {
return static_cast<T*>(ShareExternal(
static_cast<void*>(allocated),
TypeMeta::Make<typename std::remove_const<T>::type>()));
}
void* ShareExternal(void* allocated, const TypeMeta& meta) {
if (pointer_ && destroy_) {
destroy_(pointer_);
}
meta_ = meta;
pointer_ = static_cast<void*>(allocated);
destroy_ = nullptr;
return allocated;
}
Blob承擔(dān)了部分Serialize的功能,可見所有的Weights需要放入到Blob里面正是需要仰仗它的這一功能來進(jìn)行checkpoints存取。
/**
* Serializes the current blob, if possible. Note that this serialization uses
* the registration mechanism and one has to implement specific serialization
* approaches for specific classes. Acceptor should take care of writing data
* to the actual storage.
*/
void Serialize(
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor,
int chunk_size = kDefaultChunkSize) const;
/**
* @brief Convenience function to serialize a blob to a string.
*
* This is a conveinence function to serialize small Blobs that produce
* manageable serialized strings. To serialize big blobs such as
* large sparse tensors, use the fully-functional interface in
* blob_serializer_base.h.
*
* NOTE: this function doesn't do chunking and might break with big tensors.
*/
string Serialize(const string& name) const;
/**
* Deserializes from a string containing either BlobProto or TensorProto. If
* the deserialization fails, the content in the blob should no longer be
* trusted.
*/
void Deserialize(const string& content);
void Deserialize(const BlobProto& proto);
最后則為Blob私有空間的一些成員與公共函數(shù)。Destroy是一個(gè)static 模板成員函數(shù),用在這里是再合適不過了。
private:
/**
* @brief A destroy call that is used to properly deconstruct objects.
*/
template <class T>
static void Destroy(void* pointer) {
delete static_cast<T*>(pointer);
}
TypeMeta meta_;
void* pointer_ = nullptr;
DestroyCall* destroy_ = nullptr;
AT_DISABLE_COPY_AND_ASSIGN(Blob);
};
BlobSerializerBase和BlobDeserializerBase
下面為BlobSerializerBase的概況,它是一個(gè)實(shí)現(xiàn)Blob serialization功能的虛基類。不同類型的Blob需要分別繼承它來實(shí)現(xiàn)自己的Serialization操作。
/**
* @brief BlobSerializerBase is an abstract class that serializes a blob to a
* string.
*
* This class exists purely for the purpose of registering type-specific
* serialization code. If you need to serialize a specific type, you should
* write your own Serializer class, and then register it using
* REGISTER_BLOB_SERIALIZER. For a detailed example, see TensorSerializer for
* details.
*/
class BlobSerializerBase {
public:
virtual ~BlobSerializerBase() {}
using SerializationAcceptor =
std::function<void(const std::string& blobName, const std::string& data)>;
下面為它的兩個(gè)主要的Serialization功能函數(shù)。
* @brief The virtual function that returns a serialized string for the input
* blob.
* @param blob
* the input blob to be serialized.
* @param name
* the blob name to be used in the serialization implementation. It is up
* to the implementation whether this name field is going to be used or
* not.
* @param acceptor
* a lambda which accepts key value pairs to save them to storage.
* serailizer can use it to save blob in several chunks
* acceptor should be thread-safe
*/
virtual void Serialize(const Blob& blob, const std::string& name,
SerializationAcceptor acceptor) = 0;
virtual void SerializeWithChunkSize(
const Blob& blob,
const std::string& name,
SerializationAcceptor acceptor,
int /*chunk_size*/) {
// Base implementation.
Serialize(blob, name, acceptor);
}
};
我們需要對(duì)每個(gè)類型Blob生成其特定的BlobSerializer子類。
// The Blob serialization registry and serializer creator functions.
CAFFE_DECLARE_TYPED_REGISTRY(
BlobSerializerRegistry,
TypeIdentifier,
BlobSerializerBase,
std::unique_ptr);
#define REGISTER_BLOB_SERIALIZER(id, ...) \
CAFFE_REGISTER_TYPED_CLASS(BlobSerializerRegistry, id, __VA_ARGS__)
// Creates an operator with the given operator definition.
inline unique_ptr<BlobSerializerBase> CreateSerializer(TypeIdentifier id) {
return BlobSerializerRegistry()->Create(id);
}
相對(duì)應(yīng)的有個(gè)Deserializer虛基類提供了Deserialization需要的函數(shù)接口。
/**
* @brief BlobDeserializerBase is an abstract class that deserializes a blob
* from a BlobProto or a TensorProto.
*/
class CAFFE2_API BlobDeserializerBase {
public:
virtual ~BlobDeserializerBase() {}
// Deserializes from a BlobProto object.
virtual void Deserialize(const BlobProto& proto, Blob* blob) = 0;
};
CAFFE_DECLARE_REGISTRY(BlobDeserializerRegistry, BlobDeserializerBase);
#define REGISTER_BLOB_DESERIALIZER(name, ...) \
CAFFE_REGISTER_CLASS(BlobDeserializerRegistry, name, __VA_ARGS__)
// Creates an operator with the given operator definition.
inline unique_ptr<BlobDeserializerBase> CreateDeserializer(const string& type) {
return BlobDeserializerRegistry()->Create(type);
}
TensorSerializer和TensorDeserializer
TensorSerializer為BlobSerializerBase的一個(gè)子類,顧名思義,它主要用來實(shí)現(xiàn)Tensor類型的Serialization操作。同樣還有一個(gè)為TensorDeserializer,它是BlobDeserializerBase的子類。
下面為在進(jìn)行Serialization時(shí)的細(xì)節(jié)實(shí)現(xiàn)??梢娭饕菍⑿枰念愋蛿?shù)據(jù)存到Protocol buffer里面,然后再使用它的功能來進(jìn)行serialization/deserialization。
namespace detail {
template <typename SrcType, typename DstType>
inline void CopyToProtoAsIs(
const size_t size,
const SrcType* src,
google::protobuf::RepeatedField<DstType>* field,
BaseContext* context) {
static_assert(
sizeof(SrcType) == sizeof(DstType),
"The source type and dest type cannot be copied as-is. Did "
"you mean CopyToProtoWithCast?");
field->Reserve(size);
for (int i = 0; i < size; ++i) {
field->Add(0);
}
context->template CopyToCPU<SrcType>(
size, src, reinterpret_cast<SrcType*>(field->mutable_data()));
// Make sure that we finish the copy into the protobuf.
context->FinishDeviceComputation();
}
template <typename SrcType, typename DstType>
inline void CopyToProtoWithCast(
const size_t size,
const SrcType* src,
google::protobuf::RepeatedField<DstType>* field,
BaseContext* context) {
// TODO: we are having one unnecessary copy here if the context is already
// CPUContext. Remove it if it is performance critical.
unique_ptr<SrcType[]> buffer(new SrcType[size]);
context->template CopyToCPU<SrcType>(size, src, buffer.get());
context->FinishDeviceComputation();
field->Reserve(size);
for (int i = 0; i < size; ++i) {
field->Add(static_cast<DstType>(buffer[i]));
}
}
以下為Blob里面的兩個(gè)Serialize函數(shù)實(shí)現(xiàn),可以看出它主要是借助不同類型的BlobSerializer來完成此功能。Deserialize函數(shù)的實(shí)現(xiàn)與此類似,在此不再贅述。
// The blob serialization member function implementation.
void Blob::Serialize(
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor,
int chunk_size) const {
std::unique_ptr<BlobSerializerBase> serializer(CreateSerializer(meta_.id()));
CAFFE_ENFORCE(serializer, "No known serializer for ", meta_.name());
serializer->SerializeWithChunkSize(*this, name, acceptor, chunk_size);
}
// The blob serialization member function implementation.
std::string Blob::Serialize(const string& name) const {
std::string data;
BlobSerializerBase::SerializationAcceptor acceptor = [&data](
const std::string&, const std::string& blob) {
DCHECK(data.empty()); // should be called once with kNoChunking
data = blob;
};
this->Serialize(name, acceptor, kNoChunking);
return data;
}
BlobStatGetter
Blob里面提供了些輔助類來提供些統(tǒng)計(jì)等功能如BlobStatGetter類。由下面代碼可以看出,它亦是通過Type Id來選擇使用不同的子類StatGetter的。
struct BlobStatGetter {
virtual size_t sizeBytes(const Blob& blob) const = 0;
virtual ~BlobStatGetter() {}
};
struct BlobStatRegistry {
private:
std::unordered_map<TypeIdentifier, std::unique_ptr<BlobStatGetter>> map_;
void doRegister(TypeIdentifier id, std::unique_ptr<BlobStatGetter>&& v);
public:
template <typename T, typename Getter>
struct Registrar {
Registrar() {
BlobStatRegistry::instance().doRegister(
TypeMeta::Id<T>(), std::unique_ptr<Getter>(new Getter));
}
};
const BlobStatGetter* get(TypeIdentifier id);
static BlobStatRegistry& instance();
};
以下為具體的對(duì)其使用。
const BlobStatGetter* BlobStatRegistry::get(TypeIdentifier id) {
auto it = map_.find(id);
if (it == map_.end()) {
return nullptr;
}
return it->second.get();
}
BlobStatRegistry& BlobStatRegistry::instance() {
static BlobStatRegistry registry;
return registry;
}
void BlobStatRegistry::doRegister(
TypeIdentifier id,
std::unique_ptr<BlobStatGetter>&& v) {
// don't use CAFFE_ENFORCE_EQ to avoid static initialization order fiasco.
if (map_.count(id) > 0) {
throw std::runtime_error("BlobStatRegistry: Type already registered.");
}
map_[id] = std::move(v);
}
namespace BlobStat {
size_t sizeBytes(const Blob& blob) {
auto* p = BlobStatRegistry::instance().get(blob.meta().id());
return p ? p->sizeBytes(blob) : 0;
}
} // namespace BlobStats