Scripting: Add Lua setGlobal, make sure calling run twice works

This commit is contained in:
Vicki Pfau 2022-03-11 15:36:42 -08:00
parent 2c11c4806a
commit 36efaf6330
2 changed files with 78 additions and 0 deletions

View File

@ -10,6 +10,7 @@ static struct mScriptEngineContext* _luaCreate(struct mScriptEngine2*, struct mS
static void _luaDestroy(struct mScriptEngineContext*); static void _luaDestroy(struct mScriptEngineContext*);
static struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext*, const char* name); static struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext*, const char* name);
static bool _luaSetGlobal(struct mScriptEngineContext*, const char* name, struct mScriptValue*);
static bool _luaLoad(struct mScriptEngineContext*, struct VFile*, const char** error); static bool _luaLoad(struct mScriptEngineContext*, struct VFile*, const char** error);
static bool _luaRun(struct mScriptEngineContext*); static bool _luaRun(struct mScriptEngineContext*);
@ -53,6 +54,7 @@ const struct mScriptType mSTLuaFunc = {
struct mScriptEngineContextLua { struct mScriptEngineContextLua {
struct mScriptEngineContext d; struct mScriptEngineContext d;
lua_State* lua; lua_State* lua;
int func;
}; };
struct mScriptEngineContextLuaRef { struct mScriptEngineContextLuaRef {
@ -80,15 +82,20 @@ struct mScriptEngineContext* _luaCreate(struct mScriptEngine2* engine, struct mS
.context = context, .context = context,
.destroy = _luaDestroy, .destroy = _luaDestroy,
.getGlobal = _luaGetGlobal, .getGlobal = _luaGetGlobal,
.setGlobal = _luaSetGlobal,
.load = _luaLoad, .load = _luaLoad,
.run = _luaRun .run = _luaRun
}; };
luaContext->lua = luaL_newstate(); luaContext->lua = luaL_newstate();
luaContext->func = -1;
return &luaContext->d; return &luaContext->d;
} }
void _luaDestroy(struct mScriptEngineContext* ctx) { void _luaDestroy(struct mScriptEngineContext* ctx) {
struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) ctx; struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) ctx;
if (luaContext->func > 0) {
luaL_unref(luaContext->lua, LUA_REGISTRYINDEX, luaContext->func);
}
lua_close(luaContext->lua); lua_close(luaContext->lua);
free(luaContext); free(luaContext);
} }
@ -99,6 +106,15 @@ struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext* ctx, const char*
return _luaCoerce(luaContext); return _luaCoerce(luaContext);
} }
bool _luaSetGlobal(struct mScriptEngineContext* ctx, const char* name, struct mScriptValue* value) {
struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) ctx;
if (!_luaWrap(luaContext, value)) {
return false;
}
lua_setglobal(luaContext->lua, name);
return true;
}
struct mScriptValue* _luaWrapFunction(struct mScriptEngineContextLua* luaContext) { struct mScriptValue* _luaWrapFunction(struct mScriptEngineContextLua* luaContext) {
struct mScriptValue* value = mScriptValueAlloc(&mSTLuaFunc); struct mScriptValue* value = mScriptValueAlloc(&mSTLuaFunc);
struct mScriptFunction* fn = calloc(1, sizeof(*fn)); struct mScriptFunction* fn = calloc(1, sizeof(*fn));
@ -211,6 +227,7 @@ bool _luaLoad(struct mScriptEngineContext* ctx, struct VFile* vf, const char** e
if (error) { if (error) {
*error = NULL; *error = NULL;
} }
luaContext->func = luaL_ref(luaContext->lua, LUA_REGISTRYINDEX);
return true; return true;
case LUA_ERRSYNTAX: case LUA_ERRSYNTAX:
if (error) { if (error) {
@ -228,6 +245,7 @@ bool _luaLoad(struct mScriptEngineContext* ctx, struct VFile* vf, const char** e
bool _luaRun(struct mScriptEngineContext* context) { bool _luaRun(struct mScriptEngineContext* context) {
struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) context; struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) context;
lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->func);
return _luaInvoke(luaContext, NULL); return _luaInvoke(luaContext, NULL);
} }

View File

@ -71,6 +71,9 @@ M_TEST_DEFINE(runNop) {
assert_null(error); assert_null(error);
assert_true(lua->run(lua)); assert_true(lua->run(lua));
// Make sure we can run it twice
assert_true(lua->run(lua));
lua->destroy(lua); lua->destroy(lua);
mScriptContextDeinit(&context); mScriptContextDeinit(&context);
} }
@ -145,6 +148,62 @@ M_TEST_DEFINE(getGlobal) {
mScriptContextDeinit(&context); mScriptContextDeinit(&context);
} }
M_TEST_DEFINE(setGlobal) {
struct mScriptContext context;
mScriptContextInit(&context);
struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context);
struct mScriptValue a = mSCRIPT_MAKE_S32(1);
struct mScriptValue* val;
const char* program;
struct VFile* vf;
const char* error;
program = "a = b";
vf = VFileFromConstMemory(program, strlen(program));
error = NULL;
assert_true(lua->load(lua, vf, &error));
assert_null(error);
assert_true(lua->setGlobal(lua, "b", &a));
val = lua->getGlobal(lua, "b");
assert_non_null(val);
assert_true(a.type->equal(&a, val));
mScriptValueDeref(val);
assert_true(lua->run(lua));
val = lua->getGlobal(lua, "a");
assert_non_null(val);
assert_true(a.type->equal(&a, val));
a = mSCRIPT_MAKE_S32(2);
assert_false(a.type->equal(&a, val));
mScriptValueDeref(val);
val = lua->getGlobal(lua, "a");
assert_non_null(val);
assert_false(a.type->equal(&a, val));
mScriptValueDeref(val);
assert_true(lua->setGlobal(lua, "b", &a));
val = lua->getGlobal(lua, "b");
assert_non_null(val);
assert_true(a.type->equal(&a, val));
mScriptValueDeref(val);
assert_true(lua->run(lua));
val = lua->getGlobal(lua, "a");
assert_non_null(val);
assert_true(a.type->equal(&a, val));
mScriptValueDeref(val);
lua->destroy(lua);
mScriptContextDeinit(&context);
}
M_TEST_DEFINE(callLuaFunc) { M_TEST_DEFINE(callLuaFunc) {
struct mScriptContext context; struct mScriptContext context;
mScriptContextInit(&context); mScriptContextInit(&context);
@ -187,4 +246,5 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua,
cmocka_unit_test(loadBadSyntax), cmocka_unit_test(loadBadSyntax),
cmocka_unit_test(runNop), cmocka_unit_test(runNop),
cmocka_unit_test(getGlobal), cmocka_unit_test(getGlobal),
cmocka_unit_test(setGlobal),
cmocka_unit_test(callLuaFunc)) cmocka_unit_test(callLuaFunc))