From 2c11c4806abfc9796bcadcc1dc9fba737e84c93d Mon Sep 17 00:00:00 2001 From: Vicki Pfau Date: Tue, 15 Feb 2022 22:09:35 -0800 Subject: [PATCH] Scripting: Add calling Lua functions --- src/script/engines/lua.c | 228 +++++++++++++++++++++++++++++++++++++++ src/script/test/lua.c | 129 +++++++++++++++++++++- 2 files changed, 355 insertions(+), 2 deletions(-) diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index 1782c43ad..cc22b4e14 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -9,13 +9,57 @@ static struct mScriptEngineContext* _luaCreate(struct mScriptEngine2*, struct mScriptContext*); static void _luaDestroy(struct mScriptEngineContext*); +static struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext*, const char* name); static bool _luaLoad(struct mScriptEngineContext*, struct VFile*, const char** error); +static bool _luaRun(struct mScriptEngineContext*); + +static bool _luaCall(struct mScriptFrame*, void* context); + +struct mScriptEngineContextLua; +static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptFrame*); +static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptFrame*); +static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*); + +static struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext); +static bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue*); + +static void _luaDeref(struct mScriptValue*); + +#if LUA_VERSION_NUM < 503 +#define lua_pushinteger lua_pushnumber +#endif + +const struct mScriptType mSTLuaFunc = { + .base = mSCRIPT_TYPE_FUNCTION, + .size = 0, + .name = "lua-" LUA_VERSION_ONLY "::function", + .details = { + .function = { + .parameters = { + .variable = true + }, + .returnType = { + .variable = true + } + } + }, + .alloc = NULL, + .free = _luaDeref, + .hash = NULL, + .equal = NULL, + .cast = NULL, +}; struct mScriptEngineContextLua { struct mScriptEngineContext d; lua_State* lua; }; +struct mScriptEngineContextLuaRef { + struct mScriptEngineContextLua* context; + int ref; +}; + static struct mScriptEngineLua { struct mScriptEngine2 d; } _engineLua = { @@ -35,7 +79,9 @@ struct mScriptEngineContext* _luaCreate(struct mScriptEngine2* engine, struct mS luaContext->d = (struct mScriptEngineContext) { .context = context, .destroy = _luaDestroy, + .getGlobal = _luaGetGlobal, .load = _luaLoad, + .run = _luaRun }; luaContext->lua = luaL_newstate(); return &luaContext->d; @@ -47,6 +93,96 @@ void _luaDestroy(struct mScriptEngineContext* ctx) { free(luaContext); } +struct mScriptValue* _luaGetGlobal(struct mScriptEngineContext* ctx, const char* name) { + struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) ctx; + lua_getglobal(luaContext->lua, name); + return _luaCoerce(luaContext); +} + +struct mScriptValue* _luaWrapFunction(struct mScriptEngineContextLua* luaContext) { + struct mScriptValue* value = mScriptValueAlloc(&mSTLuaFunc); + struct mScriptFunction* fn = calloc(1, sizeof(*fn)); + struct mScriptEngineContextLuaRef* ref = calloc(1, sizeof(*ref)); + fn->call = _luaCall; + fn->context = ref; + ref->context = luaContext; + ref->ref = luaL_ref(luaContext->lua, LUA_REGISTRYINDEX); + value->value.opaque = fn; + return value; +} + +struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext) { + if (lua_isnone(luaContext->lua, -1)) { + lua_pop(luaContext->lua, 1); + return NULL; + } + + struct mScriptValue* value = NULL; + switch (lua_type(luaContext->lua, -1)) { + case LUA_TNUMBER: +#if LUA_VERSION_NUM >= 503 + if (lua_isinteger(luaContext->lua, -1)) { + value = mScriptValueAlloc(mSCRIPT_TYPE_MS_S64); + value->value.s64 = lua_tointeger(luaContext->lua, -1); + break; + } +#endif + value = mScriptValueAlloc(mSCRIPT_TYPE_MS_F64); + value->value.f64 = lua_tonumber(luaContext->lua, -1); + break; + case LUA_TBOOLEAN: + break; + case LUA_TSTRING: + break; + case LUA_TFUNCTION: + return _luaWrapFunction(luaContext); + } + lua_pop(luaContext->lua, 1); + return value; +} + +bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* value) { + if (value->type == mSCRIPT_TYPE_MS_WRAPPER) { + value = mScriptValueUnwrap(value); + } + bool ok = true; + switch (value->type->base) { + case mSCRIPT_TYPE_SINT: + if (value->type->size == 4) { + lua_pushinteger(luaContext->lua, value->value.s32); + } else if (value->type->size == 8) { + lua_pushinteger(luaContext->lua, value->value.s64); + } else { + ok = false; + } + break; + case mSCRIPT_TYPE_UINT: + if (value->type->size == 4) { + lua_pushinteger(luaContext->lua, value->value.u32); + } else if (value->type->size == 8) { + lua_pushinteger(luaContext->lua, value->value.u64); + } else { + ok = false; + } + break; + case mSCRIPT_TYPE_FLOAT: + if (value->type->size == 4) { + lua_pushnumber(luaContext->lua, value->value.f32); + } else if (value->type->size == 8) { + lua_pushnumber(luaContext->lua, value->value.f64); + } else { + ok = false; + } + break; + default: + ok = false; + break; + } + + mScriptValueDeref(value); + return ok; +} + #define LUA_BLOCKSIZE 0x1000 struct mScriptEngineLuaReader { struct VFile* vf; @@ -89,3 +225,95 @@ bool _luaLoad(struct mScriptEngineContext* ctx, struct VFile* vf, const char** e } return false; } + +bool _luaRun(struct mScriptEngineContext* context) { + struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) context; + return _luaInvoke(luaContext, NULL); +} + +bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) { + bool ok = true; + if (frame) { + size_t i; + for (i = 0; i < mScriptListSize(&frame->arguments); ++i) { + if (!_luaWrap(luaContext, mScriptListGetPointer(&frame->arguments, i))) { + ok = false; + break; + } + } + } + if (!ok) { + lua_pop(luaContext->lua, lua_gettop(luaContext->lua)); + } + return ok; +} + +bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) { + int count = lua_gettop(luaContext->lua); + bool ok = true; + if (frame) { + int i; + for (i = 0; i < count; ++i) { + struct mScriptValue* value = _luaCoerce(luaContext); + if (!value) { + ok = false; + break; + } + lua_pop(luaContext->lua, 1); + mScriptValueWrap(value, mScriptListAppend(&frame->returnValues)); + mScriptValueDeref(value); + } + if (count > i) { + lua_pop(luaContext->lua, count - i); + } + } + return ok; +} + +bool _luaCall(struct mScriptFrame* frame, void* context) { + struct mScriptEngineContextLuaRef* ref = context; + lua_rawgeti(ref->context->lua, LUA_REGISTRYINDEX, ref->ref); + if (!_luaInvoke(ref->context, frame)) { + return false; + } + return true; +} + +bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) { + int nargs = 0; + if (frame) { + nargs = mScriptListSize(&frame->arguments); + } + + if (frame && !_luaPushFrame(luaContext, frame)) { + return false; + } + + int ret = lua_pcall(luaContext->lua, nargs, LUA_MULTRET, 0); + + if (ret == LUA_ERRRUN) { + lua_pop(luaContext->lua, 1); + } + if (ret) { + return false; + } + + if (!_luaPopFrame(luaContext, frame)) { + return false; + } + + return true; +} + +void _luaDeref(struct mScriptValue* value) { + struct mScriptEngineContextLuaRef* ref; + if (value->type->base == mSCRIPT_TYPE_FUNCTION) { + struct mScriptFunction* function = value->value.opaque; + ref = function->context; + free(function); + } else { + return; + } + luaL_unref(ref->context->lua, LUA_REGISTRYINDEX, ref->ref); + free(ref); +} diff --git a/src/script/test/lua.c b/src/script/test/lua.c index 695e0b115..b160877b1 100644 --- a/src/script/test/lua.c +++ b/src/script/test/lua.c @@ -44,7 +44,7 @@ M_TEST_DEFINE(loadGood) { mScriptContextDeinit(&context); } -M_TEST_DEFINE(loadSyntax) { +M_TEST_DEFINE(loadBadSyntax) { struct mScriptContext context; mScriptContextInit(&context); struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context); @@ -59,7 +59,132 @@ M_TEST_DEFINE(loadSyntax) { mScriptContextDeinit(&context); } +M_TEST_DEFINE(runNop) { + struct mScriptContext context; + mScriptContextInit(&context); + struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context); + + const char* program = "return"; + struct VFile* vf = VFileFromConstMemory(program, strlen(program)); + const char* error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->run(lua)); + + lua->destroy(lua); + mScriptContextDeinit(&context); +} + +M_TEST_DEFINE(getGlobal) { + struct mScriptContext context; + mScriptContextInit(&context); + struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context); + + struct mScriptValue a = mSCRIPT_MAKE_S32(1); + struct mScriptValue* val; + const char* program; + struct VFile* vf; + const char* error; + + program = "a = 1"; + vf = VFileFromConstMemory(program, strlen(program)); + error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->run(lua)); + + val = lua->getGlobal(lua, "a"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + program = "b = 1"; + vf = VFileFromConstMemory(program, strlen(program)); + error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->run(lua)); + + val = lua->getGlobal(lua, "a"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + val = lua->getGlobal(lua, "b"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + a = mSCRIPT_MAKE_S32(2); + program = "a = 2"; + vf = VFileFromConstMemory(program, strlen(program)); + error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->run(lua)); + + val = lua->getGlobal(lua, "a"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + a = mSCRIPT_MAKE_S32(3); + program = "b = a + b"; + vf = VFileFromConstMemory(program, strlen(program)); + error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->run(lua)); + + val = lua->getGlobal(lua, "b"); + assert_non_null(val); + assert_true(a.type->equal(&a, val)); + mScriptValueDeref(val); + + lua->destroy(lua); + mScriptContextDeinit(&context); +} + +M_TEST_DEFINE(callLuaFunc) { + struct mScriptContext context; + mScriptContextInit(&context); + struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context); + + struct mScriptValue* fn; + const char* program; + struct VFile* vf; + const char* error; + + program = "function a(b) return b + 1 end"; + vf = VFileFromConstMemory(program, strlen(program)); + error = NULL; + assert_true(lua->load(lua, vf, &error)); + assert_null(error); + assert_true(lua->run(lua)); + + fn = lua->getGlobal(lua, "a"); + assert_non_null(fn); + assert_int_equal(fn->type->base, mSCRIPT_TYPE_FUNCTION); + + struct mScriptFrame frame; + mScriptFrameInit(&frame); + mSCRIPT_PUSH(&frame.arguments, S32, 1); + assert_true(mScriptInvoke(fn, &frame)); + int64_t val; + assert_true(mScriptPopS64(&frame.returnValues, &val)); + assert_int_equal(val, 2); + + mScriptFrameDeinit(&frame); + mScriptValueDeref(fn); + + lua->destroy(lua); + mScriptContextDeinit(&context); +} + M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua, cmocka_unit_test(create), cmocka_unit_test(loadGood), - cmocka_unit_test(loadSyntax)) + cmocka_unit_test(loadBadSyntax), + cmocka_unit_test(runNop), + cmocka_unit_test(getGlobal), + cmocka_unit_test(callLuaFunc))