diff --git a/include/mgba/script/types.h b/include/mgba/script/types.h index 239772483..7a6bc6513 100644 --- a/include/mgba/script/types.h +++ b/include/mgba/script/types.h @@ -284,6 +284,7 @@ bool mScriptTableInsert(struct mScriptValue* table, struct mScriptValue* key, st bool mScriptTableRemove(struct mScriptValue* table, struct mScriptValue* key); struct mScriptValue* mScriptTableLookup(struct mScriptValue* table, struct mScriptValue* key); bool mScriptTableClear(struct mScriptValue* table); +size_t mScriptTableSize(struct mScriptValue* table); bool mScriptTableIteratorStart(struct mScriptValue* table, struct TableIterator*); bool mScriptTableIteratorNext(struct mScriptValue* table, struct TableIterator*); struct mScriptValue* mScriptTableIteratorGetKey(struct mScriptValue* table, struct TableIterator*); diff --git a/src/core/test/scripting.c b/src/core/test/scripting.c index 3d2b8c4fa..308f053d3 100644 --- a/src/core/test/scripting.c +++ b/src/core/test/scripting.c @@ -186,7 +186,9 @@ M_TEST_DEFINE(detach) { LOAD_PROGRAM( "assert(emu)\n" + "assert(emu.memory)\n" "a = emu\n" + "b = emu.memory\n" ); assert_true(lua->run(lua)); @@ -202,6 +204,11 @@ M_TEST_DEFINE(detach) { ); assert_false(lua->run(lua)); + LOAD_PROGRAM( + "assert(memory.cart0)\n" + ); + assert_false(lua->run(lua)); + TEARDOWN_CORE; mScriptContextDeinit(&context); } diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index 8710d14e5..e5de3a05c 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -40,6 +40,7 @@ static int _luaGetObject(lua_State* lua); static int _luaSetObject(lua_State* lua); static int _luaGcObject(lua_State* lua); static int _luaGetTable(lua_State* lua); +static int _luaLenTable(lua_State* lua); static int _luaPairsTable(lua_State* lua); static int _luaGetList(lua_State* lua); static int _luaLenList(lua_State* lua); @@ -103,6 +104,7 @@ static const luaL_Reg _mSTStruct[] = { static const luaL_Reg _mSTTable[] = { { "__index", _luaGetTable }, + { "__len", _luaLenTable }, { "__pairs", _luaPairsTable }, { NULL, NULL } }; @@ -662,14 +664,24 @@ static int _luaGcObject(lua_State* lua) { int _luaGetTable(lua_State* lua) { struct mScriptEngineContextLua* luaContext = _luaGetContext(lua); char key[MAX_KEY_SIZE]; - const char* keyPtr = lua_tostring(lua, -1); - struct mScriptValue* obj = lua_touserdata(lua, -2); - - if (!keyPtr) { + int type = lua_type(luaContext->lua, -1); + const char* keyPtr = NULL; + int64_t intKey; + switch (type) { + case LUA_TNUMBER: + intKey = lua_tointeger(luaContext->lua, -1); + break; + case LUA_TSTRING: + keyPtr = lua_tostring(lua, -1); + break; + default: lua_pop(lua, 2); return 0; } - strlcpy(key, keyPtr, sizeof(key)); + struct mScriptValue* obj = lua_touserdata(lua, -2); + if (keyPtr) { + strlcpy(key, keyPtr, sizeof(key)); + } lua_pop(lua, 2); obj = mScriptContextAccessWeakref(luaContext->d.context, obj); @@ -678,7 +690,15 @@ int _luaGetTable(lua_State* lua) { lua_error(lua); } - struct mScriptValue keyVal = mSCRIPT_MAKE_CHARP(key); + struct mScriptValue keyVal; + switch (type) { + case LUA_TNUMBER: + keyVal = mSCRIPT_MAKE_S64(intKey); + break; + case LUA_TSTRING: + keyVal = mSCRIPT_MAKE_CHARP(key); + break; + } struct mScriptValue* val = mScriptTableLookup(obj, &keyVal); if (!val) { return 0; @@ -691,13 +711,41 @@ int _luaGetTable(lua_State* lua) { return 1; } +int _luaLenTable(lua_State* lua) { + struct mScriptEngineContextLua* luaContext = _luaGetContext(lua); + struct mScriptValue* obj = lua_touserdata(lua, -1); + lua_pop(lua, 1); + + obj = mScriptContextAccessWeakref(luaContext->d.context, obj); + if (!obj) { + luaL_traceback(lua, lua, "Invalid table", 1); + lua_error(lua); + } + + struct mScriptValue val = mSCRIPT_MAKE_U64(mScriptTableSize(obj)); + + if (!_luaWrap(luaContext, &val)) { + luaL_traceback(lua, lua, "Error translating value from runtime", 1); + lua_error(lua); + } + return 1; +} + static int _luaNextTable(lua_State* lua) { struct mScriptEngineContextLua* luaContext = _luaGetContext(lua); char key[MAX_KEY_SIZE]; - const char* keyPtr = lua_tostring(lua, -1); + int type = lua_type(luaContext->lua, -1); + const char* keyPtr = NULL; + struct mScriptValue keyVal = {0}; + switch (type) { + case LUA_TNUMBER: + keyVal = mSCRIPT_MAKE_S64(lua_tointeger(luaContext->lua, -1)); + break; + case LUA_TSTRING: + keyPtr = lua_tostring(lua, -1); + break; + } struct mScriptValue* table = lua_touserdata(lua, -2); - struct mScriptValue keyVal; - if (keyPtr) { strlcpy(key, keyPtr, sizeof(key)); keyVal = mSCRIPT_MAKE_CHARP(key); @@ -711,7 +759,7 @@ static int _luaNextTable(lua_State* lua) { } struct TableIterator iter; - if (keyPtr) { + if (keyVal.type) { if (!mScriptTableIteratorLookup(table, &iter, &keyVal)) { return 0; } diff --git a/src/script/test/lua.c b/src/script/test/lua.c index 7ef6a1ee0..fc91ad742 100644 --- a/src/script/test/lua.c +++ b/src/script/test/lua.c @@ -20,6 +20,10 @@ vf->close(vf); \ } while(0) +#define TEST_PROGRAM(PROG) \ + LOAD_PROGRAM(PROG); \ + assert_true(lua->run(lua)); \ + struct Test { int32_t i; int32_t (*ifn0)(struct Test*); @@ -154,16 +158,14 @@ M_TEST_DEFINE(getGlobal) { struct mScriptValue a = mSCRIPT_MAKE_S32(1); struct mScriptValue* val; - LOAD_PROGRAM("a = 1"); - assert_true(lua->run(lua)); + TEST_PROGRAM("a = 1"); val = lua->getGlobal(lua, "a"); assert_non_null(val); assert_true(a.type->equal(&a, val)); mScriptValueDeref(val); - LOAD_PROGRAM("b = 1"); - assert_true(lua->run(lua)); + TEST_PROGRAM("b = 1"); val = lua->getGlobal(lua, "a"); assert_non_null(val); @@ -176,8 +178,7 @@ M_TEST_DEFINE(getGlobal) { mScriptValueDeref(val); a = mSCRIPT_MAKE_S32(2); - LOAD_PROGRAM("a = 2"); - assert_true(lua->run(lua)); + TEST_PROGRAM("a = 2"); val = lua->getGlobal(lua, "a"); assert_non_null(val); @@ -185,8 +186,7 @@ M_TEST_DEFINE(getGlobal) { mScriptValueDeref(val); a = mSCRIPT_MAKE_S32(3); - LOAD_PROGRAM("b = a + b"); - assert_true(lua->run(lua)); + TEST_PROGRAM("b = a + b"); val = lua->getGlobal(lua, "b"); assert_non_null(val); @@ -250,8 +250,8 @@ M_TEST_DEFINE(callLuaFunc) { struct mScriptValue* fn; - LOAD_PROGRAM("function a(b) return b + 1 end; function c(d, e) return d + e end"); - assert_true(lua->run(lua)); + TEST_PROGRAM("function a(b) return b + 1 end; function c(d, e) return d + e end"); + assert_null(lua->getError(lua)); fn = lua->getGlobal(lua, "a"); assert_non_null(fn); @@ -291,11 +291,10 @@ M_TEST_DEFINE(callCFunc) { struct mScriptValue a = mSCRIPT_MAKE_S32(1); struct mScriptValue* val; - LOAD_PROGRAM("a = b(1); c = d(1, 2)"); - assert_true(lua->setGlobal(lua, "b", &boundIdentityInt)); assert_true(lua->setGlobal(lua, "d", &boundAddInts)); - assert_true(lua->run(lua)); + TEST_PROGRAM("a = b(1); c = d(1, 2)"); + assert_null(lua->getError(lua)); val = lua->getGlobal(lua, "a"); assert_non_null(val); @@ -535,6 +534,91 @@ M_TEST_DEFINE(errorReporting) { mScriptContextDeinit(&context); } +M_TEST_DEFINE(tableLookup) { + SETUP_LUA; + + assert_null(lua->getError(lua)); + struct mScriptValue* table = mScriptValueAlloc(mSCRIPT_TYPE_MS_TABLE); + assert_non_null(table); + struct mScriptValue* val; + + mScriptContextSetGlobal(&context, "t", table); + + val = mScriptValueAlloc(mSCRIPT_TYPE_MS_S64); + val->value.s64 = 0; + assert_true(mScriptTableInsert(table, &mSCRIPT_MAKE_S64(0), val)); + mScriptValueDeref(val); + + val = mScriptStringCreateFromASCII("t"); + assert_true(mScriptTableInsert(table, &mSCRIPT_MAKE_CHARP("t"), val)); + mScriptValueDeref(val); + + val = mScriptValueAlloc(mSCRIPT_TYPE_MS_TABLE); + assert_true(mScriptTableInsert(table, &mSCRIPT_MAKE_CHARP("sub"), val)); + mScriptValueDeref(val); + + table = val; + val = mScriptStringCreateFromASCII("t"); + assert_true(mScriptTableInsert(table, &mSCRIPT_MAKE_CHARP("t"), val)); + mScriptValueDeref(val); + + TEST_PROGRAM("assert(t)"); + TEST_PROGRAM("assert(t['t'] ~= nil)"); + TEST_PROGRAM("assert(t['t'] == 't')"); + TEST_PROGRAM("assert(t.t == 't')"); + TEST_PROGRAM("assert(t['x'] == nil)"); + TEST_PROGRAM("assert(t.x == nil)"); + TEST_PROGRAM("assert(t.sub ~= nil)"); + TEST_PROGRAM("assert(t.sub.t ~= nil)"); + TEST_PROGRAM("assert(t.sub.t == 't')"); + TEST_PROGRAM("assert(t[0] ~= nil)"); + TEST_PROGRAM("assert(t[0] == 0)"); + TEST_PROGRAM("assert(t[1] == nil)"); + + mScriptContextDeinit(&context); +} + +M_TEST_DEFINE(tableIterate) { + SETUP_LUA; + + assert_null(lua->getError(lua)); + struct mScriptValue* table = mScriptValueAlloc(mSCRIPT_TYPE_MS_TABLE); + assert_non_null(table); + struct mScriptValue* val; + struct mScriptValue* key; + + mScriptContextSetGlobal(&context, "t", table); + + int i; + for (i = 0; i < 50; ++i) { + val = mScriptValueAlloc(mSCRIPT_TYPE_MS_S64); + val->value.s64 = 1LL << i; + key = mScriptValueAlloc(mSCRIPT_TYPE_MS_S32); + key->value.s32 = i; + assert_true(mScriptTableInsert(table, key, val)); + mScriptValueDeref(key); + mScriptValueDeref(val); + } + assert_int_equal(mScriptTableSize(table), 50); + + TEST_PROGRAM("assert(t)"); + TEST_PROGRAM("assert(#t == 50)"); + TEST_PROGRAM( + "i = 0\n" + "z = 0\n" + "for k, v in pairs(t) do\n" + " i = i + 1\n" + " z = z + v\n" + " assert((1 << k) == v)\n" + "end\n" + ); + + TEST_PROGRAM("assert(i == #t)"); + TEST_PROGRAM("assert(z == (1 << #t) - 1)"); + + mScriptContextDeinit(&context); +} + M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(create), cmocka_unit_test(loadGood), @@ -548,4 +632,6 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(globalStructFieldSet), cmocka_unit_test(globalStructMethods), cmocka_unit_test(errorReporting), + cmocka_unit_test(tableLookup), + cmocka_unit_test(tableIterate), ) diff --git a/src/script/types.c b/src/script/types.c index f4a268ac4..0c20b9257 100644 --- a/src/script/types.c +++ b/src/script/types.c @@ -853,6 +853,13 @@ bool mScriptTableClear(struct mScriptValue* table) { return true; } +size_t mScriptTableSize(struct mScriptValue* table) { + if (table->type != mSCRIPT_TYPE_MS_TABLE) { + return 0; + } + return HashTableSize(table->value.table); +} + bool mScriptTableIteratorStart(struct mScriptValue* table, struct TableIterator* iter) { if (table->type == mSCRIPT_TYPE_MS_WRAPPER) { table = mScriptValueUnwrap(table);