使用boost::asio::coroutine重構(gòu)網(wǎng)絡(luò)程序

本文中代碼所做的改進(jìn),

  1. chat_message的消息字段在堆上分配,使用std::shared_ptr<char> 進(jìn)行管理。這樣可以支持比較長的消息體。本文支持的最長長度為655360,我感覺這個(gè)夠了。
    const int MAX_BODY_LEN = 655360;

  2. 使用boost::asio::spawn方法,將異步的網(wǎng)絡(luò)操作轉(zhuǎn)換成同步寫法,簡化代碼邏輯。

  3. 抽取master和slave中的公共代碼部分。提取一個(gè)虛基類 chat_client。
    子類重寫純虛函數(shù)receive_msg,實(shí)現(xiàn)自己的邏輯。

  4. json構(gòu)造部分重構(gòu),使用對象構(gòu)造法,更易于理解。

代碼如下,
CMakeLists.txt

cmake_minimum_required(VERSION 2.6)
project(perf_tool)

add_definitions(-std=c++14)
add_definitions(-g)

set(BOOST_DIR /usr/local/Cellar/boost/1.76.0)
find_package(Boost REQUIRED COMPONENTS 
    system
    filesystem
    serialization
    program_options
    coroutine
    )

include_directories(${Boost_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/../../include)

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
foreach( sourcefile ${APP_SOURCES} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${sourcefile})
    if( NOT ${filename} MATCHES "parse_msg.cpp" )
        string(REPLACE ".cpp" "" file ${filename})
        add_executable(${file} ${sourcefile} "parse_msg.cpp")
        target_link_libraries(${file} ${Boost_LIBRARIES})
        target_link_libraries(${file} pthread)
    endif( NOT ${filename} MATCHES "parse_msg.cpp" )
endforeach( sourcefile ${APP_SOURCES} )

chat_client.h

#ifndef _FREDRIC_CHAT_CLIENT_H_
#define _FREDRIC_CHAT_CLIENT_H_

#include "chat_message.h"

#include <boost/asio.hpp>
#include <boost/asio/spawn.hpp>
#include <cstdlib>
#include <deque>
#include <iostream>
#include <thread>


using boost::asio::ip::tcp;

using chat_message_queue = std::deque<chat_message>;

class chat_client {
   public:
    chat_client(boost::asio::io_service& io_service,
                tcp::resolver::iterator endpoint_iterator)
        : io_service_(io_service), socket_(io_service) {
        do_connect(endpoint_iterator);
    }

    void write(const chat_message& msg) {
        // write是由主線程往子線程寫東西
        // 所以需要使用post提交到子線程運(yùn)行
        // 使得所有io操作都由io_service的子線程掌握
        io_service_.post([this, msg]() {
            bool write_in_progress = !write_msgs_.empty();
            write_msgs_.push_back(msg);
            if (!write_in_progress) {
                do_write();
            }
        });
    }

    void close() {
        io_service_.post([this]() { socket_.close(); });
    }

   protected:
    void do_connect(tcp::resolver::iterator endpoint_iterator) {
        boost::asio::spawn(
            socket_.get_executor(),
            [this, endpoint_iterator](boost::asio::yield_context yield) {
                boost::system::error_code conn_ec;
                boost::asio::async_connect(socket_, endpoint_iterator,
                                           yield[conn_ec]);
                if (!conn_ec) {
                    do_read_header_and_body();
                }
            });
    }

    void do_read_header_and_body() {
        boost::asio::spawn(
            socket_.get_executor(), [this](boost::asio::yield_context yield) {
                while (true) {
                    boost::system::error_code header_ec;
                    boost::asio::async_read(
                        socket_,
                        boost::asio::buffer(read_msg_.data(),
                                            chat_message::header_length),
                        yield[header_ec]);

                    if (!header_ec && read_msg_.decode_header()) {
                        boost::system::error_code body_ec;
                        // 如果沒有錯(cuò)誤,并且Decode_header成功,成功讀取到body_length
                        boost::asio::async_read(
                            socket_,
                            boost::asio::buffer(read_msg_.body(),
                                                read_msg_.body_length()),
                            yield[body_ec]);
                        bool stop = receive_msg(body_ec);
                        if(stop) {
                            break;
                        }
                    } else {
                        // 讀取失敗時(shí)關(guān)閉與服務(wù)端的連接,退出事件循環(huán)
                        socket_.close();
                    }
                }
            });
    }

    json to_json() {
        std::string buffer(read_msg_.body(),
                           read_msg_.body() + read_msg_.body_length());
        std::stringstream ss(buffer);
        auto json_obj = json::parse(ss.str());
        return json_obj;
    }

    virtual bool receive_msg(const boost::system::error_code& ec) = 0;

    // 向服務(wù)端真正發(fā)送消息的函數(shù)
    void do_write() {
        boost::asio::spawn(
            socket_.get_executor(), [this](boost::asio::yield_context yield) {
                boost::system::error_code ec;

                boost::asio::async_write(
                    socket_,
                    boost::asio::buffer(write_msgs_.front().data(),
                                        write_msgs_.front().length()),
                    yield[ec]);

                if (!ec) {
                    // 一直寫直到寫完
                    write_msgs_.pop_front();
                    if (!write_msgs_.empty()) {
                        do_write();
                    }
                } else {
                    socket_.close();
                }
            });
    }

    // 注意使用了引用類型,
    // io_service對象的生命周期必須要大于chat_client對象的生命周期
    // 否則會出現(xiàn)引用失效,導(dǎo)致異常
    boost::asio::io_service& io_service_;
    tcp::socket socket_;
    chat_message read_msg_;
    chat_message_queue write_msgs_;
};

bool parse_and_send_a_message(chat_client& c,
                              const std::string& input_msg_str) {
    chat_message msg;
    auto type = 0;

    std::string output;
    if (parseMessage(input_msg_str, &type, output)) {
        msg.setMessage(type, output.data(), output.size());
        c.write(msg);
        return true;
    } else {
        std::cerr << "Parse message error!" << std::endl;
        return false;
    }
}

#endif

chat_message.h

#ifndef _CHAT_MESSAGE_H_
#define _CHAT_MESSAGE_H_

#include "parse_msg.h"
#include "const.h"

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cassert>
#include <iostream>
#include <memory>

class chat_message {
    public:
        // Header的大小變?yōu)?個(gè)字節(jié),使用sizeof關(guān)鍵字進(jìn)行計(jì)算
        enum { header_length = sizeof(Header) };
        enum { max_body_length = MAX_BODY_LEN};

        chat_message() {
            data_ = std::shared_ptr<char>(new char[header_length+max_body_length], std::default_delete<char[]>());
        }
        
        // 這里返回的data不可以修改
        const char* data() const { return data_.get(); }
        char* data() { return data_.get(); }
        
        // 計(jì)算總長度時(shí),需要通過m_header獲取到bodySize
        std::size_t length() const { return header_length + m_header.bodySize; }
        
        // body為 data_往后面移動 head_length個(gè)字節(jié)
        const char* body() const { return data_.get() + header_length; }
        char* body() { return  data_.get() + header_length; }
        
        int type() const { return m_header.type; }


        std::size_t body_length() const { return m_header.bodySize; }

        void setMessage(int messageType, const void* buffer, size_t bufferSize) {
            // 確認(rèn)body大小未超過限制
            assert(bufferSize < max_body_length);
            m_header.bodySize = bufferSize;
            m_header.type = messageType;
            std::memcpy(body(), buffer, bufferSize);
            std::memcpy(data(), &m_header, sizeof(m_header));
            char* body_ = body();
            std::cerr << "set message body=" << body_ << std::endl;
        }

        void setMessage(int messageType, const std::string& buffer) {
            setMessage(messageType, buffer.data(), buffer.size());
        }

        bool decode_header() {
            std::memcpy(&m_header, data(), header_length);
            if(m_header.bodySize > max_body_length) {
                std::cout <<"body size: " << m_header.bodySize << " header type:" << m_header.type  << std::endl;
                return false;
            }

            return true;
        }
        
    private:
        std::shared_ptr<char> data_;
        Header m_header;
};

#endif

const.h

#ifndef _FREDRIC_CONST_H_
#define _FREDRIC_CONST_H_

const int MAX_BODY_LEN = 655360;

#endif

master.cpp

#include "chat_client.h"

int slave_count = 0;

class master : public chat_client {
   public:
    master(boost::asio::io_service& io_service,
           tcp::resolver::iterator endpoint_iterator)
        : chat_client(io_service, endpoint_iterator) {}

   private:
    bool receive_msg(const boost::system::error_code& ec) {
        // 有ec 直接return true退出
        if (ec) {
            socket_.close();
            return true;
        }

        // 校驗(yàn)一下消息長度和消息類型,
        // 證明確實(shí)發(fā)過來的是RomInformation消息
        if (read_msg_.type() != MT_SEND_TASK_INFO_MSG) {
            return false;
        }

        auto json_obj = to_json();
        std::cout << "slave ";
        std::cout << json_obj["name"].get<std::string>();
        std::cout << " says: ";
        std::cout << json_obj["information"].get<std::string>();
        std::cout << "\n";

        // 還沒到那么多slave個(gè)數(shù),不用計(jì)算,return false,接著等
        ++receive_slave_cout;
        if (receive_slave_cout != slave_count) {
            return false;
        }

        // TODO: 匯總計(jì)算結(jié)果
        std::cerr << "開始匯總計(jì)算性能測試結(jié)果" << std::endl;
        close();
        return true;
    }

    int receive_slave_cout{0};
};

int main(int argc, char* argv[]) {
    try {
        if (argc != 3) {
            std::cerr << "Usage: chat_client <host> <port>" << std::endl;
            return 1;
        }

        // TODO: 讀配置文件或者命令行參數(shù),獲取SLAVE_COUNT
        slave_count = 2;

        boost::asio::io_service io_service;
        tcp::resolver resolver(io_service);
        auto endpoint_iterator = resolver.resolve({argv[1], argv[2]});
        auto c = std::make_shared<master>(io_service, endpoint_iterator);

        std::thread t([&io_service]() { io_service.run(); });

        chat_message bind_name_msg;
        auto type = 0;

        std::string msgs_[] = {"BindName master", "LaunchTask task1"};

        for (const auto& msg_str : msgs_) {
            parse_and_send_a_message(*c, msg_str);
        }

        t.join();
    } catch (std::exception& ex) {
        std::cerr << "Exception: " << ex.what() << std::endl;
    }

    return 0;
}

slave.cpp

#include "chat_client.h"

class slave : public chat_client {
   public:
    slave(boost::asio::io_service& io_service,
          tcp::resolver::iterator endpoint_iterator)
        : chat_client(io_service, endpoint_iterator) {}

   private:
    bool receive_msg(const boost::system::error_code& ec) {
        // 有ec return true 退出
        if (ec) {
            socket_.close();
            return true;
        }

        // 沒有ec 消息不對, 接著等
        if (read_msg_.type() != MT_LAUNCH_TASK_MSG) {
            return false;
        }

        // 沒有ec,消息正確,做事兒,return true 退出
        // TODO: 啟動性能測試,完事以后發(fā)送
        // send_task_info_msg
        auto json_obj = to_json();
        std::cout << "master ";
        std::cout << json_obj["name"].get<std::string>();
        std::cout << " says: ";
        std::cout << json_obj["information"].get<std::string>();
        std::cout << "\n";

        std::cerr << "開始做性能測試..." << std::endl;
        std::cerr << "結(jié)束做性能測試..." << std::endl;

        chat_message msg;
        auto type = 0;

        std::string input("SendTaskInfo TaskSuccess");
        std::string output;

        if (parseMessage(input, &type, output)) {
            msg.setMessage(type, output.data(), output.size());
            write(msg);
        }
        close();
        return true;
    }
};

int main(int argc, char* argv[]) {
    try {
        if (argc != 3) {
            std::cerr << "Usage: chat_client <host> <port>" << std::endl;
            return 1;
        }

        boost::asio::io_service io_service;
        tcp::resolver resolver(io_service);
        auto endpoint_iterator = resolver.resolve({argv[1], argv[2]});
        auto c = std::make_shared<slave>(io_service, endpoint_iterator);

        std::string slave_name{};
        std::cout << "Pls input name: " << std::endl;
        std::cin >> slave_name;
        std::string input = "BindName " + slave_name;

        parse_and_send_a_message(*c, input);

        std::thread t([&io_service]() { io_service.run(); });

        t.join();
    } catch (std::exception& ex) {
        std::cerr << "Exception: " << ex.what() << std::endl;
    }

    return 0;
}

server.cpp

#include <boost/asio.hpp>
#include <boost/asio/spawn.hpp>
#include <cassert>
#include <cstdlib>
#include <deque>
#include <iostream>
#include <list>
#include <memory>
#include <set>
#include <utility>

#include "chat_message.h"

using boost::asio::ip::tcp;

using chat_message_queue = std::deque<chat_message>;

class chat_session;
using chat_session_ptr = std::shared_ptr<chat_session>;

std::string master_name = "";

// 聊天室類的聲明
class chat_room {
   public:
    void join(chat_session_ptr);
    void leave(chat_session_ptr);
    void deliver(const chat_message&);
    void deliver_to(const chat_message&, const std::string& paticipant_name);

   private:
    std::set<chat_session_ptr> participants_;
};

class chat_session : public std::enable_shared_from_this<chat_session> {
   public:
    chat_session(tcp::socket socket, chat_room& room)
        : socket_(std::move(socket)), room_(room) {}

    void start() {
        room_.join(shared_from_this());

        // 使用協(xié)程 同時(shí)讀取頭和body
        read_header_and_body();
    }

    void deliver(const chat_message& msg) {
        bool write_in_progress = !write_msgs_.empty();
        write_msgs_.push_back(msg);

        // 為了保護(hù)do_write線程里面的deque,避免兩個(gè)線程同時(shí)寫
        if (!write_in_progress) {
            do_write();
        }
    }

    std::string& get_client_name() { return m_name; }

   private:
    void read_header_and_body() {
        auto self(shared_from_this());
        boost::asio::spawn(
            socket_.get_executor(),
            [this, self](boost::asio::yield_context yield) {
                while (true) {
                    boost::system::error_code ec_header;
                    boost::asio::async_read(
                        socket_,
                        boost::asio::buffer(read_msg_.data(),
                                            chat_message::header_length),
                        yield[ec_header]);
                    if (!ec_header & read_msg_.decode_header()) {
                        boost::system::error_code ec_body;
                        // 在這里read body
                        boost::asio::async_read(
                            socket_,
                            boost::asio::buffer(read_msg_.body(),
                                                read_msg_.body_length()),
                            yield[ec_body]);

                        // 如果讀取消息成功,沒有error
                        if (!ec_body) {
                            // 調(diào)用各個(gè)Session的Deliver message
                            // 將消息發(fā)給對應(yīng)的client
                            // room_.deliver(read_msg_);
                            handleMessage();
                        } else {
                            room_.leave(shared_from_this());
                        }

                    } else {
                        room_.leave(shared_from_this());
                    }
                }
            });
    }

    json to_json() {
        std::string buffer(read_msg_.body(),
                           read_msg_.body() + read_msg_.body_length());
        std::cout << "raw message server: " << buffer << std::endl;
        std::stringstream ss(buffer);
        json json_obj;
        try {
            json_obj = json::parse(ss.str());
        } catch (std::exception& ex) {
            std::cerr << "解析 json對象 失敗!!" << std::endl;
            std::cerr << ex.what() << std::endl;
        }
        return json_obj;
    }

    // 處理接收到的客戶端的消息的函數(shù)
    void handleMessage() {
        // master 和 slave都會發(fā)這個(gè),注冊自己的名字
        if (read_msg_.type() == MT_BIND_NAME) {
            auto json_obj = to_json();
            m_name = json_obj["name"].get<std::string>();
            std::cerr << "Bind Name: " << m_name << std::endl;
            // 只有master會發(fā)launch task message
        } else if (read_msg_.type() == MT_LAUNCH_TASK_MSG) {
            master_name = m_name;
            std::cerr << "MT_LAUNCH_TASK_MSG: " << std::endl;

            std::cerr << "Master name: " << master_name << std::endl;
            auto json_obj = to_json();
            m_chatInformation = json_obj["information"].get<std::string>();
            auto rinfo = buildRoomInfo();
            chat_message msg;
            msg.setMessage(MT_LAUNCH_TASK_MSG, rinfo);
            room_.deliver(msg);
            // 所有slave執(zhí)行完性能測試之后,都會發(fā)這個(gè)消息
        } else if (read_msg_.type() == MT_SEND_TASK_INFO_MSG) {
            std::cerr << "send task info" << std::endl;
            std::cerr << "Master name in task info: " << master_name
                      << std::endl;
            auto json_obj = to_json();
            m_chatInformation = json_obj["information"].get<std::string>();
            auto rinfo = buildRoomInfo();
            chat_message msg;
            msg.setMessage(MT_SEND_TASK_INFO_MSG, rinfo);
            room_.deliver_to(msg, master_name);
        } else {
            // 不可用消息,啥也不做
        }
    }

    // 構(gòu)建一個(gè)RoomInformation信息
    std::string buildRoomInfo() const {
        json msg_body;
        msg_body["name"] = m_name;
        msg_body["information"] = m_chatInformation;
        std::string msg_body_str = msg_body.dump();
        std::cout << "Room info: " << msg_body_str << std::endl;
        return std::move(msg_body_str);
    }

    void do_write() {
        auto self(shared_from_this());
        boost::asio::spawn(
            socket_.get_executor(),
            [this, self](boost::asio::yield_context yield) {
                boost::system::error_code ec;
                boost::asio::async_write(
                    socket_,
                    boost::asio::buffer(write_msgs_.front().data(),
                                        write_msgs_.front().length()),
                    yield[ec]);
                if (!ec) {
                    write_msgs_.pop_front();
                    // 如果還有得寫,就接著寫
                    if (!write_msgs_.empty()) {
                        do_write();
                    }
                } else {
                    room_.leave(shared_from_this());
                }
            });
    }

    tcp::socket socket_;
    // room的生命周期必須長于session的生命周期,
    // 否則會因?yàn)槌钟袩o效的引用而翻車
    chat_room& room_;
    chat_message read_msg_;
    chat_message_queue write_msgs_;
    std::string m_name;             // 客戶端姓名
    std::string m_chatInformation;  // 客戶端當(dāng)前的消息
};

void chat_room::join(chat_session_ptr participant) {
    participants_.insert(participant);
    // 不需要廣播歷史消息這里
}

void chat_room::leave(chat_session_ptr participant) {
    participants_.erase(participant);
}

// 消息分發(fā)函數(shù)
void chat_room::deliver(const chat_message& msg) {
    // 給每個(gè)群聊參與者群發(fā)消息
    for (auto& participant : participants_) {
        participant->deliver(msg);
    }
}

void chat_room::deliver_to(const chat_message& msg,
                           const std::string& paticipant_name) {
    // 給每個(gè)群聊參與者群發(fā)消息
    for (auto& participant : participants_) {
        if (participant->get_client_name() == paticipant_name) {
            participant->deliver(msg);
        }
    }
}

class chat_server {
   public:
    chat_server(boost::asio::io_service& io_service,
                const tcp::endpoint& endpoint)
        : acceptor_(io_service, endpoint), socket_(io_service) {
        do_accept();
    }

    // 接收來自客戶端的連接的函數(shù)
    void do_accept() {
        boost::asio::spawn(
            socket_.get_executor(), [this](boost::asio::yield_context yield) {
                while (true) {
                    boost::system::error_code ec;
                    acceptor_.async_accept(socket_, yield[ec]);
                    if (!ec) {
                        auto session = std::make_shared<chat_session>(
                            std::move(socket_), room_);

                        session->start();
                    }
                }
            });
    }

   private:
    tcp::acceptor acceptor_;
    tcp::socket socket_;
    chat_room room_;
};

int main(int argc, char* argv[]) {
    try {
        if (argc < 2) {
            std::cerr << "Usage: chat_server <port> [<port> ...]" << std::endl;
            return 1;
        }

        boost::asio::io_service io_service;

        std::list<chat_server> servers;

        for (int i = 1; i < argc; ++i) {
            tcp::endpoint endpoint(tcp::v4(), std::atoi(argv[i]));
            servers.emplace_back(io_service, endpoint);
        }
        io_service.run();

    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }

    return 0;
}

parse_msg.h

#ifndef _FREDRIC_PARSE_MSG_H_
#define _FREDRIC_PARSE_MSG_H_

#include "json/json.hpp"

#include <sstream>
#include <cstdlib>
#include <string>
#include <iostream>
#include <cstring>

using json = nlohmann::json;


struct Header {
    int bodySize; // 包體大小 
    int type; // 消息類型
};

enum MessageType {
    MT_BIND_NAME = 1,
    MT_LAUNCH_TASK_MSG = 2,
    MT_SEND_TASK_INFO_MSG = 3, 
};

bool parseMessage(const std::string& input, int* type, std::string& outbuffer);
#endif

parse_msg.cpp

#include "parse_msg.h"
#include "const.h"

#include <sstream>


// 消息解析函數(shù)
// input 輸入的消息字符串
// type 傳出的消息類型指針
// outbuffer 輸出的用于發(fā)送的消息內(nèi)容字符串
bool parseMessage(const std::string& input, int* type, std::string& outbuffer) {
    auto pos = input.find_first_of(" ");
    // 消息中沒找到空格
    if(pos == std::string::npos) {
        return false;
    }

    if(pos == 0) {
        return false; 
    }

    auto command = input.substr(0, pos);
    // Bind姓名消息
    if(command == "BindName") {
        std::string name = input.substr(pos+1);
        if(name.size()>32) {
            std::cerr << "姓名的長度大于32個(gè)字節(jié)!" <<  std::endl;
            return false;
        }
        
        if(type) {
            *type = MT_BIND_NAME;
        }

        json msg_body;
        msg_body["name"] = name;
        outbuffer = msg_body.dump();
        return true;
        // 聊天消息
    }else if(command == "LaunchTask") {
        std::string task = input.substr(pos+1);
        if(task.size() > MAX_BODY_LEN) {
            std::cerr << "消息的長度大于" << MAX_BODY_LEN << "個(gè)字節(jié)!" << std::endl;
            return false;
        }

        json msg_body;
        msg_body["information"] = task;
        outbuffer = msg_body.dump();

        if(type) {
            *type = MT_LAUNCH_TASK_MSG;
        }

        return true;
    } else if(command == "SendTaskInfo") {
        std::string task_res = input.substr(pos+1);
        if(task_res.size() > MAX_BODY_LEN) {
            std::cerr << "消息的長度大于1000000個(gè)字節(jié)!" << std::endl;
            return false;
        }
        
        json msg_body;
        msg_body["information"] = task_res;
        outbuffer = msg_body.dump();

        if(type) {
            *type = MT_SEND_TASK_INFO_MSG;
        }

        return true;
    }
    
    // 不支持的消息類型,返回false
    return false;
}
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容