Scripting: Lua requires should be relative to the file they were run from

This commit is contained in:
Vicki Pfau 2022-06-08 21:44:01 -07:00
parent b42b997f00
commit 3e4f1fcb2e
1 changed files with 94 additions and 7 deletions

View File

@ -49,6 +49,8 @@ static int _luaPairsTable(lua_State* lua);
static int _luaGetList(lua_State* lua);
static int _luaLenList(lua_State* lua);
static int _luaRequireShim(lua_State* lua);
#if LUA_VERSION_NUM < 503
#define lua_pushinteger lua_pushnumber
#endif
@ -78,6 +80,8 @@ struct mScriptEngineContextLua {
struct mScriptEngineContext d;
lua_State* lua;
int func;
char lastDirectory[PATH_MAX];
int require;
char* lastError;
};
@ -163,6 +167,9 @@ struct mScriptEngineContext* _luaCreate(struct mScriptEngine2* engine, struct mS
#endif
lua_pop(luaContext->lua, 1);
lua_getglobal(luaContext->lua, "require");
luaContext->require = luaL_ref(luaContext->lua, LUA_REGISTRYINDEX);
return &luaContext->d;
}
@ -330,7 +337,7 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool
}
return _luaCoerceFunction(luaContext);
case LUA_TTABLE:
// This function pops the value internally via luaL_ref
// This function pops the value internally
if (!pop) {
break;
}
@ -481,12 +488,11 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi
luaContext->lastError = NULL;
}
char name[PATH_MAX + 1];
char dirname[PATH_MAX] = {0};
if (filename) {
if (*filename == '*') {
snprintf(name, sizeof(name), "=%s", filename + 1);
} else {
#ifdef _WIN32
wchar_t dirname[PATH_MAX] = {0};
const char* lastSlash = strrchr(filename, '/');
const char* lastBackslash = strrchr(filename, '\\');
if (lastSlash && lastBackslash) {
@ -497,10 +503,8 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi
lastSlash = lastBackslash;
}
if (lastSlash) {
MultiByteToWideChar(CP_UTF8, 0, filename, lastSlash - filename, dirname, PATH_MAX);
AddDllDirectory(dirname);
strncpy(dirname, filename, lastSlash - filename);
}
#endif
snprintf(name, sizeof(name), "@%s", filename);
}
filename = name;
@ -509,6 +513,11 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi
switch (ret) {
case LUA_OK:
luaContext->func = luaL_ref(luaContext->lua, LUA_REGISTRYINDEX);
if (dirname[0]) {
strncpy(luaContext->lastDirectory, dirname, sizeof(luaContext->lastDirectory));
} else {
memset(luaContext->lastDirectory, 0, sizeof(luaContext->lastDirectory));
}
return true;
case LUA_ERRSYNTAX:
luaContext->lastError = strdup(lua_tostring(luaContext->lua, -1));
@ -522,8 +531,21 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi
bool _luaRun(struct mScriptEngineContext* context) {
struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) context;
if (luaContext->lastDirectory[0]) {
// Shim require to look in the previous location
lua_pushstring(luaContext->lua, luaContext->lastDirectory);
lua_pushcclosure(luaContext->lua, _luaRequireShim, 1);
lua_setglobal(luaContext->lua, "require");
}
lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->func);
return _luaInvoke(luaContext, NULL);
bool ret = _luaInvoke(luaContext, NULL);
// Restore previous value of require
lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->require);
lua_setglobal(luaContext->lua, "require");
return ret;
}
const char* _luaGetError(struct mScriptEngineContext* context) {
@ -962,3 +984,68 @@ static int _luaLenList(lua_State* lua) {
lua_pushinteger(lua, mScriptListSize(list));
return 1;
}
static int _luaRequireShim(lua_State* lua) {
struct mScriptEngineContextLua* luaContext = _luaGetContext(lua);
const char* path = lua_tostring(lua, lua_upvalueindex(1));
const char* oldpath;
const char* oldcpath;
lua_getglobal(luaContext->lua, "package");
lua_pushliteral(luaContext->lua, "path");
lua_pushliteral(luaContext->lua, "path");
lua_gettable(luaContext->lua, -3);
oldpath = strdup(lua_tostring(luaContext->lua, -1));
lua_pushliteral(luaContext->lua, ";");
lua_pushstring(luaContext->lua, path);
lua_pushliteral(luaContext->lua, "/?.lua;");
lua_pushstring(luaContext->lua, path);
lua_pushliteral(luaContext->lua, "/?/init.lua");
lua_concat(luaContext->lua, 6);
lua_settable(luaContext->lua, -3);
#ifdef _WIN32
#define DLL "dll"
#elif defined(__APPLE__)
#define DLL "dylib"
#else
#define DLL "so"
#endif
lua_pushliteral(luaContext->lua, "cpath");
lua_pushliteral(luaContext->lua, "cpath");
lua_gettable(luaContext->lua, -3);
oldcpath = strdup(lua_tostring(luaContext->lua, -1));
lua_pushliteral(luaContext->lua, ";");
lua_pushstring(luaContext->lua, path);
lua_pushliteral(luaContext->lua, "/?." DLL ";");
lua_pushstring(luaContext->lua, path);
lua_pushliteral(luaContext->lua, "/?/init." DLL);
lua_concat(luaContext->lua, 6);
lua_settable(luaContext->lua, -3);
lua_pop(luaContext->lua, 1);
lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->require);
lua_rotate(luaContext->lua, -2, 1);
int ret = lua_pcall(luaContext->lua, 1, 0, 0);
lua_getglobal(luaContext->lua, "package");
lua_pushliteral(luaContext->lua, "path");
lua_pushstring(luaContext->lua, oldpath);
lua_settable(luaContext->lua, -3);
lua_pushliteral(luaContext->lua, "cpath");
lua_pushstring(luaContext->lua, oldcpath);
lua_settable(luaContext->lua, -3);
lua_pop(luaContext->lua, 1);
if (ret) {
lua_error(luaContext->lua);
}
return 0;
}