From 282a033df265596335800f34a2e9e5111ff5ad00 Mon Sep 17 00:00:00 2001 From: Vicki Pfau Date: Wed, 8 Feb 2023 00:39:02 -0800 Subject: [PATCH] Scripting: Clean up refcounting --- src/script/context.c | 1 + src/script/engines/lua.c | 97 +++++++++++++++++++++++++--------------- src/script/stdlib.c | 1 - src/script/types.c | 8 ---- 4 files changed, 63 insertions(+), 44 deletions(-) diff --git a/src/script/context.c b/src/script/context.c index bf5fcfa54..418050e74 100644 --- a/src/script/context.c +++ b/src/script/context.c @@ -250,6 +250,7 @@ uint32_t mScriptContextAddCallback(struct mScriptContext* context, const char* c HashTableIteratorLookup(&context->callbacks, &iter, callback); info->callback = HashTableIteratorGetKey(&context->callbacks, &iter); info->id = mScriptListSize(list->value.list); + mScriptValueRef(fn); mScriptValueWrap(fn, mScriptListAppend(list->value.list)); while (true) { uint32_t id = context->nextCallbackId; diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index 1e09bcde1..9d652eb4f 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -36,8 +36,11 @@ static const char* _luaGetError(struct mScriptEngineContext*); static bool _luaCall(struct mScriptFrame*, void* context); +static void _freeFrame(struct mScriptList* frame); +static void _autofreeFrame(struct mScriptContext* context, struct mScriptList* frame); + struct mScriptEngineContextLua; -static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*, bool internal); +static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*); static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptList*); static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*); @@ -520,6 +523,8 @@ struct mScriptValue* _luaRootScope(struct mScriptEngineContext* ctx) { lua_pop(luaContext->lua, 1); key = _luaCoerce(luaContext, false); mScriptValueWrap(key, mScriptListAppend(list->value.list)); + mScriptValueRef(key); + mScriptContextFillPool(luaContext->d.context, key); } lua_pop(luaContext->lua, 1); @@ -544,7 +549,7 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext, lua_pushnil(luaContext->lua); - void* tablePointer; + const void* tablePointer; while (lua_next(luaContext->lua, -2) != 0) { struct mScriptValue* value = NULL; int type = lua_type(luaContext->lua, -1); @@ -559,11 +564,8 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext, tablePointer = lua_topointer(luaContext->lua, -1); // Ensure this table doesn't contain any cycles if (!HashTableLookupBinary(markedObjects, &tablePointer, sizeof(tablePointer))) { - HashTableInsertBinary(markedObjects, &tablePointer, sizeof(tablePointer), tablePointer); + HashTableInsertBinary(markedObjects, &tablePointer, sizeof(tablePointer), (void*) tablePointer); value = _luaCoerceTable(luaContext, markedObjects); - if (value) { - mScriptValueRef(value); - } } default: break; @@ -595,18 +597,13 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext, return false; } mScriptTableInsert(table, key, value); - if (key->type != mSCRIPT_TYPE_MS_STR) { - // Strings are added to the ref pool, so we need to keep it - // ref'd to prevent it from being collected prematurely - mScriptValueDeref(key); - } + mScriptValueDeref(key); mScriptValueDeref(value); } lua_pop(luaContext->lua, 1); size_t len = mScriptTableSize(table); if (!isList || !len) { - mScriptContextFillPool(luaContext->d.context, table); return table; } @@ -616,18 +613,15 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext, struct mScriptValue* value = mScriptTableLookup(table, &mSCRIPT_MAKE_S64(i)); if (!value) { mScriptValueDeref(list); - mScriptContextFillPool(luaContext->d.context, table); return table; } mScriptValueWrap(value, mScriptListAppend(list->value.list)); } if (i != len + 1) { mScriptValueDeref(list); - mScriptContextFillPool(luaContext->d.context, table); return table; } mScriptValueDeref(table); - mScriptContextFillPool(luaContext->d.context, list); return list; } @@ -663,7 +657,6 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool case LUA_TSTRING: buffer = lua_tolstring(luaContext->lua, -1, &size); value = mScriptStringCreateFromBytes(buffer, size); - mScriptContextFillPool(luaContext->d.context, value); break; case LUA_TFUNCTION: // This function pops the value internally via luaL_ref @@ -692,6 +685,12 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool lua_pop(luaContext->lua, 2); value = lua_touserdata(luaContext->lua, -1); value = mScriptContextAccessWeakref(luaContext->d.context, value); + if (value->type->base == mSCRIPT_TYPE_WRAPPER) { + value = mScriptValueUnwrap(value); + } + if (value) { + mScriptValueRef(value); + } break; } if (pop) { @@ -713,6 +712,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v lua_pushnil(luaContext->lua); return true; } + mScriptContextFillPool(luaContext->d.context, value); } struct mScriptValue derefPtr; if (value->type->base == mSCRIPT_TYPE_OPAQUE) { @@ -803,6 +803,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v if (needsWeakref) { *newValue = mSCRIPT_MAKE(WEAKREF, weakref); } else { + mScriptValueRef(value); mScriptValueWrap(value, newValue); } lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTList"); @@ -813,6 +814,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v if (needsWeakref) { *newValue = mSCRIPT_MAKE(WEAKREF, weakref); } else { + mScriptValueRef(value); mScriptValueWrap(value, newValue); } lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTTable"); @@ -824,7 +826,6 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v newValue->refs = mSCRIPT_VALUE_UNREF; newValue->type->alloc(newValue); lua_pushcclosure(luaContext->lua, _luaThunk, 1); - mScriptValueDeref(value); break; case mSCRIPT_TYPE_OBJECT: if (!value->value.opaque) { @@ -835,6 +836,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v if (needsWeakref) { *newValue = mSCRIPT_MAKE(WEAKREF, weakref); } else { + mScriptValueRef(value); mScriptValueWrap(value, newValue); } lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTStruct"); @@ -935,16 +937,12 @@ const char* _luaGetError(struct mScriptEngineContext* context) { return luaContext->lastError; } -bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame, bool internal) { +bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame) { bool ok = true; if (frame) { size_t i; for (i = 0; i < mScriptListSize(frame); ++i) { struct mScriptValue* value = mScriptListGetPointer(frame, i); - if (internal && value->type->base == mSCRIPT_TYPE_WRAPPER) { - value = mScriptValueUnwrap(value); - mScriptContextFillPool(luaContext->d.context, value); - } if (!_luaWrap(luaContext, value)) { ok = false; break; @@ -968,8 +966,11 @@ bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList ok = false; break; } - mScriptValueWrap(value, mScriptListAppend(frame)); - mScriptValueDeref(value); + struct mScriptValue* tail = mScriptListAppend(frame); + mScriptValueWrap(value, tail); + if (tail->type == value->type) { + mScriptValueDeref(value); + } } if (count > i) { lua_pop(luaContext->lua, count - i); @@ -996,6 +997,26 @@ bool _luaCall(struct mScriptFrame* frame, void* context) { return true; } +void _freeFrame(struct mScriptList* frame) { + size_t i; + for (i = 0; i < mScriptListSize(frame); ++i) { + struct mScriptValue* val = mScriptValueUnwrap(mScriptListGetPointer(frame, i)); + if (val) { + mScriptValueDeref(val); + } + } +} + +void _autofreeFrame(struct mScriptContext* context, struct mScriptList* frame) { + size_t i; + for (i = 0; i < mScriptListSize(frame); ++i) { + struct mScriptValue* val = mScriptValueUnwrap(mScriptListGetPointer(frame, i)); + if (val) { + mScriptContextFillPool(context, val); + } + } +} + bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) { int nargs = 0; if (frame) { @@ -1007,7 +1028,7 @@ bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* luaContext->lastError = NULL; } - if (frame && !_luaPushFrame(luaContext, &frame->arguments, false)) { + if (frame && !_luaPushFrame(luaContext, &frame->arguments)) { return false; } @@ -1072,6 +1093,7 @@ int _luaThunk(lua_State* lua) { struct mScriptFrame frame; mScriptFrameInit(&frame); if (!_luaPopFrame(luaContext, &frame.arguments)) { + _freeFrame(&frame.arguments); mScriptContextDrainPool(luaContext->d.context); mScriptFrameDeinit(&frame); luaL_traceback(lua, lua, "Error calling function (translating arguments into runtime)", 1); @@ -1079,19 +1101,21 @@ int _luaThunk(lua_State* lua) { } struct mScriptValue* fn = lua_touserdata(lua, lua_upvalueindex(1)); + _autofreeFrame(luaContext->d.context, &frame.arguments); if (!fn || !mScriptInvoke(fn, &frame)) { + mScriptContextDrainPool(luaContext->d.context); mScriptFrameDeinit(&frame); luaL_traceback(lua, lua, "Error calling function (invoking failed)", 1); return lua_error(lua); } - if (!_luaPushFrame(luaContext, &frame.returnValues, true)) { - mScriptFrameDeinit(&frame); + bool ok = _luaPushFrame(luaContext, &frame.returnValues); + mScriptContextDrainPool(luaContext->d.context); + mScriptFrameDeinit(&frame); + if (!ok) { luaL_traceback(lua, lua, "Error calling function (translating return values from runtime)", 1); return lua_error(lua); } - mScriptContextDrainPool(luaContext->d.context); - mScriptFrameDeinit(&frame); return lua_gettop(luaContext->lua); } @@ -1146,19 +1170,22 @@ int _luaSetObject(lua_State* lua) { strlcpy(key, keyPtr, sizeof(key)); lua_pop(lua, 2); - obj = mScriptContextAccessWeakref(luaContext->d.context, obj); - if (!obj) { - luaL_traceback(lua, lua, "Invalid object", 1); - return lua_error(lua); - } - if (!val) { luaL_traceback(lua, lua, "Error translating value to runtime", 1); return lua_error(lua); } + obj = mScriptContextAccessWeakref(luaContext->d.context, obj); + if (!obj) { + mScriptValueDeref(val); + mScriptContextDrainPool(luaContext->d.context); + luaL_traceback(lua, lua, "Invalid object", 1); + return lua_error(lua); + } + if (!mScriptObjectSet(obj, key, val)) { mScriptValueDeref(val); + mScriptContextDrainPool(luaContext->d.context); char error[MAX_KEY_SIZE + 16]; snprintf(error, sizeof(error), "Invalid key '%s'", key); luaL_traceback(lua, lua, "Invalid key", 1); diff --git a/src/script/stdlib.c b/src/script/stdlib.c index 839673cdb..3163a0d98 100644 --- a/src/script/stdlib.c +++ b/src/script/stdlib.c @@ -24,7 +24,6 @@ static uint32_t _mScriptCallbackAdd(struct mScriptCallbackManager* adapter, stru fn = mScriptValueUnwrap(fn); } uint32_t id = mScriptContextAddCallback(adapter->context, name->buffer, fn); - mScriptValueDeref(fn); return id; } diff --git a/src/script/types.c b/src/script/types.c index a05f6d252..dd50f7968 100644 --- a/src/script/types.c +++ b/src/script/types.c @@ -892,7 +892,6 @@ void mScriptValueWrap(struct mScriptValue* value, struct mScriptValue* out) { out->type = mSCRIPT_TYPE_MS_WRAPPER; out->value.opaque = value; - mScriptValueRef(value); } struct mScriptValue* mScriptValueUnwrap(struct mScriptValue* value) { @@ -1473,18 +1472,11 @@ bool mScriptObjectSet(struct mScriptValue* obj, const char* member, struct mScri this->value.opaque = obj->value.opaque; mSCRIPT_PUSH(&frame.arguments, CHARP, member); mScriptValueWrap(val, mScriptListAppend(&frame.arguments)); - bool needsDeref = mScriptListGetPointer(&frame.arguments, 2)->type->base == mSCRIPT_TYPE_WRAPPER; if (!mScriptInvoke(&setMember, &frame) || mScriptListSize(&frame.returnValues) != 0) { mScriptFrameDeinit(&frame); - if (needsDeref) { - mScriptValueDeref(val); - } return false; } mScriptFrameDeinit(&frame); - if (needsDeref) { - mScriptValueDeref(val); - } return true; }