Scripting: Clean up refcounting

This commit is contained in:
Vicki Pfau 2023-02-08 00:39:02 -08:00
parent 5c0bd1b245
commit 282a033df2
4 changed files with 63 additions and 44 deletions

View File

@ -250,6 +250,7 @@ uint32_t mScriptContextAddCallback(struct mScriptContext* context, const char* c
HashTableIteratorLookup(&context->callbacks, &iter, callback);
info->callback = HashTableIteratorGetKey(&context->callbacks, &iter);
info->id = mScriptListSize(list->value.list);
mScriptValueRef(fn);
mScriptValueWrap(fn, mScriptListAppend(list->value.list));
while (true) {
uint32_t id = context->nextCallbackId;

View File

@ -36,8 +36,11 @@ static const char* _luaGetError(struct mScriptEngineContext*);
static bool _luaCall(struct mScriptFrame*, void* context);
static void _freeFrame(struct mScriptList* frame);
static void _autofreeFrame(struct mScriptContext* context, struct mScriptList* frame);
struct mScriptEngineContextLua;
static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*, bool internal);
static bool _luaPushFrame(struct mScriptEngineContextLua*, struct mScriptList*);
static bool _luaPopFrame(struct mScriptEngineContextLua*, struct mScriptList*);
static bool _luaInvoke(struct mScriptEngineContextLua*, struct mScriptFrame*);
@ -520,6 +523,8 @@ struct mScriptValue* _luaRootScope(struct mScriptEngineContext* ctx) {
lua_pop(luaContext->lua, 1);
key = _luaCoerce(luaContext, false);
mScriptValueWrap(key, mScriptListAppend(list->value.list));
mScriptValueRef(key);
mScriptContextFillPool(luaContext->d.context, key);
}
lua_pop(luaContext->lua, 1);
@ -544,7 +549,7 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
lua_pushnil(luaContext->lua);
void* tablePointer;
const void* tablePointer;
while (lua_next(luaContext->lua, -2) != 0) {
struct mScriptValue* value = NULL;
int type = lua_type(luaContext->lua, -1);
@ -559,11 +564,8 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
tablePointer = lua_topointer(luaContext->lua, -1);
// Ensure this table doesn't contain any cycles
if (!HashTableLookupBinary(markedObjects, &tablePointer, sizeof(tablePointer))) {
HashTableInsertBinary(markedObjects, &tablePointer, sizeof(tablePointer), tablePointer);
HashTableInsertBinary(markedObjects, &tablePointer, sizeof(tablePointer), (void*) tablePointer);
value = _luaCoerceTable(luaContext, markedObjects);
if (value) {
mScriptValueRef(value);
}
}
default:
break;
@ -595,18 +597,13 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
return false;
}
mScriptTableInsert(table, key, value);
if (key->type != mSCRIPT_TYPE_MS_STR) {
// Strings are added to the ref pool, so we need to keep it
// ref'd to prevent it from being collected prematurely
mScriptValueDeref(key);
}
mScriptValueDeref(key);
mScriptValueDeref(value);
}
lua_pop(luaContext->lua, 1);
size_t len = mScriptTableSize(table);
if (!isList || !len) {
mScriptContextFillPool(luaContext->d.context, table);
return table;
}
@ -616,18 +613,15 @@ struct mScriptValue* _luaCoerceTable(struct mScriptEngineContextLua* luaContext,
struct mScriptValue* value = mScriptTableLookup(table, &mSCRIPT_MAKE_S64(i));
if (!value) {
mScriptValueDeref(list);
mScriptContextFillPool(luaContext->d.context, table);
return table;
}
mScriptValueWrap(value, mScriptListAppend(list->value.list));
}
if (i != len + 1) {
mScriptValueDeref(list);
mScriptContextFillPool(luaContext->d.context, table);
return table;
}
mScriptValueDeref(table);
mScriptContextFillPool(luaContext->d.context, list);
return list;
}
@ -663,7 +657,6 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool
case LUA_TSTRING:
buffer = lua_tolstring(luaContext->lua, -1, &size);
value = mScriptStringCreateFromBytes(buffer, size);
mScriptContextFillPool(luaContext->d.context, value);
break;
case LUA_TFUNCTION:
// This function pops the value internally via luaL_ref
@ -692,6 +685,12 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool
lua_pop(luaContext->lua, 2);
value = lua_touserdata(luaContext->lua, -1);
value = mScriptContextAccessWeakref(luaContext->d.context, value);
if (value->type->base == mSCRIPT_TYPE_WRAPPER) {
value = mScriptValueUnwrap(value);
}
if (value) {
mScriptValueRef(value);
}
break;
}
if (pop) {
@ -713,6 +712,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
lua_pushnil(luaContext->lua);
return true;
}
mScriptContextFillPool(luaContext->d.context, value);
}
struct mScriptValue derefPtr;
if (value->type->base == mSCRIPT_TYPE_OPAQUE) {
@ -803,6 +803,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
if (needsWeakref) {
*newValue = mSCRIPT_MAKE(WEAKREF, weakref);
} else {
mScriptValueRef(value);
mScriptValueWrap(value, newValue);
}
lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTList");
@ -813,6 +814,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
if (needsWeakref) {
*newValue = mSCRIPT_MAKE(WEAKREF, weakref);
} else {
mScriptValueRef(value);
mScriptValueWrap(value, newValue);
}
lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTTable");
@ -824,7 +826,6 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
newValue->refs = mSCRIPT_VALUE_UNREF;
newValue->type->alloc(newValue);
lua_pushcclosure(luaContext->lua, _luaThunk, 1);
mScriptValueDeref(value);
break;
case mSCRIPT_TYPE_OBJECT:
if (!value->value.opaque) {
@ -835,6 +836,7 @@ bool _luaWrap(struct mScriptEngineContextLua* luaContext, struct mScriptValue* v
if (needsWeakref) {
*newValue = mSCRIPT_MAKE(WEAKREF, weakref);
} else {
mScriptValueRef(value);
mScriptValueWrap(value, newValue);
}
lua_getfield(luaContext->lua, LUA_REGISTRYINDEX, "mSTStruct");
@ -935,16 +937,12 @@ const char* _luaGetError(struct mScriptEngineContext* context) {
return luaContext->lastError;
}
bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame, bool internal) {
bool _luaPushFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList* frame) {
bool ok = true;
if (frame) {
size_t i;
for (i = 0; i < mScriptListSize(frame); ++i) {
struct mScriptValue* value = mScriptListGetPointer(frame, i);
if (internal && value->type->base == mSCRIPT_TYPE_WRAPPER) {
value = mScriptValueUnwrap(value);
mScriptContextFillPool(luaContext->d.context, value);
}
if (!_luaWrap(luaContext, value)) {
ok = false;
break;
@ -968,8 +966,11 @@ bool _luaPopFrame(struct mScriptEngineContextLua* luaContext, struct mScriptList
ok = false;
break;
}
mScriptValueWrap(value, mScriptListAppend(frame));
mScriptValueDeref(value);
struct mScriptValue* tail = mScriptListAppend(frame);
mScriptValueWrap(value, tail);
if (tail->type == value->type) {
mScriptValueDeref(value);
}
}
if (count > i) {
lua_pop(luaContext->lua, count - i);
@ -996,6 +997,26 @@ bool _luaCall(struct mScriptFrame* frame, void* context) {
return true;
}
void _freeFrame(struct mScriptList* frame) {
size_t i;
for (i = 0; i < mScriptListSize(frame); ++i) {
struct mScriptValue* val = mScriptValueUnwrap(mScriptListGetPointer(frame, i));
if (val) {
mScriptValueDeref(val);
}
}
}
void _autofreeFrame(struct mScriptContext* context, struct mScriptList* frame) {
size_t i;
for (i = 0; i < mScriptListSize(frame); ++i) {
struct mScriptValue* val = mScriptValueUnwrap(mScriptListGetPointer(frame, i));
if (val) {
mScriptContextFillPool(context, val);
}
}
}
bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame* frame) {
int nargs = 0;
if (frame) {
@ -1007,7 +1028,7 @@ bool _luaInvoke(struct mScriptEngineContextLua* luaContext, struct mScriptFrame*
luaContext->lastError = NULL;
}
if (frame && !_luaPushFrame(luaContext, &frame->arguments, false)) {
if (frame && !_luaPushFrame(luaContext, &frame->arguments)) {
return false;
}
@ -1072,6 +1093,7 @@ int _luaThunk(lua_State* lua) {
struct mScriptFrame frame;
mScriptFrameInit(&frame);
if (!_luaPopFrame(luaContext, &frame.arguments)) {
_freeFrame(&frame.arguments);
mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame);
luaL_traceback(lua, lua, "Error calling function (translating arguments into runtime)", 1);
@ -1079,19 +1101,21 @@ int _luaThunk(lua_State* lua) {
}
struct mScriptValue* fn = lua_touserdata(lua, lua_upvalueindex(1));
_autofreeFrame(luaContext->d.context, &frame.arguments);
if (!fn || !mScriptInvoke(fn, &frame)) {
mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame);
luaL_traceback(lua, lua, "Error calling function (invoking failed)", 1);
return lua_error(lua);
}
if (!_luaPushFrame(luaContext, &frame.returnValues, true)) {
mScriptFrameDeinit(&frame);
bool ok = _luaPushFrame(luaContext, &frame.returnValues);
mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame);
if (!ok) {
luaL_traceback(lua, lua, "Error calling function (translating return values from runtime)", 1);
return lua_error(lua);
}
mScriptContextDrainPool(luaContext->d.context);
mScriptFrameDeinit(&frame);
return lua_gettop(luaContext->lua);
}
@ -1146,19 +1170,22 @@ int _luaSetObject(lua_State* lua) {
strlcpy(key, keyPtr, sizeof(key));
lua_pop(lua, 2);
obj = mScriptContextAccessWeakref(luaContext->d.context, obj);
if (!obj) {
luaL_traceback(lua, lua, "Invalid object", 1);
return lua_error(lua);
}
if (!val) {
luaL_traceback(lua, lua, "Error translating value to runtime", 1);
return lua_error(lua);
}
obj = mScriptContextAccessWeakref(luaContext->d.context, obj);
if (!obj) {
mScriptValueDeref(val);
mScriptContextDrainPool(luaContext->d.context);
luaL_traceback(lua, lua, "Invalid object", 1);
return lua_error(lua);
}
if (!mScriptObjectSet(obj, key, val)) {
mScriptValueDeref(val);
mScriptContextDrainPool(luaContext->d.context);
char error[MAX_KEY_SIZE + 16];
snprintf(error, sizeof(error), "Invalid key '%s'", key);
luaL_traceback(lua, lua, "Invalid key", 1);

View File

@ -24,7 +24,6 @@ static uint32_t _mScriptCallbackAdd(struct mScriptCallbackManager* adapter, stru
fn = mScriptValueUnwrap(fn);
}
uint32_t id = mScriptContextAddCallback(adapter->context, name->buffer, fn);
mScriptValueDeref(fn);
return id;
}

View File

@ -892,7 +892,6 @@ void mScriptValueWrap(struct mScriptValue* value, struct mScriptValue* out) {
out->type = mSCRIPT_TYPE_MS_WRAPPER;
out->value.opaque = value;
mScriptValueRef(value);
}
struct mScriptValue* mScriptValueUnwrap(struct mScriptValue* value) {
@ -1473,18 +1472,11 @@ bool mScriptObjectSet(struct mScriptValue* obj, const char* member, struct mScri
this->value.opaque = obj->value.opaque;
mSCRIPT_PUSH(&frame.arguments, CHARP, member);
mScriptValueWrap(val, mScriptListAppend(&frame.arguments));
bool needsDeref = mScriptListGetPointer(&frame.arguments, 2)->type->base == mSCRIPT_TYPE_WRAPPER;
if (!mScriptInvoke(&setMember, &frame) || mScriptListSize(&frame.returnValues) != 0) {
mScriptFrameDeinit(&frame);
if (needsDeref) {
mScriptValueDeref(val);
}
return false;
}
mScriptFrameDeinit(&frame);
if (needsDeref) {
mScriptValueDeref(val);
}
return true;
}