From 5164b888d895f2f91bb40290f6b2ef0e35b89571 Mon Sep 17 00:00:00 2001 From: Vicki Pfau Date: Tue, 31 Jan 2023 17:22:29 -0800 Subject: [PATCH] Scripting: Allow Lua to pass nested tables to the scripting subsystem --- include/mgba/script/types.h | 3 ++- src/script/engines/lua.c | 25 +++++++++++++++---- src/script/test/lua.c | 50 ++++++++++++++++++++++++++++++++++++- 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/include/mgba/script/types.h b/include/mgba/script/types.h index ed8ae02e3..0be1c41e3 100644 --- a/include/mgba/script/types.h +++ b/include/mgba/script/types.h @@ -35,7 +35,7 @@ CXX_GUARD_START #define mSCRIPT_TYPE_C_PTR void* #define mSCRIPT_TYPE_C_CPTR const void* #define mSCRIPT_TYPE_C_LIST struct mScriptList* -#define mSCRIPT_TYPE_C_TABLE Table* +#define mSCRIPT_TYPE_C_TABLE struct Table* #define mSCRIPT_TYPE_C_WRAPPER struct mScriptValue* #define mSCRIPT_TYPE_C_WEAKREF uint32_t #define mSCRIPT_TYPE_C_S(STRUCT) struct STRUCT* @@ -120,6 +120,7 @@ CXX_GUARD_START #define mSCRIPT_TYPE_CMP_STR(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_STR, TYPE) #define mSCRIPT_TYPE_CMP_CHARP(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_CHARP, TYPE) #define mSCRIPT_TYPE_CMP_LIST(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_LIST, TYPE) +#define mSCRIPT_TYPE_CMP_TABLE(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_TABLE, TYPE) #define mSCRIPT_TYPE_CMP_PTR(TYPE) ((TYPE)->base >= mSCRIPT_TYPE_OPAQUE) #define mSCRIPT_TYPE_CMP_WRAPPER(TYPE) (true) #define mSCRIPT_TYPE_CMP_S(STRUCT) mSCRIPT_TYPE_MS_S(STRUCT)->name == _mSCRIPT_FIELD_NAME diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index bac841f1b..95ac70eba 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -538,11 +538,13 @@ struct mScriptValue* _luaCoerceFunction(struct mScriptEngineContextLua* luaConte return value; } -struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext) { +struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext, struct Table* markedObjects) { struct mScriptValue* table = mScriptValueAlloc(mSCRIPT_TYPE_MS_TABLE); bool isList = true; lua_pushnil(luaContext->lua); + + void* tablePointer; while (lua_next(luaContext->lua, -2) != 0) { struct mScriptValue* value = NULL; int type = lua_type(luaContext->lua, -1); @@ -553,14 +555,23 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext) case LUA_TFUNCTION: value = _luaCoerce(luaContext, true); break; + case LUA_TTABLE: + tablePointer = lua_topointer(luaContext->lua, -1); + // Ensure this table doesn't contain any cycles + if (!HashTableLookupBinary(markedObjects, &tablePointer, sizeof(tablePointer))) { + HashTableInsertBinary(markedObjects, &tablePointer, sizeof(tablePointer), tablePointer); + value = _luaCoerceTable(luaContext, markedObjects); + if (value) { + mScriptValueRef(value); + } + } default: - // Don't let values be something that could contain themselves break; } if (!value) { - lua_pop(luaContext->lua, 3); + lua_pop(luaContext->lua, type == LUA_TTABLE ? 2 : 3); mScriptValueDeref(table); - return false; + return NULL; } struct mScriptValue* key = NULL; @@ -628,6 +639,7 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool size_t size; const void* buffer; + struct Table markedObjects; struct mScriptValue* value = NULL; switch (lua_type(luaContext->lua, -1)) { case LUA_TNIL: @@ -664,7 +676,10 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool if (!pop) { break; } - return _luaCoerceTable(luaContext); + HashTableInit(&markedObjects, 0, NULL); + value = _luaCoerceTable(luaContext, &markedObjects); + HashTableDeinit(&markedObjects); + return value; case LUA_TUSERDATA: if (!lua_getmetatable(luaContext->lua, -1)) { break; diff --git a/src/script/test/lua.c b/src/script/test/lua.c index 5731b0705..eab8550dc 100644 --- a/src/script/test/lua.c +++ b/src/script/test/lua.c @@ -66,9 +66,14 @@ static int32_t sum(struct mScriptList* list) { return sum; } +static unsigned tableSize(struct Table* table) { + return TableSize(table); +} + mSCRIPT_BIND_FUNCTION(boundIdentityInt, S32, identityInt, 1, S32, a); mSCRIPT_BIND_FUNCTION(boundAddInts, S32, addInts, 2, S32, a, S32, b); mSCRIPT_BIND_FUNCTION(boundSum, S32, sum, 1, LIST, list); +mSCRIPT_BIND_FUNCTION(boundTableSize, U32, tableSize, 1, TABLE, table); mSCRIPT_DECLARE_STRUCT(Test); mSCRIPT_DECLARE_STRUCT_D_METHOD(Test, S32, ifn0, 0); @@ -371,8 +376,50 @@ M_TEST_DEFINE(callCFunc) { assert_true(a.type->equal(&a, val)); mScriptValueDeref(val); + LOAD_PROGRAM("b('a')"); + assert_false(lua->run(lua)); + mScriptContextDeinit(&context); } + +M_TEST_DEFINE(callCTable) { + SETUP_LUA; + + struct mScriptValue a = mSCRIPT_MAKE_S32(1); + struct mScriptValue* val; + + assert_true(lua->setGlobal(lua, "b", &boundTableSize)); + + TEST_PROGRAM("assert(b({}) == 0)"); + assert_null(lua->getError(lua)); + + TEST_PROGRAM("assert(b({[2]=1}) == 1)"); + assert_null(lua->getError(lua)); + + TEST_PROGRAM("assert(b({a=1}) == 1)"); + assert_null(lua->getError(lua)); + + TEST_PROGRAM("assert(b({a={}}) == 1)"); + assert_null(lua->getError(lua)); + + LOAD_PROGRAM( + "a = {}\n" + "a.b = a\n" + "assert(b(a) == 1)\n" + ); + assert_false(lua->run(lua)); + + LOAD_PROGRAM( + "a = {}\n" + "a.b = {}\n" + "a.b.c = a\n" + "assert(b(a) == 1)\n" + ); + assert_false(lua->run(lua)); + + mScriptContextDeinit(&context); +} + M_TEST_DEFINE(globalNull) { SETUP_LUA; @@ -774,6 +821,7 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(rootScope), cmocka_unit_test(callLuaFunc), cmocka_unit_test(callCFunc), + cmocka_unit_test(callCTable), cmocka_unit_test(globalNull), cmocka_unit_test(globalStructFieldGet), cmocka_unit_test(globalStructFieldSet), @@ -782,5 +830,5 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(tableLookup), cmocka_unit_test(tableIterate), cmocka_unit_test(callList), - cmocka_unit_test(linkedList) + cmocka_unit_test(linkedList), )