10 Commits
1.0.0 ... 1.2.0

Author SHA1 Message Date
Mathias Westerdahl
dc1d57d661 Merge pull request #10 from defold/issue-4-case-insensitive-header
Issue 4: Compare header keys without case sensitivity
2020-10-12 09:59:55 +02:00
JCash
5b0a9960a8 Check the response headers more thoroughly 2020-10-12 09:45:05 +02:00
JCash
40ba1b334c Issue 4: Compare header keys without case sensitivity 2020-10-12 09:33:47 +02:00
Mathias Westerdahl
23ba179e2a Merge pull request #9 from defold/issue-8-close-event
Issue 8: Handle the websocket close event
2020-09-28 09:30:12 +02:00
JCash
36bf5d1c03 code cleanup 2020-09-27 17:04:00 +02:00
JCash
274f29d7e4 moved error checking code outside of socket implementation 2020-09-27 17:02:14 +02:00
JCash
bd8569f49a Issue 8: Handle the websocket close event 2020-09-27 16:54:22 +02:00
Mathias Westerdahl
18a768774f Merge pull request #7 from defold/issue-6-multiple-messages
Issue 6: Handle multiple messages per frame
2020-09-27 10:14:08 +02:00
JCash
8e32fa3c76 compile fixes 2020-09-26 15:50:24 +02:00
JCash
832a156395 Issue 6: Handle multiple messages per frame 2020-09-26 12:51:43 +02:00
7 changed files with 281 additions and 78 deletions

View File

@@ -10,3 +10,13 @@ We recommend using a link to a zip file of a [specific release](https://github.c
## API reference ## API reference
https://defold.com/extension-websocket/api/ https://defold.com/extension-websocket/api/
## Debugging
In order to make it easier to debug this extension, we provide a `game.project` setting `websocket.debug`.
Set it to:
* `0` to disable debugging (i.e. no debug output).
* `1` to display state changes.
* `2` to display the messages sent and received.

View File

@@ -20,10 +20,13 @@ local function websocket_callback(self, conn, data)
print("Connected " .. conn) print("Connected " .. conn)
-- self.connection = conn -- self.connection = conn
elseif data.event == websocket.EVENT_ERROR then elseif data.event == websocket.EVENT_ERROR then
print("Error:", data.error) print("Error:", data.message)
elseif data.event == websocket.EVENT_MESSAGE then elseif data.event == websocket.EVENT_MESSAGE then
print("Receiving: '" .. tostring(data.message) .. "'") print("Receiving: '" .. tostring(data.message) .. "'")
end end
elseif data.event == websocket.EVENT_DISCONNECTED then
print("Disconnected: '" .. tostring(data.message) .. "'")
end
end end
function init(self) function init(self)

View File

@@ -1,5 +1,6 @@
#include "websocket.h" #include "websocket.h"
#include <dmsdk/dlib/socket.h> #include <dmsdk/dlib/socket.h>
#include <ctype.h> // tolower
namespace dmWebsocket namespace dmWebsocket
{ {
@@ -146,7 +147,8 @@ Result ReceiveHeaders(WebsocketConnection* conn)
conn->m_Buffer[conn->m_BufferSize] = '\0'; conn->m_Buffer[conn->m_BufferSize] = '\0';
// Check if the end of the response has arrived // Check if the end of the response has arrived
if (conn->m_BufferSize >= 4 && strcmp(conn->m_Buffer + conn->m_BufferSize - 4, "\r\n\r\n") == 0) const char* endtag = strstr(conn->m_Buffer, "\r\n\r\n");
if (endtag != 0)
{ {
return RESULT_OK; return RESULT_OK;
} }
@@ -155,6 +157,20 @@ Result ReceiveHeaders(WebsocketConnection* conn)
} }
#endif #endif
static int dmStriCmp(const char* s1, const char* s2)
{
for (;;)
{
if (!*s1 || !*s2 || tolower((unsigned char) *s1) != tolower((unsigned char) *s2))
{
return (unsigned char) *s1 - (unsigned char) *s2;
}
s1++;
s2++;
}
}
#if defined(__EMSCRIPTEN__) #if defined(__EMSCRIPTEN__)
Result VerifyHeaders(WebsocketConnection* conn) Result VerifyHeaders(WebsocketConnection* conn)
{ {
@@ -171,16 +187,19 @@ Result VerifyHeaders(WebsocketConnection* conn)
return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Missing: '%s' in header", http_version_and_status_protocol); return SetStatus(conn, RESULT_HANDSHAKE_FAILED, "Missing: '%s' in header", http_version_and_status_protocol);
} }
const char* endtag = strstr(conn->m_Buffer, "\r\n\r\n");
r = strstr(r, "\r\n") + 2; r = strstr(r, "\r\n") + 2;
bool upgraded = false; bool connection = false;
bool upgrade = false;
bool valid_key = false; bool valid_key = false;
const char* protocol = ""; const char* protocol = "";
// TODO: Perhaps also support the Sec-WebSocket-Protocol // TODO: Perhaps also support the Sec-WebSocket-Protocol
// parse the headers in place // parse the headers in place
while (r) while (r < endtag)
{ {
// Tokenize the each header line: "Key: Value\r\n" // Tokenize the each header line: "Key: Value\r\n"
const char* key = r; const char* key = r;
@@ -194,45 +213,65 @@ Result VerifyHeaders(WebsocketConnection* conn)
*r = 0; *r = 0;
r += 2; r += 2;
if (strcmp(key, "Connection") == 0 && strcmp(value, "Upgrade") == 0) // Page 18 in https://tools.ietf.org/html/rfc6455#section-11.3.3
upgraded = true; if (dmStriCmp(key, "Connection") == 0 && dmStriCmp(value, "Upgrade") == 0)
else if (strcmp(key, "Sec-WebSocket-Accept") == 0) connection = true;
else if (dmStriCmp(key, "Upgrade") == 0 && dmStriCmp(value, "websocket") == 0)
upgrade = true;
else if (dmStriCmp(key, "Sec-WebSocket-Accept") == 0)
{ {
uint8_t client_key[32 + 40]; uint8_t client_key[32 + 40];
uint32_t client_key_len = sizeof(client_key); uint32_t client_key_len = sizeof(client_key);
dmCrypt::Base64Encode(conn->m_Key, sizeof(conn->m_Key), client_key, &client_key_len); dmCrypt::Base64Encode(conn->m_Key, sizeof(conn->m_Key), client_key, &client_key_len);
client_key[client_key_len] = 0; client_key[client_key_len] = 0;
DebugLog(2, "Secret key (base64): %s", client_key);
memcpy(client_key + client_key_len, RFC_MAGIC, strlen(RFC_MAGIC)); memcpy(client_key + client_key_len, RFC_MAGIC, strlen(RFC_MAGIC));
client_key_len += strlen(RFC_MAGIC); client_key_len += strlen(RFC_MAGIC);
client_key[client_key_len] = 0; client_key[client_key_len] = 0;
DebugLog(2, "Secret key + RFC_MAGIC: %s", client_key);
uint8_t client_key_sha1[20]; uint8_t client_key_sha1[20];
dmCrypt::HashSha1(client_key, client_key_len, client_key_sha1); dmCrypt::HashSha1(client_key, client_key_len, client_key_sha1);
DebugPrint(2, "Hashed key (sha1):", client_key_sha1, sizeof(client_key_sha1));
client_key_len = sizeof(client_key); client_key_len = sizeof(client_key);
dmCrypt::Base64Encode(client_key_sha1, sizeof(client_key_sha1), client_key, &client_key_len); dmCrypt::Base64Encode(client_key_sha1, sizeof(client_key_sha1), client_key, &client_key_len);
client_key[client_key_len] = 0; client_key[client_key_len] = 0;
DebugLog(2, "Client key (base64): %s", client_key);
DebugLog(2, "Server key (base64): %s", value);
if (strcmp(value, (const char*)client_key) == 0) if (strcmp(value, (const char*)client_key) == 0)
valid_key = true; valid_key = true;
} }
if (strcmp(r, "\r\n") == 0)
break;
} }
if (!upgraded) // The response might contain both the headers, but also (if successful) the first batch of data
endtag += 4;
uint32_t size = conn->m_BufferSize - (endtag - conn->m_Buffer);
conn->m_BufferSize = size;
memmove(conn->m_Buffer, endtag, size);
conn->m_Buffer[size] = 0;
conn->m_HasHandshakeData = conn->m_BufferSize != 0 ? 1 : 0;
if (!connection)
dmLogError("Failed to find the Connection keyword in the response headers");
if (!upgrade)
dmLogError("Failed to find the Upgrade keyword in the response headers"); dmLogError("Failed to find the Upgrade keyword in the response headers");
if (!valid_key) if (!valid_key)
dmLogError("Failed to find valid key in the response headers"); dmLogError("Failed to find valid key in the response headers");
if (!(upgraded && valid_key)) { bool ok = connection && upgrade && valid_key;
if (!ok) {
dmLogError("Response:\n\"%s\"\n", conn->m_Buffer); dmLogError("Response:\n\"%s\"\n", conn->m_Buffer);
} }
return (upgraded && valid_key) ? RESULT_OK : RESULT_HANDSHAKE_FAILED; return ok ? RESULT_OK : RESULT_HANDSHAKE_FAILED;
} }
#endif #endif

View File

@@ -42,15 +42,24 @@ dmSocket::Result Send(WebsocketConnection* conn, const char* buffer, int length,
} }
if (out_sent_bytes) if (out_sent_bytes)
*out_sent_bytes = total_sent_bytes; *out_sent_bytes = total_sent_bytes;
DebugPrint(2, "Sent buffer:", buffer, length);
return dmSocket::RESULT_OK; return dmSocket::RESULT_OK;
} }
dmSocket::Result Receive(WebsocketConnection* conn, void* buffer, int length, int* received_bytes) dmSocket::Result Receive(WebsocketConnection* conn, void* buffer, int length, int* received_bytes)
{ {
dmSocket::Result sr;
if (conn->m_SSLSocket) if (conn->m_SSLSocket)
return dmSSLSocket::Receive(conn->m_SSLSocket, buffer, length, received_bytes); sr = dmSSLSocket::Receive(conn->m_SSLSocket, buffer, length, received_bytes);
else else
return dmSocket::Receive(conn->m_Socket, buffer, length, received_bytes); sr = dmSocket::Receive(conn->m_Socket, buffer, length, received_bytes);
int num_bytes = received_bytes ? (uint32_t)*received_bytes : 0;
if (sr == dmSocket::RESULT_OK && num_bytes > 0)
DebugPrint(2, "Received bytes:", buffer, num_bytes);
return sr;
} }
} // namespace } // namespace

View File

@@ -9,13 +9,20 @@
#include <dmsdk/dlib/connection_pool.h> #include <dmsdk/dlib/connection_pool.h>
#include <dmsdk/dlib/dns.h> #include <dmsdk/dlib/dns.h>
#include <dmsdk/dlib/sslsocket.h> #include <dmsdk/dlib/sslsocket.h>
#include <ctype.h> // isprint et al
#if defined(__EMSCRIPTEN__) #if defined(__EMSCRIPTEN__)
#include <emscripten/emscripten.h> // for EM_ASM #include <emscripten/emscripten.h> // for EM_ASM
#endif #endif
#if defined(WIN32)
#include <malloc.h>
#define alloca _alloca
#endif
namespace dmWebsocket { namespace dmWebsocket {
int g_DebugWebSocket = 0;
struct WebsocketContext struct WebsocketContext
{ {
@@ -60,8 +67,45 @@ const char* StateToString(State err)
#undef STRING_CASE #undef STRING_CASE
#define WS_DEBUG(...) void DebugLog(int level, const char* fmt, ...)
//#define WS_DEBUG(...) dmLogWarning(__VA_ARGS__); {
if (level > g_DebugWebSocket)
return;
size_t buffer_size = 4096;
char* buffer = (char*)alloca(buffer_size);
va_list lst;
va_start(lst, fmt);
buffer_size = vsnprintf(buffer, buffer_size, fmt, lst);
dmLogWarning("%s", buffer);
va_end(lst);
}
void DebugPrint(int level, const char* msg, const void* _bytes, uint32_t num_bytes)
{
if (level > g_DebugWebSocket)
return;
const uint8_t* bytes = (const uint8_t*)_bytes;
printf("%s '", msg);
for (uint32_t i = 0; i < num_bytes; ++i)
{
int c = bytes[i];
if (isprint(c))
printf("%c", c);
else if (c == '\r')
printf("\\r");
else if (c == '\n')
printf("\\n");
else if (c == '\t')
printf("\\t");
else
printf("\\%02x", c);
}
printf("' %u bytes\n", num_bytes);
}
#define CLOSE_CONN(...) \ #define CLOSE_CONN(...) \
SetStatus(conn, RESULT_ERROR, __VA_ARGS__); \ SetStatus(conn, RESULT_ERROR, __VA_ARGS__); \
@@ -74,7 +118,7 @@ static void SetState(WebsocketConnection* conn, State state)
if (prev_state != state) if (prev_state != state)
{ {
conn->m_State = state; conn->m_State = state;
WS_DEBUG("%s -> %s", StateToString(prev_state), StateToString(conn->m_State)); DebugLog(1, "%s -> %s", StateToString(prev_state), StateToString(conn->m_State));
} }
} }
@@ -89,6 +133,8 @@ Result SetStatus(WebsocketConnection* conn, Result status, const char* format, .
conn->m_BufferSize = vsnprintf(conn->m_Buffer, conn->m_BufferCapacity, format, lst); conn->m_BufferSize = vsnprintf(conn->m_Buffer, conn->m_BufferCapacity, format, lst);
va_end(lst); va_end(lst);
conn->m_Status = status; conn->m_Status = status;
DebugLog(1, "STATUS: '%s' len: %u", conn->m_Buffer, conn->m_BufferSize);
} }
return status; return status;
} }
@@ -100,10 +146,11 @@ Result SetStatus(WebsocketConnection* conn, Result status, const char* format, .
static WebsocketConnection* CreateConnection(const char* url) static WebsocketConnection* CreateConnection(const char* url)
{ {
WebsocketConnection* conn = (WebsocketConnection*)malloc(sizeof(WebsocketConnection)); WebsocketConnection* conn = new WebsocketConnection;
memset(conn, 0, sizeof(WebsocketConnection));
conn->m_BufferCapacity = g_Websocket.m_BufferSize; conn->m_BufferCapacity = g_Websocket.m_BufferSize;
conn->m_Buffer = (char*)malloc(conn->m_BufferCapacity); conn->m_Buffer = (char*)malloc(conn->m_BufferCapacity);
conn->m_Buffer[0] = 0;
conn->m_BufferSize = 0;
dmURI::Parts uri; dmURI::Parts uri;
dmURI::Parse(url, &conn->m_Url); dmURI::Parse(url, &conn->m_Url);
@@ -114,6 +161,17 @@ static WebsocketConnection* CreateConnection(const char* url)
conn->m_SSL = strcmp(conn->m_Url.m_Scheme, "wss") == 0 ? 1 : 0; conn->m_SSL = strcmp(conn->m_Url.m_Scheme, "wss") == 0 ? 1 : 0;
conn->m_State = STATE_CONNECTING; conn->m_State = STATE_CONNECTING;
conn->m_Callback = 0;
conn->m_Connection = 0;
conn->m_Socket = 0;
conn->m_SSLSocket = 0;
conn->m_Status = RESULT_OK;
conn->m_HasHandshakeData = 0;
#if defined(HAVE_WSLAY)
conn->m_Ctx = 0;
#endif
return conn; return conn;
} }
@@ -138,8 +196,9 @@ static void DestroyConnection(WebsocketConnection* conn)
dmConnectionPool::Return(g_Websocket.m_Pool, conn->m_Connection); dmConnectionPool::Return(g_Websocket.m_Pool, conn->m_Connection);
#endif #endif
free((void*)conn->m_Buffer); free((void*)conn->m_Buffer);
free((void*)conn); delete conn;
} }
@@ -239,7 +298,7 @@ static int LuaSend(lua_State* L)
const char* string = luaL_checklstring(L, 2, &string_length); const char* string = luaL_checklstring(L, 2, &string_length);
#if defined(HAVE_WSLAY) #if defined(HAVE_WSLAY)
int write_mode = WSLAY_BINARY_FRAME; // WSLAY_TEXT_FRAME int write_mode = WSLAY_BINARY_FRAME; // or WSLAY_TEXT_FRAME
struct wslay_event_msg msg; struct wslay_event_msg msg;
msg.opcode = write_mode; msg.opcode = write_mode;
@@ -259,7 +318,7 @@ static int LuaSend(lua_State* L)
return 0; return 0;
} }
static void HandleCallback(WebsocketConnection* conn, int event) static void HandleCallback(WebsocketConnection* conn, int event, int msg_offset, int msg_length)
{ {
if (!dmScript::IsCallbackValid(conn->m_Callback)) if (!dmScript::IsCallbackValid(conn->m_Callback))
return; return;
@@ -280,14 +339,8 @@ static void HandleCallback(WebsocketConnection* conn, int event)
lua_pushinteger(L, event); lua_pushinteger(L, event);
lua_setfield(L, -2, "event"); lua_setfield(L, -2, "event");
if (EVENT_ERROR == event) { lua_pushlstring(L, conn->m_Buffer + msg_offset, msg_length);
lua_pushlstring(L, conn->m_Buffer, conn->m_BufferSize); lua_setfield(L, -2, "message");
lua_setfield(L, -2, "error");
}
else if (EVENT_MESSAGE == event) {
lua_pushlstring(L, conn->m_Buffer, conn->m_BufferSize);
lua_setfield(L, -2, "message");
}
dmScript::PCall(L, 3, 0); dmScript::PCall(L, 3, 0);
@@ -329,7 +382,7 @@ static void LuaInit(lua_State* L)
assert(top == lua_gettop(L)); assert(top == lua_gettop(L));
} }
static dmExtension::Result WebsocketAppInitialize(dmExtension::AppParams* params) static dmExtension::Result AppInitialize(dmExtension::AppParams* params)
{ {
g_Websocket.m_BufferSize = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.buffer_size", 64 * 1024); g_Websocket.m_BufferSize = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.buffer_size", 64 * 1024);
g_Websocket.m_Timeout = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.socket_timeout", 500 * 1000); g_Websocket.m_Timeout = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.socket_timeout", 500 * 1000);
@@ -341,6 +394,10 @@ static dmExtension::Result WebsocketAppInitialize(dmExtension::AppParams* params
pool_params.m_MaxConnections = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.max_connections", 2); pool_params.m_MaxConnections = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.max_connections", 2);
dmConnectionPool::Result result = dmConnectionPool::New(&pool_params, &g_Websocket.m_Pool); dmConnectionPool::Result result = dmConnectionPool::New(&pool_params, &g_Websocket.m_Pool);
g_DebugWebSocket = dmConfigFile::GetInt(params->m_ConfigFile, "websocket.debug", 0);
if (g_DebugWebSocket)
dmLogInfo("dmWebSocket::g_DebugWebSocket == %d", g_DebugWebSocket);
if (dmConnectionPool::RESULT_OK != result) if (dmConnectionPool::RESULT_OK != result)
{ {
dmLogError("Failed to create connection pool: %d", result); dmLogError("Failed to create connection pool: %d", result);
@@ -380,7 +437,7 @@ static dmExtension::Result WebsocketAppInitialize(dmExtension::AppParams* params
return dmExtension::RESULT_OK; return dmExtension::RESULT_OK;
} }
static dmExtension::Result WebsocketInitialize(dmExtension::Params* params) static dmExtension::Result Initialize(dmExtension::Params* params)
{ {
if (!g_Websocket.m_Initialized) if (!g_Websocket.m_Initialized)
return dmExtension::RESULT_OK; return dmExtension::RESULT_OK;
@@ -391,19 +448,50 @@ static dmExtension::Result WebsocketInitialize(dmExtension::Params* params)
return dmExtension::RESULT_OK; return dmExtension::RESULT_OK;
} }
static dmExtension::Result WebsocketAppFinalize(dmExtension::AppParams* params) static dmExtension::Result AppFinalize(dmExtension::AppParams* params)
{ {
dmConnectionPool::Shutdown(g_Websocket.m_Pool, dmSocket::SHUTDOWNTYPE_READWRITE); dmConnectionPool::Shutdown(g_Websocket.m_Pool, dmSocket::SHUTDOWNTYPE_READWRITE);
return dmExtension::RESULT_OK; return dmExtension::RESULT_OK;
} }
static dmExtension::Result WebsocketFinalize(dmExtension::Params* params) static dmExtension::Result Finalize(dmExtension::Params* params)
{ {
return dmExtension::RESULT_OK; return dmExtension::RESULT_OK;
} }
static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params) Result PushMessage(WebsocketConnection* conn, MessageType type, int length, const uint8_t* buffer)
{
if (conn->m_Messages.Full())
conn->m_Messages.OffsetCapacity(4);
Message msg;
msg.m_Type = (uint32_t)type;
msg.m_Length = length;
conn->m_Messages.Push(msg);
// No need to copy itself (html5)
if (buffer != (const uint8_t*)conn->m_Buffer)
{
if ((conn->m_BufferSize + length) >= conn->m_BufferCapacity)
{
conn->m_BufferCapacity = conn->m_BufferSize + length + 1;
conn->m_Buffer = (char*)realloc(conn->m_Buffer, conn->m_BufferCapacity);
}
// append to the end of the buffer
memcpy(conn->m_Buffer + conn->m_BufferSize, buffer, length);
}
conn->m_BufferSize += length;
conn->m_Buffer[conn->m_BufferCapacity-1] = 0;
// Instead of printing from the incoming buffer, we print from our own, to make sure it looks ok
DebugPrint(2, __FUNCTION__, conn->m_Buffer+conn->m_BufferSize-length, length);
return dmWebsocket::RESULT_OK;
}
static dmExtension::Result OnUpdate(dmExtension::Params* params)
{ {
uint32_t size = g_Websocket.m_Connections.Size(); uint32_t size = g_Websocket.m_Connections.Size();
@@ -415,10 +503,11 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params)
{ {
if (RESULT_OK != conn->m_Status) if (RESULT_OK != conn->m_Status)
{ {
HandleCallback(conn, EVENT_ERROR); HandleCallback(conn, EVENT_ERROR, 0, conn->m_BufferSize);
conn->m_BufferSize = 0;
} }
HandleCallback(conn, EVENT_DISCONNECTED); HandleCallback(conn, EVENT_DISCONNECTED, 0, conn->m_BufferSize);
g_Websocket.m_Connections.EraseSwap(i); g_Websocket.m_Connections.EraseSwap(i);
--i; --i;
@@ -434,12 +523,6 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params)
CLOSE_CONN("Websocket closing for %s (%s)", conn->m_Url.m_Hostname, WSL_ResultToString(r)); CLOSE_CONN("Websocket closing for %s (%s)", conn->m_Url.m_Hostname, WSL_ResultToString(r));
continue; continue;
} }
r = WSL_WantsExit(conn->m_Ctx);
if (0 != r)
{
CLOSE_CONN("Websocket received close event for %s", conn->m_Url.m_Hostname);
continue;
}
#else #else
int recv_bytes = 0; int recv_bytes = 0;
dmSocket::Result sr = Receive(conn, conn->m_Buffer, conn->m_BufferCapacity-1, &recv_bytes); dmSocket::Result sr = Receive(conn, conn->m_Buffer, conn->m_BufferCapacity-1, &recv_bytes);
@@ -450,9 +533,7 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params)
if (dmSocket::RESULT_OK == sr) if (dmSocket::RESULT_OK == sr)
{ {
conn->m_BufferSize += recv_bytes; PushMessage(conn, MESSAGE_TYPE_NORMAL, recv_bytes, (const uint8_t*)conn->m_Buffer);
conn->m_Buffer[conn->m_BufferCapacity-1] = 0;
conn->m_HasMessage = 1;
} }
else else
{ {
@@ -461,10 +542,31 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params)
} }
#endif #endif
if (conn->m_HasMessage) uint32_t offset = 0;
bool close_received = false;
for (uint32_t i = 0; i < conn->m_Messages.Size(); ++i)
{ {
HandleCallback(conn, EVENT_MESSAGE); const Message& msg = conn->m_Messages[i];
conn->m_HasMessage = 0;
if (EVENT_DISCONNECTED == msg.m_Type)
{
conn->m_Status = RESULT_OK;
CloseConnection(conn);
// Put the message at the front of the buffer
conn->m_Messages.SetSize(0);
conn->m_BufferSize = 0;
PushMessage(conn, MESSAGE_TYPE_CLOSE, msg.m_Length, (const uint8_t*)conn->m_Buffer+offset);
close_received = true;
break;
}
HandleCallback(conn, EVENT_MESSAGE, offset, msg.m_Length);
offset += msg.m_Length;
}
if (!close_received) // saving the close message for next step
{
conn->m_Messages.SetSize(0);
conn->m_BufferSize = 0; conn->m_BufferSize = 0;
} }
} }
@@ -482,6 +584,7 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params)
continue; continue;
} }
// Verifies headers, and also stages any initial sent data
result = VerifyHeaders(conn); result = VerifyHeaders(conn);
if (RESULT_OK != result) if (RESULT_OK != result)
{ {
@@ -505,11 +608,8 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params)
#endif #endif
dmSocket::SetBlocking(conn->m_Socket, false); dmSocket::SetBlocking(conn->m_Socket, false);
conn->m_Buffer[0] = 0;
conn->m_BufferSize = 0;
SetState(conn, STATE_CONNECTED); SetState(conn, STATE_CONNECTED);
HandleCallback(conn, EVENT_CONNECTED); HandleCallback(conn, EVENT_CONNECTED, 0, 0);
} }
else if (STATE_HANDSHAKE_WRITE == conn->m_State) else if (STATE_HANDSHAKE_WRITE == conn->m_State)
{ {
@@ -580,6 +680,6 @@ static dmExtension::Result WebsocketOnUpdate(dmExtension::Params* params)
} // dmWebsocket } // dmWebsocket
DM_DECLARE_EXTENSION(Websocket, LIB_NAME, dmWebsocket::WebsocketAppInitialize, dmWebsocket::WebsocketAppFinalize, dmWebsocket::WebsocketInitialize, dmWebsocket::WebsocketOnUpdate, 0, dmWebsocket::WebsocketFinalize) DM_DECLARE_EXTENSION(Websocket, LIB_NAME, dmWebsocket::AppInitialize, dmWebsocket::AppFinalize, dmWebsocket::Initialize, dmWebsocket::OnUpdate, 0, dmWebsocket::Finalize)
#undef CLOSE_CONN #undef CLOSE_CONN

View File

@@ -19,6 +19,7 @@
#include <dmsdk/dlib/socket.h> #include <dmsdk/dlib/socket.h>
#include <dmsdk/dlib/dns.h> #include <dmsdk/dlib/dns.h>
#include <dmsdk/dlib/uri.h> #include <dmsdk/dlib/uri.h>
#include <dmsdk/dlib/array.h>
namespace dmCrypt namespace dmCrypt
{ {
@@ -59,6 +60,18 @@ namespace dmWebsocket
EVENT_ERROR, EVENT_ERROR,
}; };
enum MessageType
{
MESSAGE_TYPE_NORMAL = 0,
MESSAGE_TYPE_CLOSE = 1,
};
struct Message
{
uint32_t m_Length:30;
uint32_t m_Type:2;
};
struct WebsocketConnection struct WebsocketConnection
{ {
dmScript::LuaCallbackInfo* m_Callback; dmScript::LuaCallbackInfo* m_Callback;
@@ -69,14 +82,16 @@ namespace dmWebsocket
dmConnectionPool::HConnection m_Connection; dmConnectionPool::HConnection m_Connection;
dmSocket::Socket m_Socket; dmSocket::Socket m_Socket;
dmSSLSocket::Socket m_SSLSocket; dmSSLSocket::Socket m_SSLSocket;
dmArray<Message> m_Messages; // lengths of the messages in the data buffer
uint8_t m_Key[16]; uint8_t m_Key[16];
State m_State; State m_State;
uint32_t m_SSL:1;
uint32_t m_HasMessage:1;
char* m_Buffer; char* m_Buffer;
int m_BufferSize; int m_BufferSize;
uint32_t m_BufferCapacity; uint32_t m_BufferCapacity;
Result m_Status; Result m_Status;
uint8_t m_SSL:1;
uint8_t m_HasHandshakeData:1;
uint8_t :6;
}; };
// Set error message // Set error message
@@ -96,13 +111,15 @@ namespace dmWebsocket
Result ReceiveHeaders(WebsocketConnection* conn); Result ReceiveHeaders(WebsocketConnection* conn);
Result VerifyHeaders(WebsocketConnection* conn); Result VerifyHeaders(WebsocketConnection* conn);
// Messages
Result PushMessage(WebsocketConnection* conn, MessageType type, int length, const uint8_t* msg);
#if defined(HAVE_WSLAY) #if defined(HAVE_WSLAY)
// Wslay callbacks // Wslay callbacks
int WSL_Init(wslay_event_context_ptr* ctx, ssize_t buffer_size, void* userctx); int WSL_Init(wslay_event_context_ptr* ctx, ssize_t buffer_size, void* userctx);
void WSL_Exit(wslay_event_context_ptr ctx); void WSL_Exit(wslay_event_context_ptr ctx);
int WSL_Close(wslay_event_context_ptr ctx); int WSL_Close(wslay_event_context_ptr ctx);
int WSL_Poll(wslay_event_context_ptr ctx); int WSL_Poll(wslay_event_context_ptr ctx);
int WSL_WantsExit(wslay_event_context_ptr ctx);
ssize_t WSL_RecvCallback(wslay_event_context_ptr ctx, uint8_t *buf, size_t len, int flags, void *user_data); ssize_t WSL_RecvCallback(wslay_event_context_ptr ctx, uint8_t *buf, size_t len, int flags, void *user_data);
ssize_t WSL_SendCallback(wslay_event_context_ptr ctx, const uint8_t *data, size_t len, int flags, void *user_data); ssize_t WSL_SendCallback(wslay_event_context_ptr ctx, const uint8_t *data, size_t len, int flags, void *user_data);
void WSL_OnMsgRecvCallback(wslay_event_context_ptr ctx, const struct wslay_event_on_msg_recv_arg *arg, void *user_data); void WSL_OnMsgRecvCallback(wslay_event_context_ptr ctx, const struct wslay_event_on_msg_recv_arg *arg, void *user_data);
@@ -114,6 +131,15 @@ namespace dmWebsocket
typedef struct { uint64_t state; uint64_t inc; } pcg32_random_t; typedef struct { uint64_t state; uint64_t inc; } pcg32_random_t;
void pcg32_srandom_r(pcg32_random_t* rng, uint64_t initstate, uint64_t initseq); void pcg32_srandom_r(pcg32_random_t* rng, uint64_t initstate, uint64_t initseq);
uint32_t pcg32_random_r(pcg32_random_t* rng); uint32_t pcg32_random_r(pcg32_random_t* rng);
// If level <= dmWebSocket::g_DebugWebSocket, then it outputs the message
#ifdef __GNUC__
void DebugLog(int level, const char* fmt, ...) __attribute__ ((format (printf, 2, 3)));
#else
void DebugLog(int level, const char* fmt, ...);
#endif
void DebugPrint(int level, const char* msg, const void* _bytes, uint32_t num_bytes);
} }

View File

@@ -54,8 +54,8 @@ void WSL_Exit(wslay_event_context_ptr ctx)
int WSL_Close(wslay_event_context_ptr ctx) int WSL_Close(wslay_event_context_ptr ctx)
{ {
const char* reason = "Client wants to close"; const char* reason = "";
wslay_event_queue_close(ctx, 0, (const uint8_t*)reason, strlen(reason)); wslay_event_queue_close(ctx, WSLAY_CODE_NORMAL_CLOSURE, (const uint8_t*)reason, 0);
return 0; return 0;
} }
@@ -68,24 +68,25 @@ int WSL_Poll(wslay_event_context_ptr ctx)
return r; return r;
} }
int WSL_WantsExit(wslay_event_context_ptr ctx)
{
if ((wslay_event_get_close_sent(ctx) && wslay_event_get_close_received(ctx))) {
return 1;
}
return 0;
}
ssize_t WSL_RecvCallback(wslay_event_context_ptr ctx, uint8_t *buf, size_t len, int flags, void *user_data) ssize_t WSL_RecvCallback(wslay_event_context_ptr ctx, uint8_t *buf, size_t len, int flags, void *user_data)
{ {
WebsocketConnection* conn = (WebsocketConnection*)user_data; WebsocketConnection* conn = (WebsocketConnection*)user_data;
int r = -1; // received bytes if >=0, error if < 0 int r = -1; // received bytes if >=0, error if < 0
if (conn->m_HasHandshakeData)
{
r = conn->m_BufferSize;
memcpy(buf, conn->m_Buffer, r);
conn->m_BufferSize = 0;
conn->m_HasHandshakeData = 0;
return r;
}
dmSocket::Result socket_result = Receive(conn, buf, len, &r); dmSocket::Result socket_result = Receive(conn, buf, len, &r);
if (dmSocket::RESULT_OK == socket_result && r == 0) if (dmSocket::RESULT_OK == socket_result && r == 0)
socket_result = dmSocket::RESULT_WOULDBLOCK; socket_result = dmSocket::RESULT_CONNABORTED;
if (dmSocket::RESULT_OK != socket_result) if (dmSocket::RESULT_OK != socket_result)
{ {
@@ -117,20 +118,35 @@ ssize_t WSL_SendCallback(wslay_event_context_ptr ctx, const uint8_t *data, size_
return (ssize_t)sent_bytes; return (ssize_t)sent_bytes;
} }
// Might be called multiple times for a connection receiving multiple events
void WSL_OnMsgRecvCallback(wslay_event_context_ptr ctx, const struct wslay_event_on_msg_recv_arg *arg, void *user_data) void WSL_OnMsgRecvCallback(wslay_event_context_ptr ctx, const struct wslay_event_on_msg_recv_arg *arg, void *user_data)
{ {
WebsocketConnection* conn = (WebsocketConnection*)user_data; WebsocketConnection* conn = (WebsocketConnection*)user_data;
if (arg->opcode == WSLAY_TEXT_FRAME || arg->opcode == WSLAY_BINARY_FRAME) if (arg->opcode == WSLAY_TEXT_FRAME || arg->opcode == WSLAY_BINARY_FRAME)
{ {
if (arg->msg_length >= conn->m_BufferCapacity) PushMessage(conn, MESSAGE_TYPE_NORMAL, arg->msg_length, arg->msg);
conn->m_Buffer = (char*)realloc(conn->m_Buffer, arg->msg_length + 1);
memcpy(conn->m_Buffer, arg->msg, arg->msg_length);
conn->m_BufferSize = arg->msg_length;
conn->m_HasMessage = 1;
} else if (arg->opcode == WSLAY_CONNECTION_CLOSE) } else if (arg->opcode == WSLAY_CONNECTION_CLOSE)
{ {
// TODO: Store the reason // The first two bytes is the close code
const uint8_t* reason = (const uint8_t*)"";
size_t len = arg->msg_length;
if (arg->msg_length > 2)
{
reason = arg->msg + 2;
len -= 2;
}
char buffer[1024];
len = dmSnPrintf(buffer, sizeof(buffer), "Server closing (%u). Reason: '%s'", wslay_event_get_status_code_received(ctx), reason);
PushMessage(conn, MESSAGE_TYPE_CLOSE, len, (const uint8_t*)buffer);
if (!wslay_event_get_close_sent(ctx))
{
wslay_event_queue_close(ctx, arg->status_code, (const uint8_t*)buffer, len);
}
DebugLog(1, "%s", buffer);
} }
} }