diff --git a/monarch/utils/callback_tracker.lua b/monarch/utils/callback_tracker.lua index c02ba08..6151340 100644 --- a/monarch/utils/callback_tracker.lua +++ b/monarch/utils/callback_tracker.lua @@ -18,9 +18,15 @@ function M.create() -- @return Callback function function instance.track() callback_count = callback_count + 1 + local done = false return function() + if done then + return false, "The callback has already been invoked once" + end + done = true callback_count = callback_count - 1 invoke_if_done() + return true end end diff --git a/test/test.script b/test/test.script index e9477ed..9fb53ba 100644 --- a/test/test.script +++ b/test/test.script @@ -1,10 +1,12 @@ local deftest = require "deftest.deftest" local test_monarch = require "test.test_monarch" +local test_callback_tracker = require "test.test_callback_tracker" function init(self) deftest.add(test_monarch) + deftest.add(test_callback_tracker) deftest.run({ coverage = { enabled = true }, --pattern = "preload", diff --git a/test/test_callback_tracker.lua b/test/test_callback_tracker.lua new file mode 100644 index 0000000..e31f39b --- /dev/null +++ b/test/test_callback_tracker.lua @@ -0,0 +1,69 @@ +local unload = require "deftest.util.unload" +local cowat = require "test.cowait" +local callback_tracker = require "monarch.utils.callback_tracker" + +return function() + + describe("callback tracker", function() + before(function() + callback_tracker = require "monarch.utils.callback_tracker" + end) + + after(function() + unload.unload("monarch%..*") + end) + + + it("should be able to tell when all callbacks are done", function() + local tracker = callback_tracker.create() + local t1 = tracker.track() + local t2 = tracker.track() + + local done = false + tracker.when_done(function() done = true end) + + assert(not done, "It should not be done yet - No callback has completed") + t1() + assert(not done, "It should not be done yet - Only one callback has completed") + t2() + assert(done, "It should be done now - All callbacks have completed") + end) + + it("should indicate if a tracked callback has been invoked more than once", function() + local tracker = callback_tracker.create() + local t = tracker.track() + local ok, err = t() + assert(ok == true and err == nil, "It should return true when successful") + ok, err = t() + assert(ok == false and err, "It should return false and a message when invoked multiple times") + end) + + it("should not be possible to track the same callback more than one time", function() + local tracker = callback_tracker.create() + local t1 = tracker.track() + local t2 = tracker.track() + + local done = false + tracker.when_done(function() done = true end) + + assert(not done, "It should not be done yet - No callback has completed") + t1() + t1() + assert(not done, "It should not be done yet - Even if one callback has been invoked twice") + t2() + assert(done, "It should be done now - All callbacks have completed") + end) + + it("should handle when callbacks are done before calling when_done()", function() + local tracker = callback_tracker.create() + local t1 = tracker.track() + local t2 = tracker.track() + t1() + t2() + + local done = false + tracker.when_done(function() done = true end) + assert(done, "It should be done now - All callbacks have completed") + end) + end) +end