Scripting: Add Lua function call thunk

This commit is contained in:
Vicki Pfau 2022-03-11 17:51:45 -08:00
parent 9d92c185c6
commit 7fb7d53c5d
2 changed files with 129 additions and 13 deletions

View File

@ -17,8 +17,8 @@ static bool _luaRun(struct mScriptEngineContext*);
static bool _luaCall(struct mScriptFrame*, void* context); static bool _luaCall(struct mScriptFrame*, void* context);
struct mScriptEngineContextLua; struct mScriptEngineContextLua;
static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptFrame*); static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*);
static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptFrame*); static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptList*);
static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*); static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*);
static struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext); static struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext);
@ -26,6 +26,8 @@ static bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptV
static void _luaDeref(struct mScriptValue*); static void _luaDeref(struct mScriptValue*);
static int _luaThunk(lua_State* lua);
#if LUA_VERSION_NUM < 503 #if LUA_VERSION_NUM < 503
#define lua_pushinteger lua_pushnumber #define lua_pushinteger lua_pushnumber
#endif #endif
@ -190,12 +192,18 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
ok = false; ok = false;
} }
break; break;
case mSCRIPT_TYPE_FUNCTION:
lua_pushlightuserdata(luaContext->lua, value);
lua_pushcclosure(luaContext->lua, _luaThunk, 1);
break;
default: default:
ok = false; ok = false;
break; break;
} }
mScriptValueDeref(value); if (ok) {
mScriptValueDeref(value);
}
return ok; return ok;
} }
@ -249,12 +257,12 @@ bool _luaRun(struct mScriptEngineContext* context) {
return _luaInvoke(luaContext, NULL); return _luaInvoke(luaContext, NULL);
} }
bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) { bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame) {
bool ok = true; bool ok = true;
if (frame) { if (frame) {
size_t i; size_t i;
for (i = 0; i < mScriptListSize(&frame->arguments); ++i) { for (i = 0; i < mScriptListSize(frame); ++i) {
if (!_luaWrap(luaContext, mScriptListGetPointer(&frame->arguments, i))) { if (!_luaWrap(luaContext, mScriptListGetPointer(frame, i))) {
ok = false; ok = false;
break; break;
} }
@ -266,7 +274,7 @@ bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptFra
return ok; return ok;
} }
bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) { bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame) {
int count = lua_gettop(luaContext->lua); int count = lua_gettop(luaContext->lua);
bool ok = true; bool ok = true;
if (frame) { if (frame) {
@ -277,8 +285,7 @@ bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptFram
ok = false; ok = false;
break; break;
} }
lua_pop(luaContext->lua, 1); mScriptValueWrap(value, mScriptListAppend(frame));
mScriptValueWrap(value, mScriptListAppend(&frame->returnValues));
mScriptValueDeref(value); mScriptValueDeref(value);
} }
if (count > i) { if (count > i) {
@ -303,11 +310,17 @@ bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame*
nargs = mScriptListSize(&frame->arguments); nargs = mScriptListSize(&frame->arguments);
} }
if (frame && !_luaPushFrame(luaContext, frame)) { if (frame && !_luaPushFrame(luaContext, &frame->arguments)) {
return false; return false;
} }
lua_pushliteral(luaContext->lua, "mCtx");
lua_pushlightuserdata(luaContext->lua, luaContext);
lua_rawset(luaContext->lua, LUA_REGISTRYINDEX);
int ret = lua_pcall(luaContext->lua, nargs, LUA_MULTRET, 0); int ret = lua_pcall(luaContext->lua, nargs, LUA_MULTRET, 0);
lua_pushliteral(luaContext->lua, "mCtx");
lua_pushnil(luaContext->lua);
lua_rawset(luaContext->lua, LUA_REGISTRYINDEX);
if (ret == LUA_ERRRUN) { if (ret == LUA_ERRRUN) {
lua_pop(luaContext->lua, 1); lua_pop(luaContext->lua, 1);
@ -316,7 +329,7 @@ bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame*
return false; return false;
} }
if (!_luaPopFrame(luaContext, frame)) { if (!_luaPopFrame(luaContext, &frame->returnValues)) {
return false; return false;
} }
@ -335,3 +348,44 @@ void _luaDeref(struct mScriptValue* value) {
luaL_unref(ref->context->lua, LUA_REGISTRYINDEX, ref->ref); luaL_unref(ref->context->lua, LUA_REGISTRYINDEX, ref->ref);
free(ref); free(ref);
} }
int _luaThunk(lua_State* lua) {
lua_pushliteral(lua, "mCtx");
int type = lua_rawget(lua, LUA_REGISTRYINDEX);
if (type != LUA_TLIGHTUSERDATA) {
lua_pop(lua, 1);
lua_pushliteral(lua, "Function called from invalid context");
lua_error(lua);
}
struct mScriptEngineContextLua* luaContext = lua_touserdata(lua, -1);
lua_pop(lua, 1);
if (luaContext->lua != lua) {
lua_pushliteral(lua, "Function called from invalid context");
lua_error(lua);
}
struct mScriptFrame frame;
mScriptFrameInit(&frame);
if (!_luaPopFrame(luaContext, &frame.arguments)) {
mScriptFrameDeinit(&frame);
lua_pushliteral(lua, "Error calling function (setting arguments)");
lua_error(lua);
}
struct mScriptValue* fn = lua_touserdata(lua, lua_upvalueindex(1));
if (!fn || !mScriptInvoke(fn, &frame)) {
mScriptFrameDeinit(&frame);
lua_pushliteral(lua, "Error calling function (invoking)");
lua_error(lua);
}
if (!_luaPushFrame(luaContext, &frame.returnValues)) {
mScriptFrameDeinit(&frame);
lua_pushliteral(lua, "Error calling function (getting return values)");
lua_error(lua);
}
mScriptFrameDeinit(&frame);
return lua_gettop(luaContext->lua);
}

View File

@ -7,6 +7,17 @@
#include <mgba/internal/script/lua.h> #include <mgba/internal/script/lua.h>
static int identityInt(int in) {
return in;
}
static int addInts(int a, int b) {
return a + b;
}
mSCRIPT_BIND_FUNCTION(boundIdentityInt, S32, identityInt, 1, S32);
mSCRIPT_BIND_FUNCTION(boundAddInts, S32, addInts, 2, S32, S32);
M_TEST_SUITE_SETUP(mScriptLua) { M_TEST_SUITE_SETUP(mScriptLua) {
if (mSCRIPT_ENGINE_LUA->init) { if (mSCRIPT_ENGINE_LUA->init) {
mSCRIPT_ENGINE_LUA->init(mSCRIPT_ENGINE_LUA); mSCRIPT_ENGINE_LUA->init(mSCRIPT_ENGINE_LUA);
@ -214,7 +225,7 @@ M_TEST_DEFINE(callLuaFunc) {
struct VFile* vf; struct VFile* vf;
const char* error; const char* error;
program = "function a(b) return b + 1 end"; program = "function a(b) return b + 1 end; function c(d, e) return d + e end";
vf = VFileFromConstMemory(program, strlen(program)); vf = VFileFromConstMemory(program, strlen(program));
error = NULL; error = NULL;
assert_true(lua->load(lua, vf, &error)); assert_true(lua->load(lua, vf, &error));
@ -236,6 +247,56 @@ M_TEST_DEFINE(callLuaFunc) {
mScriptFrameDeinit(&frame); mScriptFrameDeinit(&frame);
mScriptValueDeref(fn); mScriptValueDeref(fn);
fn = lua->getGlobal(lua, "c");
assert_non_null(fn);
assert_int_equal(fn->type->base, mSCRIPT_TYPE_FUNCTION);
mScriptFrameInit(&frame);
mSCRIPT_PUSH(&frame.arguments, S32, 1);
mSCRIPT_PUSH(&frame.arguments, S32, 2);
assert_true(mScriptInvoke(fn, &frame));
assert_true(mScriptPopS64(&frame.returnValues, &val));
assert_int_equal(val, 3);
mScriptFrameDeinit(&frame);
mScriptValueDeref(fn);
lua->destroy(lua);
mScriptContextDeinit(&context);
}
M_TEST_DEFINE(callCFunc) {
struct mScriptContext context;
mScriptContextInit(&context);
struct mScriptEngineContext* lua = mSCRIPT_ENGINE_LUA->create(mSCRIPT_ENGINE_LUA, &context);
struct mScriptValue a = mSCRIPT_MAKE_S32(1);
struct mScriptValue* val;
const char* program;
struct VFile* vf;
const char* error;
program = "a = b(1); c = d(1, 2)";
vf = VFileFromConstMemory(program, strlen(program));
error = NULL;
assert_true(lua->load(lua, vf, &error));
assert_null(error);
assert_true(lua->setGlobal(lua, "b", &boundIdentityInt));
assert_true(lua->setGlobal(lua, "d", &boundAddInts));
assert_true(lua->run(lua));
val = lua->getGlobal(lua, "a");
assert_non_null(val);
assert_true(a.type->equal(&a, val));
mScriptValueDeref(val);
a = mSCRIPT_MAKE_S32(3);
val = lua->getGlobal(lua, "c");
assert_non_null(val);
assert_true(a.type->equal(&a, val));
mScriptValueDeref(val);
lua->destroy(lua); lua->destroy(lua);
mScriptContextDeinit(&context); mScriptContextDeinit(&context);
} }
@ -247,4 +308,5 @@ M_TEST_SUITE_DEFINE_SETUP_TEARDOWN(mScriptLua,
cmocka_unit_test(runNop), cmocka_unit_test(runNop),
cmocka_unit_test(getGlobal), cmocka_unit_test(getGlobal),
cmocka_unit_test(setGlobal), cmocka_unit_test(setGlobal),
cmocka_unit_test(callLuaFunc)) cmocka_unit_test(callLuaFunc),
cmocka_unit_test(callCFunc))