Scripting: Clean up refcounting

This commit is contained in:
Vicki Pfau 2023-02-08 00:39:02 -08:00
parent 5c0bd1b245
commit 282a033df2
4 changed files with 63 additions and 44 deletions

View File

@ -250,6 +250,7 @@ uint32_t mScriptContextAddCallback(struct mScriptContext* context, const char* c
HashTableIteratorLookup(&context->callbacks, &iter, callback); HashTableIteratorLookup(&context->callbacks, &iter, callback);
info->callback = HashTableIteratorGetKey(&context->callbacks, &iter); info->callback = HashTableIteratorGetKey(&context->callbacks, &iter);
info->id = mScriptListSize(list->value.list); info->id = mScriptListSize(list->value.list);
mScriptValueRef(fn);
mScriptValueWrap(fn, mScriptListAppend(list->value.list)); mScriptValueWrap(fn, mScriptListAppend(list->value.list));
while (true) { while (true) {
uint32_t id = context->nextCallbackId; uint32_t id = context->nextCallbackId;

View File

@ -36,8 +36,11 @@ static const char* _luaGetError(struct mScriptEngineContext*);
static bool _luaCall(struct mScriptFrame*, void* context); static bool _luaCall(struct mScriptFrame*, void* context);
static void _freeFrame(struct mScriptList* frame);
static void _autofreeFrame(struct mScriptContext* context, struct mScriptList* frame);
struct mScriptEngineContextLua; struct mScriptEngineContextLua;
static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*, bool internal); static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*);
static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptList*); static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptList*);
static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*); static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*);
@ -520,6 +523,8 @@ struct mScriptValue* _luaRootScope(struct mScriptEngineContext* ctx) {
lua_pop(luaContext->lua, 1); lua_pop(luaContext->lua, 1);
key = _luaCoerce(luaContext, false); key = _luaCoerce(luaContext, false);
mScriptValueWrap(key, mScriptListAppend(list->value.list)); mScriptValueWrap(key, mScriptListAppend(list->value.list));
mScriptValueRef(key);
mScriptContextFillPool(luaContext->d.context, key);
} }
lua_pop(luaContext->lua, 1); lua_pop(luaContext->lua, 1);
@ -544,7 +549,7 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
lua_pushnil(luaContext->lua); lua_pushnil(luaContext->lua);
void* tablePointer; const void* tablePointer;
while (lua_next(luaContext->lua, -2) != 0) { while (lua_next(luaContext->lua, -2) != 0) {
struct mScriptValue* value = NULL; struct mScriptValue* value = NULL;
int type = lua_type(luaContext->lua, -1); int type = lua_type(luaContext->lua, -1);
@ -559,11 +564,8 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
tablePointer = lua_topointer(luaContext->lua, -1); tablePointer = lua_topointer(luaContext->lua, -1);
// Ensure this table doesn't contain any cycles // Ensure this table doesn't contain any cycles
if (!HashTableLookupBinary(markedObjects, &tablePointer, sizeof(tablePointer))) { if (!HashTableLookupBinary(markedObjects, &tablePointer, sizeof(tablePointer))) {
HashTableInsertBinary(markedObjects, &tablePointer, sizeof(tablePointer), tablePointer); HashTableInsertBinary(markedObjects, &tablePointer, sizeof(tablePointer), (void*) tablePointer);
value = _luaCoerceTable(luaContext, markedObjects); value = _luaCoerceTable(luaContext, markedObjects);
if (value) {
mScriptValueRef(value);
}
} }
default: default:
break; break;
@ -595,18 +597,13 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
return false; return false;
} }
mScriptTableInsert(table, key, value); mScriptTableInsert(table, key, value);
if (key->type != mSCRIPT_TYPE_MS_STR) { mScriptValueDeref(key);
// Strings are added to the ref pool, so we need to keep it
// ref'd to prevent it from being collected prematurely
mScriptValueDeref(key);
}
mScriptValueDeref(value); mScriptValueDeref(value);
} }
lua_pop(luaContext->lua, 1); lua_pop(luaContext->lua, 1);
size_t len = mScriptTableSize(table); size_t len = mScriptTableSize(table);
if (!isList || !len) { if (!isList || !len) {
mScriptContextFillPool(luaContext->d.context, table);
return table; return table;
} }
@ -616,18 +613,15 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
struct mScriptValue* value = mScriptTableLookup(table, &mSCRIPT_MAKE_S64(i)); struct mScriptValue* value = mScriptTableLookup(table, &mSCRIPT_MAKE_S64(i));
if (!value) { if (!value) {
mScriptValueDeref(list); mScriptValueDeref(list);
mScriptContextFillPool(luaContext->d.context, table);
return table; return table;
} }
mScriptValueWrap(value, mScriptListAppend(list->value.list)); mScriptValueWrap(value, mScriptListAppend(list->value.list));
} }
if (i != len + 1) { if (i != len + 1) {
mScriptValueDeref(list); mScriptValueDeref(list);
mScriptContextFillPool(luaContext->d.context, table);
return table; return table;
} }
mScriptValueDeref(table); mScriptValueDeref(table);
mScriptContextFillPool(luaContext->d.context, list);
return list; return list;
} }
@ -663,7 +657,6 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool
case LUA_TSTRING: case LUA_TSTRING:
buffer = lua_tolstring(luaContext->lua, -1, &size); buffer = lua_tolstring(luaContext->lua, -1, &size);
value = mScriptStringCreateFromBytes(buffer, size); value = mScriptStringCreateFromBytes(buffer, size);
mScriptContextFillPool(luaContext->d.context, value);
break; break;
case LUA_TFUNCTION: case LUA_TFUNCTION:
// This function pops the value internally via luaL_ref // This function pops the value internally via luaL_ref
@ -692,6 +685,12 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool
lua_pop(luaContext->lua, 2); lua_pop(luaContext->lua, 2);
value = lua_touserdata(luaContext->lua, -1); value = lua_touserdata(luaContext->lua, -1);
value = mScriptContextAccessWeakref(luaContext->d.context, value); value = mScriptContextAccessWeakref(luaContext->d.context, value);
if (value->type->base == mSCRIPT_TYPE_WRAPPER) {
value = mScriptValueUnwrap(value);
}
if (value) {
mScriptValueRef(value);
}
break; break;
} }
if (pop) { if (pop) {
@ -713,6 +712,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
lua_pushnil(luaContext->lua); lua_pushnil(luaContext->lua);
return true; return true;
} }
mScriptContextFillPool(luaContext->d.context, value);
} }
struct mScriptValue derefPtr; struct mScriptValue derefPtr;
if (value->type->base == mSCRIPT_TYPE_OPAQUE) { if (value->type->base == mSCRIPT_TYPE_OPAQUE) {
@ -803,6 +803,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
if (needsWeakref) { if (needsWeakref) {
*newValue = mSCRIPT_MAKE(WEAKREF, weakref); *newValue = mSCRIPT_MAKE(WEAKREF, weakref);
} else { } else {
mScriptValueRef(value);
mScriptValueWrap(value, newValue); mScriptValueWrap(value, newValue);
} }
lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTList"); lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTList");
@ -813,6 +814,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
if (needsWeakref) { if (needsWeakref) {
*newValue = mSCRIPT_MAKE(WEAKREF, weakref); *newValue = mSCRIPT_MAKE(WEAKREF, weakref);
} else { } else {
mScriptValueRef(value);
mScriptValueWrap(value, newValue); mScriptValueWrap(value, newValue);
} }
lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTTable"); lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTTable");
@ -824,7 +826,6 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
newValue->refs = mSCRIPT_VALUE_UNREF; newValue->refs = mSCRIPT_VALUE_UNREF;
newValue->type->alloc(newValue); newValue->type->alloc(newValue);
lua_pushcclosure(luaContext->lua, _luaThunk, 1); lua_pushcclosure(luaContext->lua, _luaThunk, 1);
mScriptValueDeref(value);
break; break;
case mSCRIPT_TYPE_OBJECT: case mSCRIPT_TYPE_OBJECT:
if (!value->value.opaque) { if (!value->value.opaque) {
@ -835,6 +836,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
if (needsWeakref) { if (needsWeakref) {
*newValue = mSCRIPT_MAKE(WEAKREF, weakref); *newValue = mSCRIPT_MAKE(WEAKREF, weakref);
} else { } else {
mScriptValueRef(value);
mScriptValueWrap(value, newValue); mScriptValueWrap(value, newValue);
} }
lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTStruct"); lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTStruct");
@ -935,16 +937,12 @@ const char* _luaGetError(struct mScriptEngineContext* context) {
return luaContext->lastError; return luaContext->lastError;
} }
bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame, bool internal) { bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame) {
bool ok = true; bool ok = true;
if (frame) { if (frame) {
size_t i; size_t i;
for (i = 0; i < mScriptListSize(frame); ++i) { for (i = 0; i < mScriptListSize(frame); ++i) {
struct mScriptValue* value = mScriptListGetPointer(frame, i); struct mScriptValue* value = mScriptListGetPointer(frame, i);
if (internal && value->type->base == mSCRIPT_TYPE_WRAPPER) {
value = mScriptValueUnwrap(value);
mScriptContextFillPool(luaContext->d.context, value);
}
if (!_luaWrap(luaContext, value)) { if (!_luaWrap(luaContext, value)) {
ok = false; ok = false;
break; break;
@ -968,8 +966,11 @@ bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList
ok = false; ok = false;
break; break;
} }
mScriptValueWrap(value, mScriptListAppend(frame)); struct mScriptValue* tail = mScriptListAppend(frame);
mScriptValueDeref(value); mScriptValueWrap(value, tail);
if (tail->type == value->type) {
mScriptValueDeref(value);
}
} }
if (count > i) { if (count > i) {
lua_pop(luaContext->lua, count - i); lua_pop(luaContext->lua, count - i);
@ -996,6 +997,26 @@ bool _luaCall(struct mScriptFrame* frame, void* context) {
return true; return true;
} }
void _freeFrame(struct mScriptList* frame) {
size_t i;
for (i = 0; i < mScriptListSize(frame); ++i) {
struct mScriptValue* val = mScriptValueUnwrap(mScriptListGetPointer(frame, i));
if (val) {
mScriptValueDeref(val);
}
}
}
void _autofreeFrame(struct mScriptContext* context, struct mScriptList* frame) {
size_t i;
for (i = 0; i < mScriptListSize(frame); ++i) {
struct mScriptValue* val = mScriptValueUnwrap(mScriptListGetPointer(frame, i));
if (val) {
mScriptContextFillPool(context, val);
}
}
}
bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) { bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) {
int nargs = 0; int nargs = 0;
if (frame) { if (frame) {
@ -1007,7 +1028,7 @@ bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame*
luaContext->lastError = NULL; luaContext->lastError = NULL;
} }
if (frame && !_luaPushFrame(luaContext, &frame->arguments, false)) { if (frame && !_luaPushFrame(luaContext, &frame->arguments)) {
return false; return false;
} }
@ -1072,6 +1093,7 @@ int _luaThunk(lua_State* lua) {
struct mScriptFrame frame; struct mScriptFrame frame;
mScriptFrameInit(&frame); mScriptFrameInit(&frame);
if (!_luaPopFrame(luaContext, &frame.arguments)) { if (!_luaPopFrame(luaContext, &frame.arguments)) {
_freeFrame(&frame.arguments);
mScriptContextDrainPool(luaContext->d.context); mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame); mScriptFrameDeinit(&frame);
luaL_traceback(lua, lua, "Error calling function (translating arguments into runtime)", 1); luaL_traceback(lua, lua, "Error calling function (translating arguments into runtime)", 1);
@ -1079,19 +1101,21 @@ int _luaThunk(lua_State* lua) {
} }
struct mScriptValue* fn = lua_touserdata(lua, lua_upvalueindex(1)); struct mScriptValue* fn = lua_touserdata(lua, lua_upvalueindex(1));
_autofreeFrame(luaContext->d.context, &frame.arguments);
if (!fn || !mScriptInvoke(fn, &frame)) { if (!fn || !mScriptInvoke(fn, &frame)) {
mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame); mScriptFrameDeinit(&frame);
luaL_traceback(lua, lua, "Error calling function (invoking failed)", 1); luaL_traceback(lua, lua, "Error calling function (invoking failed)", 1);
return lua_error(lua); return lua_error(lua);
} }
if (!_luaPushFrame(luaContext, &frame.returnValues, true)) { bool ok = _luaPushFrame(luaContext, &frame.returnValues);
mScriptFrameDeinit(&frame); mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame);
if (!ok) {
luaL_traceback(lua, lua, "Error calling function (translating return values from runtime)", 1); luaL_traceback(lua, lua, "Error calling function (translating return values from runtime)", 1);
return lua_error(lua); return lua_error(lua);
} }
mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame);
return lua_gettop(luaContext->lua); return lua_gettop(luaContext->lua);
} }
@ -1146,19 +1170,22 @@ int _luaSetObject(lua_State* lua) {
strlcpy(key, keyPtr, sizeof(key)); strlcpy(key, keyPtr, sizeof(key));
lua_pop(lua, 2); lua_pop(lua, 2);
obj = mScriptContextAccessWeakref(luaContext->d.context, obj);
if (!obj) {
luaL_traceback(lua, lua, "Invalid object", 1);
return lua_error(lua);
}
if (!val) { if (!val) {
luaL_traceback(lua, lua, "Error translating value to runtime", 1); luaL_traceback(lua, lua, "Error translating value to runtime", 1);
return lua_error(lua); return lua_error(lua);
} }
obj = mScriptContextAccessWeakref(luaContext->d.context, obj);
if (!obj) {
mScriptValueDeref(val);
mScriptContextDrainPool(luaContext->d.context);
luaL_traceback(lua, lua, "Invalid object", 1);
return lua_error(lua);
}
if (!mScriptObjectSet(obj, key, val)) { if (!mScriptObjectSet(obj, key, val)) {
mScriptValueDeref(val); mScriptValueDeref(val);
mScriptContextDrainPool(luaContext->d.context);
char error[MAX_KEY_SIZE + 16]; char error[MAX_KEY_SIZE + 16];
snprintf(error, sizeof(error), "Invalid key '%s'", key); snprintf(error, sizeof(error), "Invalid key '%s'", key);
luaL_traceback(lua, lua, "Invalid key", 1); luaL_traceback(lua, lua, "Invalid key", 1);

View File

@ -24,7 +24,6 @@ static uint32_t _mScriptCallbackAdd(struct mScriptCallbackManager* adapter, stru
fn = mScriptValueUnwrap(fn); fn = mScriptValueUnwrap(fn);
} }
uint32_t id = mScriptContextAddCallback(adapter->context, name->buffer, fn); uint32_t id = mScriptContextAddCallback(adapter->context, name->buffer, fn);
mScriptValueDeref(fn);
return id; return id;
} }

View File

@ -892,7 +892,6 @@ void mScriptValueWrap(struct mScriptValue* value, struct mScriptValue* out) {
out->type = mSCRIPT_TYPE_MS_WRAPPER; out->type = mSCRIPT_TYPE_MS_WRAPPER;
out->value.opaque = value; out->value.opaque = value;
mScriptValueRef(value);
} }
struct mScriptValue* mScriptValueUnwrap(struct mScriptValue* value) { struct mScriptValue* mScriptValueUnwrap(struct mScriptValue* value) {
@ -1473,18 +1472,11 @@ bool mScriptObjectSet(struct mScriptValue* obj, const char* member, struct mScri
this->value.opaque = obj->value.opaque; this->value.opaque = obj->value.opaque;
mSCRIPT_PUSH(&frame.arguments, CHARP, member); mSCRIPT_PUSH(&frame.arguments, CHARP, member);
mScriptValueWrap(val, mScriptListAppend(&frame.arguments)); mScriptValueWrap(val, mScriptListAppend(&frame.arguments));
bool needsDeref = mScriptListGetPointer(&frame.arguments, 2)->type->base == mSCRIPT_TYPE_WRAPPER;
if (!mScriptInvoke(&setMember, &frame) || mScriptListSize(&frame.returnValues) != 0) { if (!mScriptInvoke(&setMember, &frame) || mScriptListSize(&frame.returnValues) != 0) {
mScriptFrameDeinit(&frame); mScriptFrameDeinit(&frame);
if (needsDeref) {
mScriptValueDeref(val);
}
return false; return false;
} }
mScriptFrameDeinit(&frame); mScriptFrameDeinit(&frame);
if (needsDeref) {
mScriptValueDeref(val);
}
return true; return true;
} }