diff --git a/include/mgba/script/context.h b/include/mgba/script/context.h index 435277323..0bd550030 100644 --- a/include/mgba/script/context.h +++ b/include/mgba/script/context.h @@ -36,6 +36,7 @@ struct mScriptContext { uint32_t nextCallbackId; struct mScriptValue* constants; struct Table docstrings; + int threadDepth; }; struct mScriptEngine2 { @@ -110,6 +111,12 @@ struct VFile; bool mScriptContextLoadVF(struct mScriptContext*, const char* name, struct VFile* vf); bool mScriptContextLoadFile(struct mScriptContext*, const char* path); +struct mScriptContext* mScriptActiveContext(void); +bool mScriptContextActivate(struct mScriptContext*); +void mScriptContextDeactivate(struct mScriptContext*); + +bool mScriptContextInvoke(struct mScriptContext*, const struct mScriptValue* fn, struct mScriptFrame* frame); + bool mScriptInvoke(const struct mScriptValue* fn, struct mScriptFrame* frame); CXX_GUARD_END diff --git a/src/core/scripting.c b/src/core/scripting.c index ff7f7f255..a90be776a 100644 --- a/src/core/scripting.c +++ b/src/core/scripting.c @@ -761,7 +761,7 @@ static int64_t _addCallbackToBreakpoint(struct mScriptDebugger* debugger, struct return cbid; } -static void _runCallbacks(struct mScriptBreakpoint* point) { +static void _runCallbacks(struct mScriptDebugger* debugger, struct mScriptBreakpoint* point) { struct TableIterator iter; if (!HashTableIteratorStart(&point->callbacks, &iter)) { return; @@ -770,7 +770,7 @@ static void _runCallbacks(struct mScriptBreakpoint* point) { struct mScriptValue* fn = HashTableIteratorGetValue(&point->callbacks, &iter); struct mScriptFrame frame; mScriptFrameInit(&frame); - mScriptInvoke(fn, &frame); + mScriptContextInvoke(debugger->p->context, fn, &frame); mScriptFrameDeinit(&frame); } while (HashTableIteratorNext(&point->callbacks, &iter)); } @@ -812,7 +812,7 @@ static void _scriptDebuggerEntered(struct mDebuggerModule* debugger, enum mDebug default: return; } - _runCallbacks(point); + _runCallbacks(scriptDebugger, point); debugger->isPaused = false; } @@ -1085,7 +1085,7 @@ static bool _callRotationCb(struct mScriptCoreAdapter* adapter, const char* cbNa if (context) { mScriptValueWrap(context, mScriptListAppend(&frame.arguments)); } - bool ok = mScriptInvoke(cb, &frame); + bool ok = mScriptContextInvoke(adapter->context, cb, &frame); if (ok && out && mScriptListSize(&frame.returnValues) == 1) { if (!mScriptCast(mSCRIPT_TYPE_MS_F32, mScriptListGetPointer(&frame.returnValues, 0), out)) { ok = false; @@ -1154,7 +1154,7 @@ static uint8_t _readLuminance(struct GBALuminanceSource* luminance) { if (adapter->luminanceCb) { struct mScriptFrame frame; mScriptFrameInit(&frame); - bool ok = mScriptInvoke(adapter->luminanceCb, &frame); + bool ok = mScriptContextInvoke(adapter->context, adapter->luminanceCb, &frame); struct mScriptValue out = {0}; if (ok && mScriptListSize(&frame.returnValues) == 1) { if (!mScriptCast(mSCRIPT_TYPE_MS_U8, mScriptListGetPointer(&frame.returnValues, 0), &out)) { diff --git a/src/script/context.c b/src/script/context.c index 85c3de485..090daf66d 100644 --- a/src/script/context.c +++ b/src/script/context.c @@ -7,6 +7,27 @@ #ifdef USE_LUA #include #endif +#include + +static ThreadLocal _threadContext; + +#ifdef USE_PTHREADS +static pthread_once_t _contextOnce = PTHREAD_ONCE_INIT; + +static void _createTLS(void) { + ThreadLocalInitKey(&_threadContext); +} +#elif _WIN32 +static INIT_ONCE _contextOnce = INIT_ONCE_STATIC_INIT; + +static BOOL CALLBACK _createTLS(PINIT_ONCE once, PVOID param, PVOID* context) { + UNUSED(once); + UNUSED(param); + UNUSED(context); + ThreadLocalInitKey(&_threadContext); + return TRUE; +} +#endif #define KEY_NAME_MAX 128 @@ -74,6 +95,11 @@ static void _freeTable(void* data) { } void mScriptContextInit(struct mScriptContext* context) { +#ifdef USE_PTHREADS + pthread_once(&_contextOnce, _createTLS); +#elif _WIN32 + InitOnceExecuteOnce(&_contextOnce, _createTLS, NULL, 0); +#endif HashTableInit(&context->rootScope, 0, (void (*)(void*)) mScriptValueDeref); HashTableInit(&context->engines, 0, _engineContextDestroy); mScriptListInit(&context->refPool, 0); @@ -84,6 +110,7 @@ void mScriptContextInit(struct mScriptContext* context) { context->nextCallbackId = 1; context->constants = NULL; HashTableInit(&context->docstrings, 0, NULL); + context->threadDepth = 0; } void mScriptContextDeinit(struct mScriptContext* context) { @@ -241,7 +268,7 @@ void mScriptContextTriggerCallback(struct mScriptContext* context, const char* c if (args) { mScriptListCopy(&frame.arguments, args); } - mScriptInvoke(fn, &frame); + mScriptContextInvoke(context, fn, &frame); mScriptFrameDeinit(&frame); } @@ -407,6 +434,48 @@ bool mScriptContextLoadFile(struct mScriptContext* context, const char* path) { return ret; } +struct mScriptContext* mScriptActiveContext(void) { + return ThreadLocalGetValue(_threadContext); +} + +bool mScriptContextActivate(struct mScriptContext* context) { + struct mScriptContext* threadContext = ThreadLocalGetValue(_threadContext); + if (threadContext && threadContext != context) { + return false; + } + if (!threadContext && context->threadDepth) { + return false; + } + ++context->threadDepth; + if (!threadContext) { + ThreadLocalSetKey(_threadContext, context); + } + return true; +} + +void mScriptContextDeactivate(struct mScriptContext* context) { +#ifndef NDEBUG + struct mScriptContext* threadContext = ThreadLocalGetValue(_threadContext); + if (threadContext != context) { + abort(); + } +#endif + + --context->threadDepth; + if (!context->threadDepth) { + ThreadLocalSetKey(_threadContext, NULL); + } +} + +bool mScriptContextInvoke(struct mScriptContext* context, const struct mScriptValue* fn, struct mScriptFrame* frame) { + if (!mScriptContextActivate(context)) { + return false; + } + bool res = mScriptInvoke(fn, frame); + mScriptContextDeactivate(context); + return res; +} + bool mScriptInvoke(const struct mScriptValue* val, struct mScriptFrame* frame) { if (val->type->base != mSCRIPT_TYPE_FUNCTION) { return false; diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index 9a5278846..5c56c7df3 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -1098,7 +1098,12 @@ bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* luaContext->lastError = NULL; } + if (!mScriptContextActivate(luaContext->d.context)) { + return false; + } + if (frame && !_luaPushFrame(luaContext, &frame->arguments)) { + mScriptContextDeactivate(luaContext->d.context); return false; } @@ -1114,6 +1119,7 @@ bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* luaContext->lastError = strdup(lua_tostring(luaContext->lua, -1)); lua_pop(luaContext->lua, 1); } + mScriptContextDeactivate(luaContext->d.context); if (ret) { return false; } @@ -1172,7 +1178,7 @@ int _luaThunk(lua_State* lua) { struct mScriptValue* fn = lua_touserdata(lua, lua_upvalueindex(1)); _autofreeFrame(luaContext->d.context, &frame.arguments); - if (!fn || !mScriptInvoke(fn, &frame)) { + if (!fn || !mScriptContextInvoke(luaContext->d.context, fn, &frame)) { mScriptContextDrainPool(luaContext->d.context); mScriptFrameDeinit(&frame); luaL_traceback(lua, lua, "Error calling function (invoking failed)", 1);