diff --git a/include/mgba/script/types.h b/include/mgba/script/types.h index fd7dbe5d8..489b5eee4 100644 --- a/include/mgba/script/types.h +++ b/include/mgba/script/types.h @@ -108,6 +108,7 @@ CXX_GUARD_START #define mSCRIPT_TYPE_CMP_F64(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_F64, TYPE) #define mSCRIPT_TYPE_CMP_STR(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_STR, TYPE) #define mSCRIPT_TYPE_CMP_CHARP(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_CHARP, TYPE) +#define mSCRIPT_TYPE_CMP_LIST(TYPE) mSCRIPT_TYPE_CMP_GENERIC(mSCRIPT_TYPE_MS_LIST, TYPE) #define mSCRIPT_TYPE_CMP_PTR(TYPE) ((TYPE)->base >= mSCRIPT_TYPE_OPAQUE) #define mSCRIPT_TYPE_CMP_WRAPPER(TYPE) (true) #define mSCRIPT_TYPE_CMP_S(STRUCT) mSCRIPT_TYPE_MS_S(STRUCT)->name == _mSCRIPT_FIELD_NAME diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index 5c8906531..db1178993 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -30,7 +30,7 @@ static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*, static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptList*); static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*); -static struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext); +static struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool pop); static bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue*); static void _luaDeref(struct mScriptValue*); @@ -186,7 +186,7 @@ bool _luaIsScript(struct mScriptEngineContext* ctx, const char* name, struct VFi struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext* ctx, const char* name) { struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) ctx; lua_getglobal(luaContext->lua, name); - return _luaCoerce(luaContext); + return _luaCoerce(luaContext, true); } bool _luaSetGlobal(struct mScriptEngineContext* ctx, const char* name, struct mScriptValue* value) { @@ -212,7 +212,83 @@ struct mScriptValue* _luaCoerceFunction(struct mScriptEngineContextLua* luaConte return value; } -struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext) { +struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext) { + struct mScriptValue* table = mScriptValueAlloc(mSCRIPT_TYPE_MS_TABLE); + bool isList = true; + + lua_pushnil(luaContext->lua); + while (lua_next(luaContext->lua, -2) != 0) { + struct mScriptValue* value = NULL; + int type = lua_type(luaContext->lua, -1); + switch (type) { + case LUA_TNUMBER: + case LUA_TBOOLEAN: + case LUA_TSTRING: + case LUA_TFUNCTION: + value = _luaCoerce(luaContext, true); + break; + default: + // Don't let values be something that could contain themselves + break; + } + if (!value) { + lua_pop(luaContext->lua, 3); + mScriptValueDeref(table); + return false; + } + + struct mScriptValue* key = NULL; + type = lua_type(luaContext->lua, -1); + switch (type) { + case LUA_TBOOLEAN: + case LUA_TSTRING: + isList = false; + // Fall through + case LUA_TNUMBER: + key = _luaCoerce(luaContext, false); + break; + default: + // Limit keys to hashable types + break; + } + + if (!key) { + lua_pop(luaContext->lua, 2); + mScriptValueDeref(table); + return false; + } + mScriptTableInsert(table, key, value); + mScriptValueDeref(key); + mScriptValueDeref(value); + } + lua_pop(luaContext->lua, 1); + + size_t len = mScriptTableSize(table); + if (!isList || !len) { + return table; + } + + struct mScriptValue* list = mScriptValueAlloc(mSCRIPT_TYPE_MS_LIST); + size_t i; + for (i = 1; i <= len; ++i) { + struct mScriptValue* value = mScriptTableLookup(table, &mSCRIPT_MAKE_S64(i)); + if (!value) { + mScriptValueDeref(list); + return table; + } + mScriptValueWrap(value, mScriptListAppend(list->value.list)); + } + if (i != len + 1) { + mScriptValueDeref(list); + mScriptContextFillPool(luaContext->d.context, table); + return table; + } + mScriptValueDeref(table); + mScriptContextFillPool(luaContext->d.context, list); + return list; +} + +struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool pop) { if (lua_isnone(luaContext->lua, -1)) { lua_pop(luaContext->lua, 1); return NULL; @@ -244,7 +320,16 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext) { break; case LUA_TFUNCTION: // This function pops the value internally via luaL_ref + if (!pop) { + break; + } return _luaCoerceFunction(luaContext); + case LUA_TTABLE: + // This function pops the value internally via luaL_ref + if (!pop) { + break; + } + return _luaCoerceTable(luaContext); case LUA_TUSERDATA: if (!lua_getmetatable(luaContext->lua, -1)) { break; @@ -259,7 +344,9 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext) { value = mScriptContextAccessWeakref(luaContext->d.context, value); break; } - lua_pop(luaContext->lua, 1); + if (pop) { + lua_pop(luaContext->lua, 1); + } return value; } @@ -467,7 +554,7 @@ bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList if (frame) { int i; for (i = 0; i < count; ++i) { - struct mScriptValue* value = _luaCoerce(luaContext); + struct mScriptValue* value = _luaCoerce(luaContext, true); if (!value) { ok = false; break; @@ -640,7 +727,7 @@ int _luaSetObject(lua_State* lua) { char key[MAX_KEY_SIZE]; const char* keyPtr = lua_tostring(lua, -2); struct mScriptValue* obj = lua_touserdata(lua, -3); - struct mScriptValue* val = _luaCoerce(luaContext); + struct mScriptValue* val = _luaCoerce(luaContext, true); if (!keyPtr) { lua_pop(lua, 2); diff --git a/src/script/test/lua.c b/src/script/test/lua.c index 71b719966..bf3bf099a 100644 --- a/src/script/test/lua.c +++ b/src/script/test/lua.c @@ -61,8 +61,22 @@ static void testV1(struct Test* a, int b) { a->i += b; } +static int32_t sum(struct mScriptList* list) { + int32_t sum = 0; + size_t i; + for (i = 0; i < mScriptListSize(list); ++i) { + struct mScriptValue value; + if (!mScriptCast(mSCRIPT_TYPE_MS_S32, mScriptListGetPointer(list, i), &value)) { + continue; + } + sum += value.value.s32; + } + return sum; +} + mSCRIPT_BIND_FUNCTION(boundIdentityInt, S32, identityInt, 1, S32, a); mSCRIPT_BIND_FUNCTION(boundAddInts, S32, addInts, 2, S32, a, S32, b); +mSCRIPT_BIND_FUNCTION(boundSum, S32, sum, 1, LIST, list); mSCRIPT_DECLARE_STRUCT(Test); mSCRIPT_DECLARE_STRUCT_D_METHOD(Test, S32, ifn0, 0); @@ -619,6 +633,24 @@ M_TEST_DEFINE(tableIterate) { mScriptContextDeinit(&context); } +M_TEST_DEFINE(callList) { + SETUP_LUA; + + struct mScriptValue a = mSCRIPT_MAKE_S32(6); + struct mScriptValue* val; + + assert_true(lua->setGlobal(lua, "sum", &boundSum)); + TEST_PROGRAM("a = sum({1, 2, 3})"); + assert_null(lua->getError(lua)); + + val = lua->getGlobal(lua, "a"); + assert_non_null(val); + assert_true(mSCRIPT_TYPE_MS_S32->equal(&a, val)); + mScriptValueDeref(val); + + mScriptContextDeinit(&context); +} + M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(create), cmocka_unit_test(loadGood), @@ -634,4 +666,5 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(errorReporting), cmocka_unit_test(tableLookup), cmocka_unit_test(tableIterate), + cmocka_unit_test(callList), ) diff --git a/src/script/test/types.c b/src/script/test/types.c index 45a9a2d25..9fa83a8a5 100644 --- a/src/script/test/types.c +++ b/src/script/test/types.c @@ -52,6 +52,30 @@ static int isHello(const char* str) { return strcmp(str, "hello") == 0; } +static int isSequential(struct mScriptList* list) { + int last; + if (mScriptListSize(list) == 0) { + return true; + } + size_t i; + for (i = 0; i < mScriptListSize(list); ++i) { + struct mScriptValue* value = mScriptListGetPointer(list, i); + struct mScriptValue intValue; + if (!mScriptCast(mSCRIPT_TYPE_MS_S32, value, &intValue)) { + return false; + } + if (!i) { + last = intValue.value.s32; + } else { + if (intValue.value.s32 != last + 1) { + return false; + } + ++last; + } + } + return true; +} + mSCRIPT_BIND_FUNCTION(boundVoidOne, S32, voidOne, 0); mSCRIPT_BIND_VOID_FUNCTION(boundDiscard, discard, 1, S32, ignored); mSCRIPT_BIND_FUNCTION(boundIdentityInt, S32, identityInt, 1, S32, in); @@ -61,6 +85,7 @@ mSCRIPT_BIND_FUNCTION(boundIdentityStruct, S(Test), identityStruct, 1, S(Test), mSCRIPT_BIND_FUNCTION(boundAddInts, S32, addInts, 2, S32, a, S32, b); mSCRIPT_BIND_FUNCTION(boundSubInts, S32, subInts, 2, S32, a, S32, b); mSCRIPT_BIND_FUNCTION(boundIsHello, S32, isHello, 1, CHARP, str); +mSCRIPT_BIND_FUNCTION(boundIsSequential, S32, isSequential, 1, LIST, list); M_TEST_DEFINE(voidArgs) { struct mScriptFrame frame; @@ -919,6 +944,47 @@ M_TEST_DEFINE(stringIsNotHello) { mScriptFrameDeinit(&frame); } +M_TEST_DEFINE(invokeList) { + struct mScriptFrame frame; + struct mScriptList list; + int val; + + mScriptListInit(&list, 0); + + mScriptFrameInit(&frame); + mSCRIPT_PUSH(&frame.arguments, LIST, &list); + assert_true(mScriptInvoke(&boundIsSequential, &frame)); + assert_true(mScriptPopS32(&frame.returnValues, &val)); + assert_int_equal(val, 1); + mScriptFrameDeinit(&frame); + + *mScriptListAppend(&list) = mSCRIPT_MAKE_S32(1); + mScriptFrameInit(&frame); + mSCRIPT_PUSH(&frame.arguments, LIST, &list); + assert_true(mScriptInvoke(&boundIsSequential, &frame)); + assert_true(mScriptPopS32(&frame.returnValues, &val)); + assert_int_equal(val, 1); + mScriptFrameDeinit(&frame); + + *mScriptListAppend(&list) = mSCRIPT_MAKE_S32(2); + mScriptFrameInit(&frame); + mSCRIPT_PUSH(&frame.arguments, LIST, &list); + assert_true(mScriptInvoke(&boundIsSequential, &frame)); + assert_true(mScriptPopS32(&frame.returnValues, &val)); + assert_int_equal(val, 1); + mScriptFrameDeinit(&frame); + + *mScriptListAppend(&list) = mSCRIPT_MAKE_S32(4); + mScriptFrameInit(&frame); + mSCRIPT_PUSH(&frame.arguments, LIST, &list); + assert_true(mScriptInvoke(&boundIsSequential, &frame)); + assert_true(mScriptPopS32(&frame.returnValues, &val)); + assert_int_equal(val, 0); + mScriptFrameDeinit(&frame); + + mScriptListDeinit(&list); +} + M_TEST_SUITE_DEFINE(mScript, cmocka_unit_test(voidArgs), cmocka_unit_test(voidFunc), @@ -948,4 +1014,6 @@ M_TEST_SUITE_DEFINE(mScript, cmocka_unit_test(hashTableBasic), cmocka_unit_test(hashTableString), cmocka_unit_test(stringIsHello), - cmocka_unit_test(stringIsNotHello)) + cmocka_unit_test(stringIsNotHello), + cmocka_unit_test(invokeList), +)