Handshake response object if handshake failed (#29)

* added handshake response object, so if error occurs on websocket handhsake stage, server's response can be reachable for further decisions

* fixed typo; removed c++11 style constructors
This commit is contained in:
Alexander Palagin 2021-02-04 09:30:40 +02:00 committed by GitHub
parent 39abd7cdea
commit ac230a6278
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 197 additions and 94 deletions

View File

@ -1,5 +1,4 @@
local URL="://echo.websocket.org" local URL="://echo.websocket.org"
local function click_button(node, action) local function click_button(node, action)
return gui.is_enabled(node) and action.pressed and gui.pick_node(node, action.x, action.y) return gui.is_enabled(node) and action.pressed and gui.pick_node(node, action.x, action.y)
end end
@ -73,7 +72,14 @@ local function websocket_callback(self, conn, data)
update_gui(self) update_gui(self)
log("Connected: " .. tostring(conn)) log("Connected: " .. tostring(conn))
elseif data.event == websocket.EVENT_ERROR then elseif data.event == websocket.EVENT_ERROR then
log("Error: '" .. tostring(data.error) .. "'") log("Error: '" .. tostring(data.message) .. "'")
if data.handshake_response then
log("Handshake response status: '" .. tostring(data.handshake_response.status) .. "'")
for key, value in pairs(data.handshake_response.headers) do
log("Handshake response header: '" .. key .. ": " .. value .. "'")
end
log("Handshake response body: '" .. tostring(data.handshake_response.response) .. "'")
end
elseif data.event == websocket.EVENT_MESSAGE then elseif data.event == websocket.EVENT_MESSAGE then
log("Receiving: '" .. tostring(data.message) .. "'") log("Receiving: '" .. tostring(data.message) .. "'")
end end

View File

@ -1,5 +1,6 @@
#include "websocket.h" #include "websocket.h"
#include <dmsdk/dlib/socket.h> #include <dmsdk/dlib/socket.h>
#include <dmsdk/dlib/http_client.h>
#include <ctype.h> // tolower #include <ctype.h> // tolower
namespace dmWebsocket namespace dmWebsocket
@ -170,17 +171,62 @@ Result ReceiveHeaders(WebsocketConnection* conn)
} }
#endif #endif
static int dmStriCmp(const char* s1, const char* s2) static void HandleVersion(void* user_data, int major, int minor, int status, const char* status_str)
{ {
for (;;) HandshakeResponse* response = (HandshakeResponse*)user_data;
response->m_HttpMajor = major;
response->m_HttpMinor = minor;
response->m_ResponseStatusCode = status;
}
static void HandleHeader(void* user_data, const char* key, const char* value)
{
HandshakeResponse* response = (HandshakeResponse*)user_data;
if (response->m_Headers.Remaining() == 0)
{ {
if (!*s1 || !*s2 || tolower((unsigned char) *s1) != tolower((unsigned char) *s2)) response->m_Headers.OffsetCapacity(4);
{
return (unsigned char) *s1 - (unsigned char) *s2;
}
s1++;
s2++;
} }
HttpHeader* new_header = new HttpHeader(key, value);
response->m_Headers.Push(new_header);
}
static void HandleContent(void* user_data, int offset)
{
HandshakeResponse* response = (HandshakeResponse*)user_data;
response->m_BodyOffset = offset;
}
bool ValidateSecretKey(WebsocketConnection* conn, const char* server_key)
{
uint8_t client_key[32 + 40];
uint32_t client_key_len = sizeof(client_key);
dmCrypt::Base64Encode(conn->m_Key, sizeof(conn->m_Key), client_key, &client_key_len);
client_key[client_key_len] = 0;
DebugLog(2, "Secret key (base64): %s", client_key);
memcpy(client_key + client_key_len, RFC_MAGIC, strlen(RFC_MAGIC));
client_key_len += strlen(RFC_MAGIC);
client_key[client_key_len] = 0;
DebugLog(2, "Secret key + RFC_MAGIC: %s", client_key);
uint8_t client_key_sha1[20];
dmCrypt::HashSha1(client_key, client_key_len, client_key_sha1);
DebugPrint(2, "Hashed key (sha1):", client_key_sha1, sizeof(client_key_sha1));
client_key_len = sizeof(client_key);
dmCrypt::Base64Encode(client_key_sha1, sizeof(client_key_sha1), client_key, &client_key_len);
client_key[client_key_len] = 0;
DebugLog(2, "Client key (base64): %s", client_key);
DebugLog(2, "Server key (base64): %s", server_key);
return strcmp(server_key, (const char*)client_key) == 0;
} }
@ -194,81 +240,31 @@ Result VerifyHeaders(WebsocketConnection* conn)
{ {
char* r = conn->m_Buffer; char* r = conn->m_Buffer;
// According to protocol, the response should start with "HTTP/1.1 <statuscode> <message>" // Find start of payload now because dmHttpClient::ParseHeader is destructive
const char* http_version_and_status_protocol = "HTTP/1.1 101"; const char* start_of_payload = strstr(conn->m_Buffer, "\r\n\r\n");
if (strstr(r, http_version_and_status_protocol) != r) { start_of_payload += 4;
return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Missing: '%s' in header", http_version_and_status_protocol);
}
const char* endtag = strstr(conn->m_Buffer, "\r\n\r\n"); HandshakeResponse* response = new HandshakeResponse();
conn->m_HandshakeResponse = response;
r = strstr(r, "\r\n") + 2; dmHttpClient::ParseResult parse_result = dmHttpClient::ParseHeader(r, response, true, &HandleVersion, &HandleHeader, &HandleContent);
if (parse_result != dmHttpClient::ParseResult::PARSE_RESULT_OK)
bool connection = false;
bool upgrade = false;
bool valid_key = false;
// parse the headers in place
while (r < endtag)
{ {
// Tokenize the each header line: "Key: Value\r\n" return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Failed to parse handshake response. 'dmHttpClient::ParseResult=%i'", parse_result);
const char* key = r;
r = strchr(r, ':');
*r = 0;
++r;
const char* value = r;
while(*value == ' ')
++value;
r = strstr(r, "\r\n");
*r = 0;
r += 2;
// Page 18 in https://tools.ietf.org/html/rfc6455#section-11.3.3
if (dmStriCmp(key, "Connection") == 0 && dmStriCmp(value, "Upgrade") == 0)
connection = true;
else if (dmStriCmp(key, "Upgrade") == 0 && dmStriCmp(value, "websocket") == 0)
upgrade = true;
else if (dmStriCmp(key, "Sec-WebSocket-Accept") == 0)
{
uint8_t client_key[32 + 40];
uint32_t client_key_len = sizeof(client_key);
dmCrypt::Base64Encode(conn->m_Key, sizeof(conn->m_Key), client_key, &client_key_len);
client_key[client_key_len] = 0;
DebugLog(2, "Secret key (base64): %s", client_key);
memcpy(client_key + client_key_len, RFC_MAGIC, strlen(RFC_MAGIC));
client_key_len += strlen(RFC_MAGIC);
client_key[client_key_len] = 0;
DebugLog(2, "Secret key + RFC_MAGIC: %s", client_key);
uint8_t client_key_sha1[20];
dmCrypt::HashSha1(client_key, client_key_len, client_key_sha1);
DebugPrint(2, "Hashed key (sha1):", client_key_sha1, sizeof(client_key_sha1));
client_key_len = sizeof(client_key);
dmCrypt::Base64Encode(client_key_sha1, sizeof(client_key_sha1), client_key, &client_key_len);
client_key[client_key_len] = 0;
DebugLog(2, "Client key (base64): %s", client_key);
DebugLog(2, "Server key (base64): %s", value);
if (strcmp(value, (const char*)client_key) == 0)
valid_key = true;
}
} }
// The response might contain both the headers, but also (if successful) the first batch of data if (response->m_ResponseStatusCode != 101) {
endtag += 4; return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Wrong response status: %i", response->m_ResponseStatusCode);
uint32_t size = conn->m_BufferSize - (endtag - conn->m_Buffer); }
conn->m_BufferSize = size;
memmove(conn->m_Buffer, endtag, size);
conn->m_Buffer[size] = 0;
conn->m_HasHandshakeData = conn->m_BufferSize != 0 ? 1 : 0;
HttpHeader *connection_header, *upgrade_header, *websocket_secret_header;
connection_header = response->GetHeader("Connection");
upgrade_header = response->GetHeader("Upgrade");
websocket_secret_header = response->GetHeader("Sec-WebSocket-Accept");
bool connection = connection_header && dmStriCmp(connection_header->m_Value, "Upgrade") == 0;
bool upgrade = upgrade_header && dmStriCmp(upgrade_header->m_Value, "websocket") == 0;
bool valid_key = websocket_secret_header && ValidateSecretKey(conn, websocket_secret_header->m_Value);
// Send error to lua?
if (!connection) if (!connection)
dmLogError("Failed to find the Connection keyword in the response headers"); dmLogError("Failed to find the Connection keyword in the response headers");
if (!upgrade) if (!upgrade)
@ -277,11 +273,21 @@ Result VerifyHeaders(WebsocketConnection* conn)
dmLogError("Failed to find valid key in the response headers"); dmLogError("Failed to find valid key in the response headers");
bool ok = connection && upgrade && valid_key; bool ok = connection && upgrade && valid_key;
if (!ok) { if(!ok)
dmLogError("Response:\n\"%s\"\n", conn->m_Buffer); {
return RESULT_HANDSHAKE_FAILED;
} }
return ok ? RESULT_OK : RESULT_HANDSHAKE_FAILED; delete conn->m_HandshakeResponse;
conn->m_HandshakeResponse = 0;
// The response might contain both the headers, but also (if successful) the first batch of data
uint32_t size = conn->m_BufferSize - (start_of_payload - conn->m_Buffer);
conn->m_BufferSize = size;
memmove(conn->m_Buffer, start_of_payload, size);
conn->m_Buffer[size] = 0;
conn->m_HasHandshakeData = conn->m_BufferSize != 0 ? 1 : 0;
return RESULT_OK;
} }
#endif #endif

View File

@ -67,6 +67,19 @@ const char* StateToString(State err)
#undef STRING_CASE #undef STRING_CASE
int dmStriCmp(const char* s1, const char* s2)
{
for (;;)
{
if (!*s1 || !*s2 || tolower((unsigned char) *s1) != tolower((unsigned char) *s2))
{
return (unsigned char) *s1 - (unsigned char) *s2;
}
s1++;
s2++;
}
}
void DebugLog(int level, const char* fmt, ...) void DebugLog(int level, const char* fmt, ...)
{ {
if (level > g_DebugWebSocket) if (level > g_DebugWebSocket)
@ -195,6 +208,7 @@ static WebsocketConnection* CreateConnection(const char* url)
conn->m_SSLSocket = 0; conn->m_SSLSocket = 0;
conn->m_Status = RESULT_OK; conn->m_Status = RESULT_OK;
conn->m_HasHandshakeData = 0; conn->m_HasHandshakeData = 0;
conn->m_HandshakeResponse = 0;
#if defined(HAVE_WSLAY) #if defined(HAVE_WSLAY)
conn->m_Ctx = 0; conn->m_Ctx = 0;
@ -227,6 +241,9 @@ static void DestroyConnection(WebsocketConnection* conn)
dmConnectionPool::Return(g_Websocket.m_Pool, conn->m_Connection); dmConnectionPool::Return(g_Websocket.m_Pool, conn->m_Connection);
#endif #endif
if (conn->m_HandshakeResponse)
delete conn->m_HandshakeResponse;
free((void*)conn->m_Buffer); free((void*)conn->m_Buffer);
delete conn; delete conn;
@ -383,12 +400,74 @@ static void HandleCallback(WebsocketConnection* conn, int event, int msg_offset,
lua_pushlstring(L, conn->m_Buffer + msg_offset, msg_length); lua_pushlstring(L, conn->m_Buffer + msg_offset, msg_length);
lua_setfield(L, -2, "message"); lua_setfield(L, -2, "message");
if(conn->m_HandshakeResponse)
{
HandshakeResponse* response = conn->m_HandshakeResponse;
lua_newtable(L);
lua_pushnumber(L, response->m_ResponseStatusCode);
lua_setfield(L, -2, "status");
lua_pushstring(L, &conn->m_Buffer[response->m_BodyOffset]);
lua_setfield(L, -2, "response");
lua_newtable(L);
for (uint32_t i = 0; i < response->m_Headers.Size(); ++i)
{
lua_pushstring(L, response->m_Headers[i]->m_Value);
lua_setfield(L, -2, response->m_Headers[i]->m_Key);
}
lua_setfield(L, -2, "headers");
lua_setfield(L, -2, "handshake_response");
delete conn->m_HandshakeResponse;
conn->m_HandshakeResponse = 0;
}
dmScript::PCall(L, 3, 0); dmScript::PCall(L, 3, 0);
dmScript::TeardownCallback(conn->m_Callback); dmScript::TeardownCallback(conn->m_Callback);
} }
HttpHeader::HttpHeader(const char* key, const char* value)
{
m_Key = strdup(key);
m_Value = strdup(value);
}
HttpHeader::~HttpHeader()
{
free((void*)m_Key);
free((void*)m_Value);
m_Key = 0;
m_Value = 0;
}
HttpHeader* HandshakeResponse::GetHeader(const char* header_key)
{
for(uint32_t i = 0; i < m_Headers.Size(); ++i)
{
if (dmStriCmp(m_Headers[i]->m_Key, header_key) == 0)
{
return m_Headers[i];
}
}
return 0;
}
HandshakeResponse::~HandshakeResponse()
{
for(uint32_t i = 0; i < m_Headers.Size(); ++i)
{
delete m_Headers[i];
}
}
// *************************************************************************************************** // ***************************************************************************************************
// Life cycle functions // Life cycle functions

View File

@ -78,6 +78,27 @@ namespace dmWebsocket
uint32_t m_Type:2; uint32_t m_Type:2;
}; };
struct HttpHeader
{
const char* m_Key;
const char* m_Value;
HttpHeader(const char* key, const char* value);
~HttpHeader();
};
struct HandshakeResponse
{
int m_HttpMajor;
int m_HttpMinor;
int m_ResponseStatusCode;
int m_BodyOffset;
dmArray<HttpHeader*> m_Headers;
~HandshakeResponse();
HttpHeader* GetHeader(const char* header);
};
struct WebsocketConnection struct WebsocketConnection
{ {
dmScript::LuaCallbackInfo* m_Callback; dmScript::LuaCallbackInfo* m_Callback;
@ -101,6 +122,7 @@ namespace dmWebsocket
uint8_t m_SSL:1; uint8_t m_SSL:1;
uint8_t m_HasHandshakeData:1; uint8_t m_HasHandshakeData:1;
uint8_t :7; uint8_t :7;
HandshakeResponse* m_HandshakeResponse;
}; };
// Set error message // Set error message
@ -148,16 +170,6 @@ namespace dmWebsocket
void DebugLog(int level, const char* fmt, ...); void DebugLog(int level, const char* fmt, ...);
#endif #endif
int dmStriCmp(const char* s1, const char* s2);
void DebugPrint(int level, const char* msg, const void* _bytes, uint32_t num_bytes); void DebugPrint(int level, const char* msg, const void* _bytes, uint32_t num_bytes);
} }