Unverified Commit f0baecee authored by kladko's avatar kladko

bug/SKALE-3751-enable-zeromq

parent ebc75da4
...@@ -104,12 +104,10 @@ void ServerWorker::work() { ...@@ -104,12 +104,10 @@ void ServerWorker::work() {
CHECK_STATE(msg.size() > 5 || msgData.at(0) == '{' || msgData[msg.size()] == '}'); CHECK_STATE(msg.size() > 5 || msgData.at(0) == '{' || msgData[msg.size()] == '}');
memcpy(msgData.data(), msg.data(), msg.size()); memcpy(msgData.data(), msg.data(), msg.size());
auto parsedMsg = ZMQMessage::parse( auto parsedMsg = ZMQMessage::parse(
(const char *) msgData.data(), msg.size(), true); (const char *) msgData.data(), msg.size(), true, checkSignature);
CHECK_STATE(parsedMsg); CHECK_STATE(parsedMsg);
result = parsedMsg->process(); result = parsedMsg->process();
......
...@@ -52,10 +52,6 @@ shared_ptr <ZMQMessage> ZMQClient::doRequestReply(Json::Value &_req) { ...@@ -52,10 +52,6 @@ shared_ptr <ZMQMessage> ZMQClient::doRequestReply(Json::Value &_req) {
string msgToSign = fastWriter.write(_req); string msgToSign = fastWriter.write(_req);
std::regex r("\\s+");
msgToSign = std::regex_replace(msgToSign, r, "");
_req["msgSig"] = signString(pkey, msgToSign); _req["msgSig"] = signString(pkey, msgToSign);
} }
...@@ -76,7 +72,7 @@ shared_ptr <ZMQMessage> ZMQClient::doRequestReply(Json::Value &_req) { ...@@ -76,7 +72,7 @@ shared_ptr <ZMQMessage> ZMQClient::doRequestReply(Json::Value &_req) {
CHECK_STATE(resultStr.back() == '}') CHECK_STATE(resultStr.back() == '}')
return ZMQMessage::parse(resultStr.c_str(), resultStr.size(), false); return ZMQMessage::parse(resultStr.c_str(), resultStr.size(), false, false);
} catch (std::exception &e) { } catch (std::exception &e) {
spdlog::error(string("Error in doRequestReply:") + e.what()); spdlog::error(string("Error in doRequestReply:") + e.what());
throw; throw;
...@@ -141,9 +137,13 @@ string ZMQClient::readFileIntoString(const string &_fileName) { ...@@ -141,9 +137,13 @@ string ZMQClient::readFileIntoString(const string &_fileName) {
} }
void ZMQClient::verifySig(EVP_PKEY* _pubkey, const string& _str, const string _sig) { void ZMQClient::verifySig(EVP_PKEY* _pubkey, const string& _str, const string& _sig) {
CHECK_STATE(_pubkey); CHECK_STATE(_pubkey);
CHECK_STATE(!_str.empty());
static std::regex r("\\s+");
auto msgToSign = std::regex_replace(_str, r, "");
vector<uint8_t> binSig(256,0); vector<uint8_t> binSig(256,0);
...@@ -163,7 +163,7 @@ void ZMQClient::verifySig(EVP_PKEY* _pubkey, const string& _str, const string _s ...@@ -163,7 +163,7 @@ void ZMQClient::verifySig(EVP_PKEY* _pubkey, const string& _str, const string _s
CHECK_STATE((EVP_DigestVerifyInit(mdctx, NULL, EVP_sha256(), NULL, _pubkey) == 1)); CHECK_STATE((EVP_DigestVerifyInit(mdctx, NULL, EVP_sha256(), NULL, _pubkey) == 1));
CHECK_STATE(EVP_DigestVerifyUpdate(mdctx, _str.c_str(), _str.size()) == 1); CHECK_STATE(EVP_DigestVerifyUpdate(mdctx, msgToSign.c_str(), msgToSign.size()) == 1);
/* First call EVP_DigestSignFinal with a NULL sig parameter to obtain the length of the /* First call EVP_DigestSignFinal with a NULL sig parameter to obtain the length of the
* signature. Length is returned in slen */ * signature. Length is returned in slen */
...@@ -181,6 +181,12 @@ void ZMQClient::verifySig(EVP_PKEY* _pubkey, const string& _str, const string _s ...@@ -181,6 +181,12 @@ void ZMQClient::verifySig(EVP_PKEY* _pubkey, const string& _str, const string _s
string ZMQClient::signString(EVP_PKEY* _pkey, const string& _str) { string ZMQClient::signString(EVP_PKEY* _pkey, const string& _str) {
CHECK_STATE(_pkey); CHECK_STATE(_pkey);
CHECK_STATE(!_str.empty());
static std::regex r("\\s+");
auto msgToSign = std::regex_replace(_str, r, "");
EVP_MD_CTX *mdctx = NULL; EVP_MD_CTX *mdctx = NULL;
int ret = 0; int ret = 0;
...@@ -194,7 +200,7 @@ string ZMQClient::signString(EVP_PKEY* _pkey, const string& _str) { ...@@ -194,7 +200,7 @@ string ZMQClient::signString(EVP_PKEY* _pkey, const string& _str) {
CHECK_STATE((EVP_DigestSignInit(mdctx, NULL, EVP_sha256(), NULL, _pkey) == 1)); CHECK_STATE((EVP_DigestSignInit(mdctx, NULL, EVP_sha256(), NULL, _pkey) == 1));
CHECK_STATE(EVP_DigestSignUpdate(mdctx, _str.c_str(), _str.size()) == 1); CHECK_STATE(EVP_DigestSignUpdate(mdctx, msgToSign.c_str(), msgToSign.size()) == 1);
/* First call EVP_DigestSignFinal with a NULL sig parameter to obtain the length of the /* First call EVP_DigestSignFinal with a NULL sig parameter to obtain the length of the
* signature. Length is returned in slen */ * signature. Length is returned in slen */
......
...@@ -92,7 +92,7 @@ public: ...@@ -92,7 +92,7 @@ public:
static string signString(EVP_PKEY* _pkey, const string& _str); static string signString(EVP_PKEY* _pkey, const string& _str);
static void verifySig(EVP_PKEY* _pubkey, const string& _str, const string _sig); static void verifySig(EVP_PKEY* _pubkey, const string& _str, const string& _sig);
string blsSignMessageHash(const std::string &keyShareName, const std::string &messageHash, int t, int n); string blsSignMessageHash(const std::string &keyShareName, const std::string &messageHash, int t, int n);
......
...@@ -51,10 +51,9 @@ string ZMQMessage::getStringRapid(const char *_name) { ...@@ -51,10 +51,9 @@ string ZMQMessage::getStringRapid(const char *_name) {
}; };
shared_ptr <ZMQMessage> ZMQMessage::parse(const char *_msg,
size_t _size, bool _isRequest,
shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg, bool _verifySig) {
size_t _size, bool _isRequest) {
CHECK_STATE(_msg); CHECK_STATE(_msg);
CHECK_STATE(_size > 5); CHECK_STATE(_size > 5);
...@@ -76,7 +75,9 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg, ...@@ -76,7 +75,9 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg,
CHECK_STATE((*d)["type"].IsString()); CHECK_STATE((*d)["type"].IsString());
string type = (*d)["type"].GetString(); string type = (*d)["type"].GetString();
if (d->HasMember("cert")) { if (_verifySig) {
CHECK_STATE(d->HasMember("cert"));
CHECK_STATE(d->HasMember("msgSig"));
CHECK_STATE((*d)["cert"].IsString()); CHECK_STATE((*d)["cert"].IsString());
auto cert = make_shared<string>((*d)["cert"].GetString()); auto cert = make_shared<string>((*d)["cert"].GetString());
...@@ -90,10 +91,15 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg, ...@@ -90,10 +91,15 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg,
outFile.close(); outFile.close();
static recursive_mutex m;
EVP_PKEY *publicKey = nullptr;
{
lock_guard <recursive_mutex> lock(m);
if (!verifiedCerts.exists(*cert)) { if (!verifiedCerts.exists(*cert)) {
CHECK_STATE(SGXWalletServer::verifyCert(filepath)); CHECK_STATE(SGXWalletServer::verifyCert(filepath));
auto handles = ZMQClient::readPublicKeyFromCertStr(*cert); auto handles = ZMQClient::readPublicKeyFromCertStr(*cert);
CHECK_STATE(handles.first); CHECK_STATE(handles.first);
CHECK_STATE(handles.second); CHECK_STATE(handles.second);
...@@ -102,13 +108,29 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg, ...@@ -102,13 +108,29 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg,
remove(cert->c_str()); remove(cert->c_str());
} }
} publicKey = verifiedCerts.get(*cert).first;
CHECK_STATE(publicKey);
if (d->HasMember("msgSig")) {
CHECK_STATE((*d)["msgSig"].IsString()); CHECK_STATE((*d)["msgSig"].IsString());
auto msgSig = make_shared<string>((*d)["msgSig"].GetString()); auto msgSig = make_shared<string>((*d)["msgSig"].GetString());
cerr << "Got msgSig:" << msgSig << endl; cerr << "Got msgSig:" << msgSig << endl;
d->RemoveMember("msgSig");
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> w(buffer);
d->Accept(w);
auto msgToVerify = buffer.GetString();
ZMQClient::verifySig(publicKey,msgToVerify, *msgSig );
} }
}
shared_ptr <ZMQMessage> result; shared_ptr <ZMQMessage> result;
...@@ -119,7 +141,7 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg, ...@@ -119,7 +141,7 @@ shared_ptr <ZMQMessage> ZMQMessage::parse(const char* _msg,
} }
} }
shared_ptr <ZMQMessage> ZMQMessage::buildRequest(string& _type, shared_ptr<rapidjson::Document> _d) { shared_ptr <ZMQMessage> ZMQMessage::buildRequest(string &_type, shared_ptr <rapidjson::Document> _d) {
if (_type == ZMQMessage::BLS_SIGN_REQ) { if (_type == ZMQMessage::BLS_SIGN_REQ) {
return make_shared<BLSSignReqMessage>(_d); return make_shared<BLSSignReqMessage>(_d);
} else if (_type == ZMQMessage::ECDSA_SIGN_REQ) { } else if (_type == ZMQMessage::ECDSA_SIGN_REQ) {
...@@ -131,7 +153,7 @@ shared_ptr <ZMQMessage> ZMQMessage::buildRequest(string& _type, shared_ptr<rapid ...@@ -131,7 +153,7 @@ shared_ptr <ZMQMessage> ZMQMessage::buildRequest(string& _type, shared_ptr<rapid
} }
} }
shared_ptr <ZMQMessage> ZMQMessage::buildResponse(string& _type, shared_ptr<rapidjson::Document> _d) { shared_ptr <ZMQMessage> ZMQMessage::buildResponse(string &_type, shared_ptr <rapidjson::Document> _d) {
if (_type == ZMQMessage::BLS_SIGN_RSP) { if (_type == ZMQMessage::BLS_SIGN_RSP) {
return return
make_shared<BLSSignRspMessage>(_d); make_shared<BLSSignRspMessage>(_d);
...@@ -145,4 +167,5 @@ shared_ptr <ZMQMessage> ZMQMessage::buildResponse(string& _type, shared_ptr<rapi ...@@ -145,4 +167,5 @@ shared_ptr <ZMQMessage> ZMQMessage::buildResponse(string& _type, shared_ptr<rapi
} }
} }
cache::lru_cache<string, pair<EVP_PKEY*, X509*>> ZMQMessage::verifiedCerts(256); cache::lru_cache<string, pair < EVP_PKEY * , X509 *>>
\ No newline at end of file ZMQMessage::verifiedCerts(256);
\ No newline at end of file
...@@ -38,6 +38,9 @@ ...@@ -38,6 +38,9 @@
#include "abstractstubserver.h" #include "abstractstubserver.h"
#include "document.h" #include "document.h"
#include "stringbuffer.h"
#include "writer.h"
#include "SGXException.h" #include "SGXException.h"
using namespace std; using namespace std;
...@@ -54,7 +57,6 @@ protected: ...@@ -54,7 +57,6 @@ protected:
public: public:
void verifySig();
static constexpr const char *BLS_SIGN_REQ = "BLSSignReq"; static constexpr const char *BLS_SIGN_REQ = "BLSSignReq";
static constexpr const char *BLS_SIGN_RSP = "BLSSignRsp"; static constexpr const char *BLS_SIGN_RSP = "BLSSignRsp";
...@@ -72,8 +74,8 @@ public: ...@@ -72,8 +74,8 @@ public:
return getUint64Rapid("status"); return getUint64Rapid("status");
} }
static shared_ptr<ZMQMessage> parse(vector<uint8_t> &_msg, bool _isRequest); static shared_ptr <ZMQMessage> parse(const char* _msg, size_t _size, bool _isRequest,
static shared_ptr <ZMQMessage> parse(const char* _msg, size_t _size, bool _isRequest); bool _verifySig);
static shared_ptr<ZMQMessage> buildRequest(string& type, shared_ptr<rapidjson::Document> _d); static shared_ptr<ZMQMessage> buildRequest(string& type, shared_ptr<rapidjson::Document> _d);
static shared_ptr<ZMQMessage> buildResponse(string& type, shared_ptr<rapidjson::Document> _d); static shared_ptr<ZMQMessage> buildResponse(string& type, shared_ptr<rapidjson::Document> _d);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment