diff --git a/include/mgba/script/context.h b/include/mgba/script/context.h index 80682ee29..753c356ba 100644 --- a/include/mgba/script/context.h +++ b/include/mgba/script/context.h @@ -29,6 +29,8 @@ struct mScriptContext { struct Table weakrefs; uint32_t nextWeakref; struct Table callbacks; + struct Table callbackId; + uint32_t nextCallbackId; struct mScriptValue* constants; struct Table docstrings; }; @@ -85,7 +87,8 @@ void mScriptContextExportConstants(struct mScriptContext* context, const char* n void mScriptContextExportNamespace(struct mScriptContext* context, const char* nspace, struct mScriptKVPair* value); void mScriptContextTriggerCallback(struct mScriptContext*, const char* callback); -void mScriptContextAddCallback(struct mScriptContext*, const char* callback, struct mScriptValue* value); +uint32_t mScriptContextAddCallback(struct mScriptContext*, const char* callback, struct mScriptValue* value); +void mScriptContextRemoveCallback(struct mScriptContext*, uint32_t cbid); void mScriptContextSetDocstring(struct mScriptContext*, const char* key, const char* docstring); const char* mScriptContextGetDocstring(struct mScriptContext*, const char* key); diff --git a/src/script/context.c b/src/script/context.c index 7c9b32dab..ad4e445fc 100644 --- a/src/script/context.c +++ b/src/script/context.c @@ -14,6 +14,11 @@ struct mScriptFileInfo { struct mScriptEngineContext* context; }; +struct mScriptCallbackInfo { + const char* callback; + size_t id; +}; + static void _engineContextDestroy(void* ctx) { struct mScriptEngineContext* context = ctx; context->destroy(context); @@ -56,6 +61,8 @@ void mScriptContextInit(struct mScriptContext* context) { TableInit(&context->weakrefs, 0, (void (*)(void*)) mScriptValueDeref); context->nextWeakref = 1; HashTableInit(&context->callbacks, 0, (void (*)(void*)) mScriptValueDeref); + TableInit(&context->callbackId, 0, free); + context->nextCallbackId = 1; context->constants = NULL; HashTableInit(&context->docstrings, 0, NULL); } @@ -66,6 +73,7 @@ void mScriptContextDeinit(struct mScriptContext* context) { mScriptContextDrainPool(context); mScriptListDeinit(&context->refPool); HashTableDeinit(&context->callbacks); + TableDeinit(&context->callbackId); HashTableDeinit(&context->engines); HashTableDeinit(&context->docstrings); } @@ -199,8 +207,11 @@ void mScriptContextTriggerCallback(struct mScriptContext* context, const char* c size_t i; for (i = 0; i < mScriptListSize(list->value.list); ++i) { struct mScriptFrame frame; - mScriptFrameInit(&frame); struct mScriptValue* fn = mScriptListGetPointer(list->value.list, i); + if (!fn->type) { + continue; + } + mScriptFrameInit(&frame); if (fn->type->base == mSCRIPT_TYPE_WRAPPER) { fn = mScriptValueUnwrap(fn); } @@ -209,16 +220,48 @@ void mScriptContextTriggerCallback(struct mScriptContext* context, const char* c } } -void mScriptContextAddCallback(struct mScriptContext* context, const char* callback, struct mScriptValue* fn) { +uint32_t mScriptContextAddCallback(struct mScriptContext* context, const char* callback, struct mScriptValue* fn) { if (fn->type->base != mSCRIPT_TYPE_FUNCTION) { - return; + return 0; } struct mScriptValue* list = HashTableLookup(&context->callbacks, callback); if (!list) { list = mScriptValueAlloc(mSCRIPT_TYPE_MS_LIST); HashTableInsert(&context->callbacks, callback, list); } + struct mScriptCallbackInfo* info = malloc(sizeof(*info)); + // Steal the string from the table key, since it's guaranteed to outlive this struct + struct TableIterator iter; + HashTableIteratorLookup(&context->callbacks, &iter, callback); + info->callback = HashTableIteratorGetKey(&context->callbacks, &iter); + info->id = mScriptListSize(list->value.list); mScriptValueWrap(fn, mScriptListAppend(list->value.list)); + while (true) { + uint32_t id = context->nextCallbackId; + ++context->nextCallbackId; + if (TableLookup(&context->callbackId, id)) { + continue; + } + TableInsert(&context->callbackId, id, info); + return id; + } +} + +void mScriptContextRemoveCallback(struct mScriptContext* context, uint32_t cbid) { + struct mScriptCallbackInfo* info = TableLookup(&context->callbackId, cbid); + if (!info) { + return; + } + struct mScriptValue* list = HashTableLookup(&context->callbacks, info->callback); + if (!list) { + return; + } + if (info->id >= mScriptListSize(list->value.list)) { + return; + } + struct mScriptValue* fn = mScriptValueUnwrap(mScriptListGetPointer(list->value.list, info->id)); + mScriptValueDeref(fn); + mScriptListGetPointer(list->value.list, info->id)->type = NULL; } void mScriptContextExportConstants(struct mScriptContext* context, const char* nspace, struct mScriptKVPair* constants) { diff --git a/src/script/stdlib.c b/src/script/stdlib.c index 695946d05..839673cdb 100644 --- a/src/script/stdlib.c +++ b/src/script/stdlib.c @@ -19,16 +19,22 @@ struct mScriptCallbackManager { struct mScriptContext* context; }; -static void _mScriptCallbackAdd(struct mScriptCallbackManager* adapter, struct mScriptString* name, struct mScriptValue* fn) { +static uint32_t _mScriptCallbackAdd(struct mScriptCallbackManager* adapter, struct mScriptString* name, struct mScriptValue* fn) { if (fn->type->base == mSCRIPT_TYPE_WRAPPER) { fn = mScriptValueUnwrap(fn); } - mScriptContextAddCallback(adapter->context, name->buffer, fn); + uint32_t id = mScriptContextAddCallback(adapter->context, name->buffer, fn); mScriptValueDeref(fn); + return id; +} + +static void _mScriptCallbackRemove(struct mScriptCallbackManager* adapter, uint32_t id) { + mScriptContextRemoveCallback(adapter->context, id); } mSCRIPT_DECLARE_STRUCT(mScriptCallbackManager); -mSCRIPT_DECLARE_STRUCT_VOID_METHOD(mScriptCallbackManager, add, _mScriptCallbackAdd, 2, STR, callback, WRAPPER, function); +mSCRIPT_DECLARE_STRUCT_METHOD(mScriptCallbackManager, U32, add, _mScriptCallbackAdd, 2, STR, callback, WRAPPER, function); +mSCRIPT_DECLARE_STRUCT_VOID_METHOD(mScriptCallbackManager, remove, _mScriptCallbackRemove, 1, U32, cbid); static uint64_t mScriptMakeBitmask(struct mScriptList* list) { size_t i; @@ -76,8 +82,10 @@ mSCRIPT_DEFINE_STRUCT(mScriptCallbackManager) "- **start**: The emulation has started\n" "- **stop**: The emulation has voluntarily shut down\n" ) - mSCRIPT_DEFINE_DOCSTRING("Add a callback of the named type") + mSCRIPT_DEFINE_DOCSTRING("Add a callback of the named type. The returned id can be used to remove it later") mSCRIPT_DEFINE_STRUCT_METHOD(mScriptCallbackManager, add) + mSCRIPT_DEFINE_DOCSTRING("Remove a callback with the previously retuned id") + mSCRIPT_DEFINE_STRUCT_METHOD(mScriptCallbackManager, remove) mSCRIPT_DEFINE_END; void mScriptContextAttachStdlib(struct mScriptContext* context) { diff --git a/src/script/test/stdlib.c b/src/script/test/stdlib.c index c6e9fd3fe..d4c57605b 100644 --- a/src/script/test/stdlib.c +++ b/src/script/test/stdlib.c @@ -7,6 +7,7 @@ #include #include +#include #include #define SETUP_LUA \ @@ -85,7 +86,36 @@ M_TEST_DEFINE(bitUnmask) { mScriptContextDeinit(&context); } +M_TEST_DEFINE(callbacks) { + SETUP_LUA; + + TEST_PROGRAM( + "val = 0\n" + "function cb()\n" + " val = val + 1\n" + "end\n" + "id = callbacks:add('test', cb)\n" + "assert(id)" + ); + + TEST_VALUE(S32, "val", 0); + + mScriptContextTriggerCallback(&context, "test"); + TEST_VALUE(S32, "val", 1); + + mScriptContextTriggerCallback(&context, "test"); + TEST_VALUE(S32, "val", 2); + + TEST_PROGRAM("callbacks:remove(id)"); + + mScriptContextTriggerCallback(&context, "test"); + TEST_VALUE(S32, "val", 2); + + mScriptContextDeinit(&context); +} + M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptStdlib, cmocka_unit_test(bitMask), cmocka_unit_test(bitUnmask), + cmocka_unit_test(callbacks), ) diff --git a/src/script/types.c b/src/script/types.c index 61fcadcd2..52d0425b6 100644 --- a/src/script/types.c +++ b/src/script/types.c @@ -249,6 +249,9 @@ void _allocList(struct mScriptValue* val) { void _freeList(struct mScriptValue* val) { size_t i; for (i = 0; i < mScriptListSize(val->value.list); ++i) { + if (val->type) { + continue; + } struct mScriptValue* unwrapped = mScriptValueUnwrap(mScriptListGetPointer(val->value.list, i)); if (unwrapped) { mScriptValueDeref(unwrapped);