diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index cc22b4e14..80728540f 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -10,6 +10,7 @@ static struct mScriptEngineContext* _luaCreate(struct mScriptEngine2*, struct mS static void _luaDestroy(struct mScriptEngineContext*); static struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext*, const char* name); +static bool _luaSetGlobal(struct mScriptEngineContext*, const char* name, struct mScriptValue*); static bool _luaLoad(struct mScriptEngineContext*, struct VFile*, const char** error); static bool _luaRun(struct mScriptEngineContext*); @@ -53,6 +54,7 @@ const struct mScriptType mSTLuaFunc = { struct mScriptEngineContextLua { struct mScriptEngineContext d; lua_State* lua; + int func; }; struct mScriptEngineContextLuaRef { @@ -80,15 +82,20 @@ struct mScriptEngineContext* _luaCreate(struct mScriptEngine2* engine, struct mS .context = context, .destroy = _luaDestroy, .getGlobal = _luaGetGlobal, + .setGlobal = _luaSetGlobal, .load = _luaLoad, .run = _luaRun }; luaContext->lua = luaL_newstate(); + luaContext->func = -1; return &luaContext->d; } void _luaDestroy(struct mScriptEngineContext* ctx) { struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) ctx; + if (luaContext->func > 0) { + luaL_unref(luaContext->lua, LUA_REGISTRYINDEX, luaContext->func); + } lua_close(luaContext->lua); free(luaContext); } @@ -99,6 +106,15 @@ struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext* ctx, const char* return _luaCoerce(luaContext); } +bool _luaSetGlobal(struct mScriptEngineContext* ctx, const char* name, struct mScriptValue* value) { + struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) ctx; + if (!_luaWrap(luaContext, value)) { + return false; + } + lua_setglobal(luaContext->lua, name); + return true; +} + struct mScriptValue* _luaWrapFunction(struct mScriptEngineContextLua* luaContext) { struct mScriptValue* value = mScriptValueAlloc(&mSTLuaFunc); struct mScriptFunction* fn = calloc(1, sizeof(*fn)); @@ -211,6 +227,7 @@ bool _luaLoad(struct mScriptEngineContext* ctx, struct VFile* vf, const char** e if (error) { *error = NULL; } + luaContext->func = luaL_ref(luaContext->lua, LUA_REGISTRYINDEX); return true; case LUA_ERRSYNTAX: if (error) { @@ -228,6 +245,7 @@ bool _luaLoad(struct mScriptEngineContext* ctx, struct VFile* vf, const char** e bool _luaRun(struct mScriptEngineContext* context) { struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) context; + lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->func); return _luaInvoke(luaContext, NULL); } diff --git a/src/script/test/lua.c b/src/script/test/lua.c index b160877b1..cfe9071e5 100644 --- a/src/script/test/lua.c +++ b/src/script/test/lua.c @@ -71,6 +71,9 @@ M_TEST_DEFINE(runNop) { assert_null(error); assert_true(lua->run(lua)); + // Make sure we can run it twice + assert_true(lua->run(lua)); + lua->destroy(lua); mScriptContextDeinit(&context); } @@ -145,6 +148,62 @@ M_TEST_DEFINE(getGlobal) { mScriptContextDeinit(&context); } + +M_TEST_DEFINE(setGlobal) { + struct mScriptContext context; + mScriptContextInit(&context); + struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context); + + struct mScriptValue a = mSCRIPT_MAKE_S32(1); + struct mScriptValue* val; + const char* program; + struct VFile* vf; + const char* error; + + program = "a = b"; + vf = VFileFromConstMemory(program, strlen(program)); + error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->setGlobal(lua, "b", &a)); + + val = lua->getGlobal(lua, "b"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + assert_true(lua->run(lua)); + + val = lua->getGlobal(lua, "a"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + a = mSCRIPT_MAKE_S32(2); + assert_false(a.type->equal(&a, val)); + mScriptValueDeref(val); + + val = lua->getGlobal(lua, "a"); + assert_non_null(val); + assert_false(a.type->equal(&a, val)); + mScriptValueDeref(val); + + assert_true(lua->setGlobal(lua, "b", &a)); + + val = lua->getGlobal(lua, "b"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + assert_true(lua->run(lua)); + + val = lua->getGlobal(lua, "a"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + lua->destroy(lua); + mScriptContextDeinit(&context); +} + M_TEST_DEFINE(callLuaFunc) { struct mScriptContext context; mScriptContextInit(&context); @@ -187,4 +246,5 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(loadBadSyntax), cmocka_unit_test(runNop), cmocka_unit_test(getGlobal), + cmocka_unit_test(setGlobal), cmocka_unit_test(callLuaFunc))