diff --git a/examples/websocket.gui_script b/examples/websocket.gui_script index 1cd8f17..282446d 100644 --- a/examples/websocket.gui_script +++ b/examples/websocket.gui_script @@ -1,5 +1,4 @@ local URL="://echo.websocket.org" - local function click_button(node, action) return gui.is_enabled(node) and action.pressed and gui.pick_node(node, action.x, action.y) end @@ -73,7 +72,14 @@ local function websocket_callback(self, conn, data) update_gui(self) log("Connected: " .. tostring(conn)) elseif data.event == websocket.EVENT_ERROR then - log("Error: '" .. tostring(data.error) .. "'") + log("Error: '" .. tostring(data.message) .. "'") + if data.handshake_response then + log("Handshake response status: '" .. tostring(data.handshake_response.status) .. "'") + for key, value in pairs(data.handshake_response.headers) do + log("Handshake response header: '" .. key .. ": " .. value .. "'") + end + log("Handshake response body: '" .. tostring(data.handshake_response.response) .. "'") + end elseif data.event == websocket.EVENT_MESSAGE then log("Receiving: '" .. tostring(data.message) .. "'") end diff --git a/websocket/src/handshake.cpp b/websocket/src/handshake.cpp index 4395269..88461d0 100644 --- a/websocket/src/handshake.cpp +++ b/websocket/src/handshake.cpp @@ -1,5 +1,6 @@ #include "websocket.h" #include +#include #include // tolower namespace dmWebsocket @@ -170,17 +171,62 @@ Result ReceiveHeaders(WebsocketConnection* conn) } #endif -static int dmStriCmp(const char* s1, const char* s2) +static void HandleVersion(void* user_data, int major, int minor, int status, const char* status_str) { - for (;;) + HandshakeResponse* response = (HandshakeResponse*)user_data; + response->m_HttpMajor = major; + response->m_HttpMinor = minor; + response->m_ResponseStatusCode = status; +} + + +static void HandleHeader(void* user_data, const char* key, const char* value) +{ + HandshakeResponse* response = (HandshakeResponse*)user_data; + if (response->m_Headers.Remaining() == 0) { - if (!*s1 || !*s2 || tolower((unsigned char) *s1) != tolower((unsigned char) *s2)) - { - return (unsigned char) *s1 - (unsigned char) *s2; - } - s1++; - s2++; + response->m_Headers.OffsetCapacity(4); } + HttpHeader* new_header = new HttpHeader(key, value); + + response->m_Headers.Push(new_header); +} + + +static void HandleContent(void* user_data, int offset) +{ + HandshakeResponse* response = (HandshakeResponse*)user_data; + response->m_BodyOffset = offset; +} + + +bool ValidateSecretKey(WebsocketConnection* conn, const char* server_key) +{ + uint8_t client_key[32 + 40]; + uint32_t client_key_len = sizeof(client_key); + dmCrypt::Base64Encode(conn->m_Key, sizeof(conn->m_Key), client_key, &client_key_len); + client_key[client_key_len] = 0; + + DebugLog(2, "Secret key (base64): %s", client_key); + + memcpy(client_key + client_key_len, RFC_MAGIC, strlen(RFC_MAGIC)); + client_key_len += strlen(RFC_MAGIC); + client_key[client_key_len] = 0; + + DebugLog(2, "Secret key + RFC_MAGIC: %s", client_key); + + uint8_t client_key_sha1[20]; + dmCrypt::HashSha1(client_key, client_key_len, client_key_sha1); + + DebugPrint(2, "Hashed key (sha1):", client_key_sha1, sizeof(client_key_sha1)); + + client_key_len = sizeof(client_key); + dmCrypt::Base64Encode(client_key_sha1, sizeof(client_key_sha1), client_key, &client_key_len); + client_key[client_key_len] = 0; + DebugLog(2, "Client key (base64): %s", client_key); + DebugLog(2, "Server key (base64): %s", server_key); + + return strcmp(server_key, (const char*)client_key) == 0; } @@ -194,81 +240,31 @@ Result VerifyHeaders(WebsocketConnection* conn) { char* r = conn->m_Buffer; - // According to protocol, the response should start with "HTTP/1.1 " - const char* http_version_and_status_protocol = "HTTP/1.1 101"; - if (strstr(r, http_version_and_status_protocol) != r) { - return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Missing: '%s' in header", http_version_and_status_protocol); - } + // Find start of payload now because dmHttpClient::ParseHeader is destructive + const char* start_of_payload = strstr(conn->m_Buffer, "\r\n\r\n"); + start_of_payload += 4; - const char* endtag = strstr(conn->m_Buffer, "\r\n\r\n"); - - r = strstr(r, "\r\n") + 2; - - bool connection = false; - bool upgrade = false; - bool valid_key = false; - - // parse the headers in place - while (r < endtag) + HandshakeResponse* response = new HandshakeResponse(); + conn->m_HandshakeResponse = response; + dmHttpClient::ParseResult parse_result = dmHttpClient::ParseHeader(r, response, true, &HandleVersion, &HandleHeader, &HandleContent); + if (parse_result != dmHttpClient::ParseResult::PARSE_RESULT_OK) { - // Tokenize the each header line: "Key: Value\r\n" - const char* key = r; - r = strchr(r, ':'); - *r = 0; - ++r; - const char* value = r; - while(*value == ' ') - ++value; - r = strstr(r, "\r\n"); - *r = 0; - r += 2; - - // Page 18 in https://tools.ietf.org/html/rfc6455#section-11.3.3 - if (dmStriCmp(key, "Connection") == 0 && dmStriCmp(value, "Upgrade") == 0) - connection = true; - else if (dmStriCmp(key, "Upgrade") == 0 && dmStriCmp(value, "websocket") == 0) - upgrade = true; - else if (dmStriCmp(key, "Sec-WebSocket-Accept") == 0) - { - uint8_t client_key[32 + 40]; - uint32_t client_key_len = sizeof(client_key); - dmCrypt::Base64Encode(conn->m_Key, sizeof(conn->m_Key), client_key, &client_key_len); - client_key[client_key_len] = 0; - - DebugLog(2, "Secret key (base64): %s", client_key); - - memcpy(client_key + client_key_len, RFC_MAGIC, strlen(RFC_MAGIC)); - client_key_len += strlen(RFC_MAGIC); - client_key[client_key_len] = 0; - - DebugLog(2, "Secret key + RFC_MAGIC: %s", client_key); - - uint8_t client_key_sha1[20]; - dmCrypt::HashSha1(client_key, client_key_len, client_key_sha1); - - DebugPrint(2, "Hashed key (sha1):", client_key_sha1, sizeof(client_key_sha1)); - - client_key_len = sizeof(client_key); - dmCrypt::Base64Encode(client_key_sha1, sizeof(client_key_sha1), client_key, &client_key_len); - client_key[client_key_len] = 0; - - DebugLog(2, "Client key (base64): %s", client_key); - - DebugLog(2, "Server key (base64): %s", value); - - if (strcmp(value, (const char*)client_key) == 0) - valid_key = true; - } + return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Failed to parse handshake response. 'dmHttpClient::ParseResult=%i'", parse_result); } - // The response might contain both the headers, but also (if successful) the first batch of data - endtag += 4; - uint32_t size = conn->m_BufferSize - (endtag - conn->m_Buffer); - conn->m_BufferSize = size; - memmove(conn->m_Buffer, endtag, size); - conn->m_Buffer[size] = 0; - conn->m_HasHandshakeData = conn->m_BufferSize != 0 ? 1 : 0; + if (response->m_ResponseStatusCode != 101) { + return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Wrong response status: %i", response->m_ResponseStatusCode); + } + HttpHeader *connection_header, *upgrade_header, *websocket_secret_header; + connection_header = response->GetHeader("Connection"); + upgrade_header = response->GetHeader("Upgrade"); + websocket_secret_header = response->GetHeader("Sec-WebSocket-Accept"); + bool connection = connection_header && dmStriCmp(connection_header->m_Value, "Upgrade") == 0; + bool upgrade = upgrade_header && dmStriCmp(upgrade_header->m_Value, "websocket") == 0; + bool valid_key = websocket_secret_header && ValidateSecretKey(conn, websocket_secret_header->m_Value); + + // Send error to lua? if (!connection) dmLogError("Failed to find the Connection keyword in the response headers"); if (!upgrade) @@ -277,11 +273,21 @@ Result VerifyHeaders(WebsocketConnection* conn) dmLogError("Failed to find valid key in the response headers"); bool ok = connection && upgrade && valid_key; - if (!ok) { - dmLogError("Response:\n\"%s\"\n", conn->m_Buffer); + if(!ok) + { + return RESULT_HANDSHAKE_FAILED; } - return ok ? RESULT_OK : RESULT_HANDSHAKE_FAILED; + delete conn->m_HandshakeResponse; + conn->m_HandshakeResponse = 0; + + // The response might contain both the headers, but also (if successful) the first batch of data + uint32_t size = conn->m_BufferSize - (start_of_payload - conn->m_Buffer); + conn->m_BufferSize = size; + memmove(conn->m_Buffer, start_of_payload, size); + conn->m_Buffer[size] = 0; + conn->m_HasHandshakeData = conn->m_BufferSize != 0 ? 1 : 0; + return RESULT_OK; } #endif diff --git a/websocket/src/websocket.cpp b/websocket/src/websocket.cpp index 00b84b0..df00f75 100644 --- a/websocket/src/websocket.cpp +++ b/websocket/src/websocket.cpp @@ -67,6 +67,19 @@ const char* StateToString(State err) #undef STRING_CASE +int dmStriCmp(const char* s1, const char* s2) +{ + for (;;) + { + if (!*s1 || !*s2 || tolower((unsigned char) *s1) != tolower((unsigned char) *s2)) + { + return (unsigned char) *s1 - (unsigned char) *s2; + } + s1++; + s2++; + } +} + void DebugLog(int level, const char* fmt, ...) { if (level > g_DebugWebSocket) @@ -195,6 +208,7 @@ static WebsocketConnection* CreateConnection(const char* url) conn->m_SSLSocket = 0; conn->m_Status = RESULT_OK; conn->m_HasHandshakeData = 0; + conn->m_HandshakeResponse = 0; #if defined(HAVE_WSLAY) conn->m_Ctx = 0; @@ -227,6 +241,9 @@ static void DestroyConnection(WebsocketConnection* conn) dmConnectionPool::Return(g_Websocket.m_Pool, conn->m_Connection); #endif + if (conn->m_HandshakeResponse) + delete conn->m_HandshakeResponse; + free((void*)conn->m_Buffer); delete conn; @@ -383,12 +400,74 @@ static void HandleCallback(WebsocketConnection* conn, int event, int msg_offset, lua_pushlstring(L, conn->m_Buffer + msg_offset, msg_length); lua_setfield(L, -2, "message"); + if(conn->m_HandshakeResponse) + { + HandshakeResponse* response = conn->m_HandshakeResponse; + + lua_newtable(L); + + lua_pushnumber(L, response->m_ResponseStatusCode); + lua_setfield(L, -2, "status"); + + lua_pushstring(L, &conn->m_Buffer[response->m_BodyOffset]); + lua_setfield(L, -2, "response"); + + lua_newtable(L); + for (uint32_t i = 0; i < response->m_Headers.Size(); ++i) + { + lua_pushstring(L, response->m_Headers[i]->m_Value); + lua_setfield(L, -2, response->m_Headers[i]->m_Key); + } + lua_setfield(L, -2, "headers"); + + lua_setfield(L, -2, "handshake_response"); + + delete conn->m_HandshakeResponse; + conn->m_HandshakeResponse = 0; + } + dmScript::PCall(L, 3, 0); dmScript::TeardownCallback(conn->m_Callback); } +HttpHeader::HttpHeader(const char* key, const char* value) +{ + m_Key = strdup(key); + m_Value = strdup(value); +} + +HttpHeader::~HttpHeader() +{ + free((void*)m_Key); + free((void*)m_Value); + m_Key = 0; + m_Value = 0; +} + +HttpHeader* HandshakeResponse::GetHeader(const char* header_key) +{ + for(uint32_t i = 0; i < m_Headers.Size(); ++i) + { + if (dmStriCmp(m_Headers[i]->m_Key, header_key) == 0) + { + return m_Headers[i]; + } + } + + return 0; +} + +HandshakeResponse::~HandshakeResponse() +{ + for(uint32_t i = 0; i < m_Headers.Size(); ++i) + { + delete m_Headers[i]; + } +} + + // *************************************************************************************************** // Life cycle functions diff --git a/websocket/src/websocket.h b/websocket/src/websocket.h index b54c30b..8adffb2 100644 --- a/websocket/src/websocket.h +++ b/websocket/src/websocket.h @@ -78,6 +78,27 @@ namespace dmWebsocket uint32_t m_Type:2; }; + struct HttpHeader + { + const char* m_Key; + const char* m_Value; + HttpHeader(const char* key, const char* value); + ~HttpHeader(); + }; + + struct HandshakeResponse + { + int m_HttpMajor; + int m_HttpMinor; + int m_ResponseStatusCode; + int m_BodyOffset; + dmArray m_Headers; + + ~HandshakeResponse(); + HttpHeader* GetHeader(const char* header); + }; + + struct WebsocketConnection { dmScript::LuaCallbackInfo* m_Callback; @@ -101,6 +122,7 @@ namespace dmWebsocket uint8_t m_SSL:1; uint8_t m_HasHandshakeData:1; uint8_t :7; + HandshakeResponse* m_HandshakeResponse; }; // Set error message @@ -148,16 +170,6 @@ namespace dmWebsocket void DebugLog(int level, const char* fmt, ...); #endif + int dmStriCmp(const char* s1, const char* s2); void DebugPrint(int level, const char* msg, const void* _bytes, uint32_t num_bytes); } - - - - - - - - - - -