diff --git a/websocket/src/handshake.cpp b/websocket/src/handshake.cpp index 0b81930..59749fe 100644 --- a/websocket/src/handshake.cpp +++ b/websocket/src/handshake.cpp @@ -146,7 +146,8 @@ Result ReceiveHeaders(WebsocketConnection* conn) conn->m_Buffer[conn->m_BufferSize] = '\0'; // Check if the end of the response has arrived - if (conn->m_BufferSize >= 4 && strcmp(conn->m_Buffer + conn->m_BufferSize - 4, "\r\n\r\n") == 0) + const char* endtag = strstr(conn->m_Buffer, "\r\n\r\n"); + if (endtag != 0) { return RESULT_OK; } @@ -171,6 +172,8 @@ Result VerifyHeaders(WebsocketConnection* conn) return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Missing: '%s' in header", http_version_and_status_protocol); } + const char* endtag = strstr(conn->m_Buffer, "\r\n\r\n"); + r = strstr(r, "\r\n") + 2; bool upgraded = false; @@ -180,7 +183,7 @@ Result VerifyHeaders(WebsocketConnection* conn) // TODO: Perhaps also support the Sec-WebSocket-Protocol // parse the headers in place - while (r) + while (r < endtag) { // Tokenize the each header line: "Key: Value\r\n" const char* key = r; @@ -218,11 +221,16 @@ Result VerifyHeaders(WebsocketConnection* conn) if (strcmp(value, (const char*)client_key) == 0) valid_key = true; } - - if (strcmp(r, "\r\n") == 0) - break; } + // The response might contain both the headers, but also (if successful) the first batch of data + endtag += 4; + uint32_t size = conn->m_BufferSize - (endtag - conn->m_Buffer); + conn->m_BufferSize = size; + memmove(conn->m_Buffer, endtag, size); + conn->m_Buffer[size] = 0; + conn->m_HasHandshakeData = conn->m_BufferSize != 0 ? 1 : 0; + if (!upgraded) dmLogError("Failed to find the Upgrade keyword in the response headers"); if (!valid_key) diff --git a/websocket/src/socket.cpp b/websocket/src/socket.cpp index 2935c45..72eb6a0 100644 --- a/websocket/src/socket.cpp +++ b/websocket/src/socket.cpp @@ -42,15 +42,24 @@ dmSocket::Result Send(WebsocketConnection* conn, const char* buffer, int length, } if (out_sent_bytes) *out_sent_bytes = total_sent_bytes; + + DebugPrint(2, "Sent buffer:", buffer, length); return dmSocket::RESULT_OK; } dmSocket::Result Receive(WebsocketConnection* conn, void* buffer, int length, int* received_bytes) { + dmSocket::Result sr; if (conn->m_SSLSocket) - return dmSSLSocket::Receive(conn->m_SSLSocket, buffer, length, received_bytes); + sr = dmSSLSocket::Receive(conn->m_SSLSocket, buffer, length, received_bytes); else - return dmSocket::Receive(conn->m_Socket, buffer, length, received_bytes); + sr = dmSocket::Receive(conn->m_Socket, buffer, length, received_bytes); + + int num_bytes = received_bytes ? (uint32_t)*received_bytes : 0; + if (sr == dmSocket::RESULT_OK && num_bytes > 0) + DebugPrint(2, "Received bytes:", buffer, num_bytes); + + return sr; } } // namespace \ No newline at end of file diff --git a/websocket/src/websocket.cpp b/websocket/src/websocket.cpp index 0c42535..6615641 100644 --- a/websocket/src/websocket.cpp +++ b/websocket/src/websocket.cpp @@ -9,13 +9,20 @@ #include #include #include +#include // isprint et al #if defined(__EMSCRIPTEN__) #include // for EM_ASM #endif +#if defined(WIN32) +#include +#define alloca _alloca +#endif + namespace dmWebsocket { +int g_DebugWebSocket = 0; struct WebsocketContext { @@ -60,8 +67,45 @@ const char* StateToString(State err) #undef STRING_CASE -#define WS_DEBUG(...) -//#define WS_DEBUG(...) dmLogWarning(__VA_ARGS__); +void DebugLog(int level, const char* fmt, ...) +{ + if (level > g_DebugWebSocket) + return; + + size_t buffer_size = 4096; + char* buffer = (char*)alloca(buffer_size); + va_list lst; + va_start(lst, fmt); + + buffer_size = vsnprintf(buffer, buffer_size, fmt, lst); + dmLogWarning("%s", buffer); + va_end(lst); +} + +void DebugPrint(int level, const char* msg, const void* _bytes, uint32_t num_bytes) +{ + if (level > g_DebugWebSocket) + return; + + const uint8_t* bytes = (const uint8_t*)_bytes; + printf("%s '", msg); + for (uint32_t i = 0; i < num_bytes; ++i) + { + int c = bytes[i]; + if (isprint(c)) + printf("%c", c); + else if (c == '\r') + printf("\\r"); + else if (c == '\n') + printf("\\n"); + else if (c == '\t') + printf("\\t"); + else + printf("\\%02x", c); + } + printf("' %u bytes\n", num_bytes); +} + #define CLOSE_CONN(...) \ SetStatus(conn, RESULT_ERROR, __VA_ARGS__); \ @@ -74,7 +118,7 @@ static void SetState(WebsocketConnection* conn, State state) if (prev_state != state) { conn->m_State = state; - WS_DEBUG("%s -> %s", StateToString(prev_state), StateToString(conn->m_State)); + DebugLog(1, "%s -> %s", StateToString(prev_state), StateToString(conn->m_State)); } } @@ -100,10 +144,11 @@ Result SetStatus(WebsocketConnection* conn, Result status, const char* format, . static WebsocketConnection* CreateConnection(const char* url) { - WebsocketConnection* conn = (WebsocketConnection*)malloc(sizeof(WebsocketConnection)); - memset(conn, 0, sizeof(WebsocketConnection)); + WebsocketConnection* conn = new WebsocketConnection; conn->m_BufferCapacity = g_Websocket.m_BufferSize; conn->m_Buffer = (char*)malloc(conn->m_BufferCapacity); + conn->m_Buffer[0] = 0; + conn->m_BufferSize = 0; dmURI::Parts uri; dmURI::Parse(url, &conn->m_Url); @@ -114,6 +159,17 @@ static WebsocketConnection* CreateConnection(const char* url) conn->m_SSL = strcmp(conn->m_Url.m_Scheme, "wss") == 0 ? 1 : 0; conn->m_State = STATE_CONNECTING; + conn->m_Callback = 0; + conn->m_Connection = 0; + conn->m_Socket = 0; + conn->m_SSLSocket = 0; + conn->m_Status = RESULT_OK; + conn->m_HasHandshakeData = 0; + +#if defined(HAVE_WSLAY) + conn->m_Ctx = 0; +#endif + return conn; } @@ -138,8 +194,9 @@ static void DestroyConnection(WebsocketConnection* conn) dmConnectionPool::Return(g_Websocket.m_Pool, conn->m_Connection); #endif + free((void*)conn->m_Buffer); - free((void*)conn); + delete conn; } @@ -239,7 +296,7 @@ static int LuaSend(lua_State* L) const char* string = luaL_checklstring(L, 2, &string_length); #if defined(HAVE_WSLAY) - int write_mode = WSLAY_BINARY_FRAME; // WSLAY_TEXT_FRAME + int write_mode = WSLAY_BINARY_FRAME; // or WSLAY_TEXT_FRAME struct wslay_event_msg msg; msg.opcode = write_mode; @@ -259,7 +316,7 @@ static int LuaSend(lua_State* L) return 0; } -static void HandleCallback(WebsocketConnection* conn, int event) +static void HandleCallback(WebsocketConnection* conn, int event, int msg_offset, int msg_length) { if (!dmScript::IsCallbackValid(conn->m_Callback)) return; @@ -285,7 +342,7 @@ static void HandleCallback(WebsocketConnection* conn, int event) lua_setfield(L, -2, "error"); } else if (EVENT_MESSAGE == event) { - lua_pushlstring(L, conn->m_Buffer, conn->m_BufferSize); + lua_pushlstring(L, conn->m_Buffer + msg_offset, msg_length); lua_setfield(L, -2, "message"); } @@ -329,7 +386,7 @@ static void LuaInit(lua_State* L) assert(top == lua_gettop(L)); } -static dmExtension::Result WebsocketAppInitialize(dmExtension::AppParams* params) +static dmExtension::Result AppInitialize(dmExtension::AppParams* params) { g_Websocket.m_BufferSize = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.buffer_size", 64 * 1024); g_Websocket.m_Timeout = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.socket_timeout", 500 * 1000); @@ -341,6 +398,10 @@ static dmExtension::Result WebsocketAppInitialize(dmExtension::AppParams* params pool_params.m_MaxConnections = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.max_connections", 2); dmConnectionPool::Result result = dmConnectionPool::New(&pool_params, &g_Websocket.m_Pool); + g_DebugWebSocket = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.debug", 0); + if (g_DebugWebSocket) + dmLogInfo("dmWebSocket::g_DebugWebSocket == %d", g_DebugWebSocket); + if (dmConnectionPool::RESULT_OK != result) { dmLogError("Failed to create connection pool: %d", result); @@ -380,7 +441,7 @@ static dmExtension::Result WebsocketAppInitialize(dmExtension::AppParams* params return dmExtension::RESULT_OK; } -static dmExtension::Result WebsocketInitialize(dmExtension::Params* params) +static dmExtension::Result Initialize(dmExtension::Params* params) { if (!g_Websocket.m_Initialized) return dmExtension::RESULT_OK; @@ -391,19 +452,27 @@ static dmExtension::Result WebsocketInitialize(dmExtension::Params* params) return dmExtension::RESULT_OK; } -static dmExtension::Result WebsocketAppFinalize(dmExtension::AppParams* params) +static dmExtension::Result AppFinalize(dmExtension::AppParams* params) { dmConnectionPool::Shutdown(g_Websocket.m_Pool, dmSocket::SHUTDOWNTYPE_READWRITE); return dmExtension::RESULT_OK; } -static dmExtension::Result WebsocketFinalize(dmExtension::Params* params) +static dmExtension::Result Finalize(dmExtension::Params* params) { return dmExtension::RESULT_OK; } -static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) +Result PushMessage(WebsocketConnection* conn, int length) +{ + if (conn->m_Messages.Full()) + conn->m_Messages.OffsetCapacity(4); + conn->m_Messages.Push(length); + return dmWebsocket::RESULT_OK; +} + +static dmExtension::Result OnUpdate(dmExtension::Params* params) { uint32_t size = g_Websocket.m_Connections.Size(); @@ -415,10 +484,10 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) { if (RESULT_OK != conn->m_Status) { - HandleCallback(conn, EVENT_ERROR); + HandleCallback(conn, EVENT_ERROR, 0, 0); } - HandleCallback(conn, EVENT_DISCONNECTED); + HandleCallback(conn, EVENT_DISCONNECTED, 0, 0); g_Websocket.m_Connections.EraseSwap(i); --i; @@ -450,9 +519,9 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) if (dmSocket::RESULT_OK == sr) { + PushMessage(conn, recv_bytes); conn->m_BufferSize += recv_bytes; conn->m_Buffer[conn->m_BufferCapacity-1] = 0; - conn->m_HasMessage = 1; } else { @@ -461,12 +530,15 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) } #endif - if (conn->m_HasMessage) + uint32_t offset = 0; + for (uint32_t i = 0; i < conn->m_Messages.Size(); ++i) { - HandleCallback(conn, EVENT_MESSAGE); - conn->m_HasMessage = 0; - conn->m_BufferSize = 0; + uint32_t length = conn->m_Messages[i]; + HandleCallback(conn, EVENT_MESSAGE, offset, length); + offset += length; } + conn->m_Messages.SetSize(0); + conn->m_BufferSize = 0; } else if (STATE_HANDSHAKE_READ == conn->m_State) { @@ -482,6 +554,7 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) continue; } + // Verifies headers, and also stages any initial sent data result = VerifyHeaders(conn); if (RESULT_OK != result) { @@ -505,11 +578,8 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) #endif dmSocket::SetBlocking(conn->m_Socket, false); - conn->m_Buffer[0] = 0; - conn->m_BufferSize = 0; - SetState(conn, STATE_CONNECTED); - HandleCallback(conn, EVENT_CONNECTED); + HandleCallback(conn, EVENT_CONNECTED, 0, 0); } else if (STATE_HANDSHAKE_WRITE == conn->m_State) { @@ -580,6 +650,6 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) } // dmWebsocket -DM_DECLARE_EXTENSION(Websocket, LIB_NAME, dmWebsocket::WebsocketAppInitialize, dmWebsocket::WebsocketAppFinalize, dmWebsocket::WebsocketInitialize, dmWebsocket::WebsocketOnUpdate, 0, dmWebsocket::WebsocketFinalize) +DM_DECLARE_EXTENSION(Websocket, LIB_NAME, dmWebsocket::AppInitialize, dmWebsocket::AppFinalize, dmWebsocket::Initialize, dmWebsocket::OnUpdate, 0, dmWebsocket::Finalize) #undef CLOSE_CONN diff --git a/websocket/src/websocket.h b/websocket/src/websocket.h index b14a43b..abcebd4 100644 --- a/websocket/src/websocket.h +++ b/websocket/src/websocket.h @@ -19,6 +19,7 @@ #include #include #include +#include namespace dmCrypt { @@ -69,14 +70,16 @@ namespace dmWebsocket dmConnectionPool::HConnection m_Connection; dmSocket::Socket m_Socket; dmSSLSocket::Socket m_SSLSocket; + dmArray m_Messages; // lengths of the messages in the data buffer uint8_t m_Key[16]; State m_State; - uint32_t m_SSL:1; - uint32_t m_HasMessage:1; char* m_Buffer; int m_BufferSize; uint32_t m_BufferCapacity; Result m_Status; + uint8_t m_SSL:1; + uint8_t m_HasHandshakeData:1; + uint8_t :6; }; // Set error message @@ -96,6 +99,9 @@ namespace dmWebsocket Result ReceiveHeaders(WebsocketConnection* conn); Result VerifyHeaders(WebsocketConnection* conn); + // Messages + Result PushMessage(WebsocketConnection* conn, int length); + #if defined(HAVE_WSLAY) // Wslay callbacks int WSL_Init(wslay_event_context_ptr* ctx, ssize_t buffer_size, void* userctx); @@ -114,6 +120,15 @@ namespace dmWebsocket typedef struct { uint64_t state; uint64_t inc; } pcg32_random_t; void pcg32_srandom_r(pcg32_random_t* rng, uint64_t initstate, uint64_t initseq); uint32_t pcg32_random_r(pcg32_random_t* rng); + + // If level <= dmWebSocket::g_DebugWebSocket, then it outputs the message +#ifdef __GNUC__ + void DebugLog(int level, const char* fmt, ...) __attribute__ ((format (printf, 2, 3))); +#else + void DebugLog(int level, const char* fmt, ...); +#endif + + void DebugPrint(int level, const char* msg, const void* _bytes, uint32_t num_bytes); } diff --git a/websocket/src/wslay_callbacks.cpp b/websocket/src/wslay_callbacks.cpp index a2ffc25..1e1b48c 100644 --- a/websocket/src/wslay_callbacks.cpp +++ b/websocket/src/wslay_callbacks.cpp @@ -82,6 +82,15 @@ ssize_t WSL_RecvCallback(wslay_event_context_ptr ctx, uint8_t *buf, size_t len, int r = -1; // received bytes if >=0, error if < 0 + if (conn->m_HasHandshakeData) + { + r = conn->m_BufferSize; + memcpy(buf, conn->m_Buffer, r); + conn->m_BufferSize = 0; + conn->m_HasHandshakeData = 0; + return r; + } + dmSocket::Result socket_result = Receive(conn, buf, len, &r); if (dmSocket::RESULT_OK == socket_result && r == 0) @@ -117,16 +126,23 @@ ssize_t WSL_SendCallback(wslay_event_context_ptr ctx, const uint8_t *data, size_ return (ssize_t)sent_bytes; } +// Might be called multiple times for a connection receiving multiple events void WSL_OnMsgRecvCallback(wslay_event_context_ptr ctx, const struct wslay_event_on_msg_recv_arg *arg, void *user_data) { WebsocketConnection* conn = (WebsocketConnection*)user_data; if (arg->opcode == WSLAY_TEXT_FRAME || arg->opcode == WSLAY_BINARY_FRAME) { - if (arg->msg_length >= conn->m_BufferCapacity) - conn->m_Buffer = (char*)realloc(conn->m_Buffer, arg->msg_length + 1); - memcpy(conn->m_Buffer, arg->msg, arg->msg_length); - conn->m_BufferSize = arg->msg_length; - conn->m_HasMessage = 1; + if ((conn->m_BufferSize + arg->msg_length) >= conn->m_BufferCapacity) + { + conn->m_BufferCapacity = conn->m_BufferSize + arg->msg_length + 1; + conn->m_Buffer = (char*)realloc(conn->m_Buffer, conn->m_BufferCapacity); + } + // append to the end of the buffer + memcpy(conn->m_Buffer + conn->m_BufferSize, arg->msg, arg->msg_length); + conn->m_BufferSize += arg->msg_length; + + PushMessage(conn, arg->msg_length); + DebugPrint(2, __FUNCTION__, conn->m_Buffer+conn->m_BufferSize-arg->msg_length, arg->msg_length); } else if (arg->opcode == WSLAY_CONNECTION_CLOSE) {