diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index e7da8c8bf..eb1db84a8 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -27,6 +27,8 @@ static bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptV static void _luaDeref(struct mScriptValue*); static int _luaThunk(lua_State* lua); +static int _luaGetObject(lua_State* lua); +static int _luaSetObject(lua_State* lua); #if LUA_VERSION_NUM < 503 #define lua_pushinteger lua_pushnumber @@ -90,6 +92,16 @@ struct mScriptEngineContext* _luaCreate(struct mScriptEngine2* engine, struct mS }; luaContext->lua = luaL_newstate(); luaContext->func = -1; + + luaL_newmetatable(luaContext->lua, "mSTStruct"); + lua_pushliteral(luaContext->lua, "__index"); + lua_pushcfunction(luaContext->lua, _luaGetObject); + lua_rawset(luaContext->lua, -3); + lua_pushliteral(luaContext->lua, "__newindex"); + lua_pushcfunction(luaContext->lua, _luaSetObject); + lua_rawset(luaContext->lua, -3); + lua_pop(luaContext->lua, 1); + return &luaContext->d; } @@ -149,6 +161,8 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext) { value->value.f64 = lua_tonumber(luaContext->lua, -1); break; case LUA_TBOOLEAN: + value = mScriptValueAlloc(mSCRIPT_TYPE_MS_S32); + value->value.s32 = lua_toboolean(luaContext->lua, -1); break; case LUA_TSTRING: break; @@ -195,15 +209,18 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v case mSCRIPT_TYPE_FUNCTION: lua_pushlightuserdata(luaContext->lua, value); lua_pushcclosure(luaContext->lua, _luaThunk, 1); + mScriptValueRef(value); + break; + case mSCRIPT_TYPE_OBJECT: + lua_pushlightuserdata(luaContext->lua, value); + luaL_setmetatable(luaContext->lua, "mSTStruct"); + mScriptValueRef(value); break; default: ok = false; break; } - if (ok) { - mScriptValueDeref(value); - } return ok; } @@ -349,7 +366,7 @@ void _luaDeref(struct mScriptValue* value) { free(ref); } -int _luaThunk(lua_State* lua) { +static struct mScriptEngineContextLua* _luaGetContext(lua_State* lua) { lua_pushliteral(lua, "mCtx"); int type = lua_rawget(lua, LUA_REGISTRYINDEX); if (type != LUA_TLIGHTUSERDATA) { @@ -364,7 +381,11 @@ int _luaThunk(lua_State* lua) { lua_pushliteral(lua, "Function called from invalid context"); lua_error(lua); } + return luaContext; +} +int _luaThunk(lua_State* lua) { + struct mScriptEngineContextLua* luaContext = _luaGetContext(lua); struct mScriptFrame frame; mScriptFrameInit(&frame); if (!_luaPopFrame(luaContext, &frame.arguments)) { @@ -389,3 +410,45 @@ int _luaThunk(lua_State* lua) { return lua_gettop(luaContext->lua); } + +int _luaGetObject(lua_State* lua) { + struct mScriptEngineContextLua* luaContext = _luaGetContext(lua); + const char* key = lua_tostring(lua, -1); + struct mScriptValue* obj = lua_touserdata(lua, -2); + struct mScriptValue val; + + if (!mScriptObjectGet(obj, key, &val)) { + lua_pop(lua, 2); + lua_pushliteral(lua, "Invalid key"); + lua_error(lua); + } + + lua_pop(lua, 2); + if (!_luaWrap(luaContext, &val)) { + lua_pushliteral(lua, "Invalid value"); + lua_error(lua); + } + return 1; +} + + +int _luaSetObject(lua_State* lua) { + struct mScriptEngineContextLua* luaContext = _luaGetContext(lua); + const char* key = lua_tostring(lua, -2); + struct mScriptValue* obj = lua_touserdata(lua, -3); + struct mScriptValue* val = _luaCoerce(luaContext); + + lua_pop(lua, 2); + if (!val) { + lua_pushliteral(lua, "Invalid value"); + lua_error(lua); + } + + if (!mScriptObjectSet(obj, key, val)) { + mScriptValueDeref(val); + lua_pushliteral(lua, "Invalid key"); + lua_error(lua); + } + mScriptValueDeref(val); + return 0; +} diff --git a/src/script/test/lua.c b/src/script/test/lua.c index c8656476e..a48d6b5be 100644 --- a/src/script/test/lua.c +++ b/src/script/test/lua.c @@ -7,6 +7,15 @@ #include +struct Test { + int32_t i; + int32_t (*ifn0)(struct Test*); + int32_t (*ifn1)(struct Test*, int); + void (*vfn0)(struct Test*); + void (*vfn1)(struct Test*, int); + int32_t (*icfn0)(const struct Test*); +}; + static int identityInt(int in) { return in; } @@ -15,9 +24,57 @@ static int addInts(int a, int b) { return a + b; } +static int32_t testI0(struct Test* a) { + return a->i; +} + +static int32_t testI1(struct Test* a, int b) { + return a->i + b; +} + +static int32_t testIC0(const struct Test* a) { + return a->i; +} + +static void testV0(struct Test* a) { + ++a->i; +} + +static void testV1(struct Test* a, int b) { + a->i += b; +} + mSCRIPT_BIND_FUNCTION(boundIdentityInt, S32, identityInt, 1, S32); mSCRIPT_BIND_FUNCTION(boundAddInts, S32, addInts, 2, S32, S32); +mSCRIPT_DECLARE_STRUCT(Test); +mSCRIPT_DECLARE_STRUCT_D_METHOD(Test, S32, ifn0, 0); +mSCRIPT_DECLARE_STRUCT_D_METHOD(Test, S32, ifn1, 1, S32); +mSCRIPT_DECLARE_STRUCT_CD_METHOD(Test, S32, icfn0, 0); +mSCRIPT_DECLARE_STRUCT_VOID_D_METHOD(Test, vfn0, 0); +mSCRIPT_DECLARE_STRUCT_VOID_D_METHOD(Test, vfn1, 1, S32); +mSCRIPT_DECLARE_STRUCT_METHOD(Test, S32, i0, testI0, 0); +mSCRIPT_DECLARE_STRUCT_METHOD(Test, S32, i1, testI1, 1, S32); +mSCRIPT_DECLARE_STRUCT_C_METHOD(Test, S32, ic0, testIC0, 0); +mSCRIPT_DECLARE_STRUCT_VOID_METHOD(Test, v0, testV0, 0); +mSCRIPT_DECLARE_STRUCT_VOID_METHOD(Test, v1, testV1, 1, S32); + +mSCRIPT_DEFINE_STRUCT(Test) + mSCRIPT_DEFINE_STRUCT_MEMBER(Test, S32, i) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, ifn0) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, ifn1) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, icfn0) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, vfn0) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, vfn1) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, i0) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, i1) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, ic0) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, v0) + mSCRIPT_DEFINE_STRUCT_METHOD(Test, v1) +mSCRIPT_DEFINE_END; + +mSCRIPT_EXPORT_STRUCT(Test); + M_TEST_SUITE_SETUP(mScriptLua) { if (mSCRIPT_ENGINE_LUA->init) { mSCRIPT_ENGINE_LUA->init(mSCRIPT_ENGINE_LUA); @@ -310,6 +367,72 @@ M_TEST_DEFINE(callCFunc) { vf->close(vf); } +M_TEST_DEFINE(setGlobalStruct) { + struct mScriptContext context; + mScriptContextInit(&context); + struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context); + + struct Test s = { + .i = 1 + }; + + struct mScriptValue a = mSCRIPT_MAKE_S(Test, &s); + struct mScriptValue b; + struct mScriptValue* val; + const char* program; + struct VFile* vf; + const char* error; + + program = "b = a.i"; + vf = VFileFromConstMemory(program, strlen(program)); + error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->setGlobal(lua, "a", &a)); + + assert_true(lua->run(lua)); + + b = mSCRIPT_MAKE_S32(1); + val = lua->getGlobal(lua, "b"); + assert_non_null(val); + assert_true(b.type->equal(&b, val)); + mScriptValueDeref(val); + + s.i = 2; + assert_true(lua->run(lua)); + + b = mSCRIPT_MAKE_S32(2); + val = lua->getGlobal(lua, "b"); + assert_non_null(val); + assert_true(b.type->equal(&b, val)); + mScriptValueDeref(val); + vf->close(vf); + + program = "a.i = b"; + vf = VFileFromConstMemory(program, strlen(program)); + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + b = mSCRIPT_MAKE_S32(3); + assert_true(lua->setGlobal(lua, "b", &b)); + + assert_true(lua->run(lua)); + + assert_int_equal(s.i, 3); + + a = mSCRIPT_MAKE_CS(Test, &s); + assert_true(lua->setGlobal(lua, "a", &a)); + b = mSCRIPT_MAKE_S32(4); + assert_true(lua->setGlobal(lua, "b", &b)); + + assert_false(lua->run(lua)); + + assert_int_equal(s.i, 3); + + lua->destroy(lua); + mScriptContextDeinit(&context); + vf->close(vf); +} + M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(create), cmocka_unit_test(loadGood), @@ -318,4 +441,5 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(getGlobal), cmocka_unit_test(setGlobal), cmocka_unit_test(callLuaFunc), - cmocka_unit_test(callCFunc)) + cmocka_unit_test(callCFunc), + cmocka_unit_test(setGlobalStruct))