diff --git a/src/script/engines/lua.c b/src/script/engines/lua.c index 096532b8d..343b07cb2 100644 --- a/src/script/engines/lua.c +++ b/src/script/engines/lua.c @@ -49,6 +49,8 @@ static int _luaPairsTable(lua_State* lua); static int _luaGetList(lua_State* lua); static int _luaLenList(lua_State* lua); +static int _luaRequireShim(lua_State* lua); + #if LUA_VERSION_NUM < 503 #define lua_pushinteger lua_pushnumber #endif @@ -78,6 +80,8 @@ struct mScriptEngineContextLua { struct mScriptEngineContext d; lua_State* lua; int func; + char lastDirectory[PATH_MAX]; + int require; char* lastError; }; @@ -163,6 +167,9 @@ struct mScriptEngineContext* _luaCreate(struct mScriptEngine2* engine, struct mS #endif lua_pop(luaContext->lua, 1); + lua_getglobal(luaContext->lua, "require"); + luaContext->require = luaL_ref(luaContext->lua, LUA_REGISTRYINDEX); + return &luaContext->d; } @@ -330,7 +337,7 @@ struct mScriptValue* _luaCoerce(struct mScriptEngineContextLua* luaContext, bool } return _luaCoerceFunction(luaContext); case LUA_TTABLE: - // This function pops the value internally via luaL_ref + // This function pops the value internally if (!pop) { break; } @@ -481,12 +488,11 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi luaContext->lastError = NULL; } char name[PATH_MAX + 1]; + char dirname[PATH_MAX] = {0}; if (filename) { if (*filename == '*') { snprintf(name, sizeof(name), "=%s", filename + 1); } else { -#ifdef _WIN32 - wchar_t dirname[PATH_MAX] = {0}; const char* lastSlash = strrchr(filename, '/'); const char* lastBackslash = strrchr(filename, '\\'); if (lastSlash && lastBackslash) { @@ -497,10 +503,8 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi lastSlash = lastBackslash; } if (lastSlash) { - MultiByteToWideChar(CP_UTF8, 0, filename, lastSlash - filename, dirname, PATH_MAX); - AddDllDirectory(dirname); + strncpy(dirname, filename, lastSlash - filename); } -#endif snprintf(name, sizeof(name), "@%s", filename); } filename = name; @@ -509,6 +513,11 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi switch (ret) { case LUA_OK: luaContext->func = luaL_ref(luaContext->lua, LUA_REGISTRYINDEX); + if (dirname[0]) { + strncpy(luaContext->lastDirectory, dirname, sizeof(luaContext->lastDirectory)); + } else { + memset(luaContext->lastDirectory, 0, sizeof(luaContext->lastDirectory)); + } return true; case LUA_ERRSYNTAX: luaContext->lastError = strdup(lua_tostring(luaContext->lua, -1)); @@ -522,8 +531,21 @@ bool _luaLoad(struct mScriptEngineContext* ctx, const char* filename, struct VFi bool _luaRun(struct mScriptEngineContext* context) { struct mScriptEngineContextLua* luaContext = (struct mScriptEngineContextLua*) context; + + if (luaContext->lastDirectory[0]) { + // Shim require to look in the previous location + lua_pushstring(luaContext->lua, luaContext->lastDirectory); + lua_pushcclosure(luaContext->lua, _luaRequireShim, 1); + lua_setglobal(luaContext->lua, "require"); + } + lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->func); - return _luaInvoke(luaContext, NULL); + bool ret = _luaInvoke(luaContext, NULL); + + // Restore previous value of require + lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->require); + lua_setglobal(luaContext->lua, "require"); + return ret; } const char* _luaGetError(struct mScriptEngineContext* context) { @@ -962,3 +984,68 @@ static int _luaLenList(lua_State* lua) { lua_pushinteger(lua, mScriptListSize(list)); return 1; } + +static int _luaRequireShim(lua_State* lua) { + struct mScriptEngineContextLua* luaContext = _luaGetContext(lua); + + const char* path = lua_tostring(lua, lua_upvalueindex(1)); + const char* oldpath; + const char* oldcpath; + + lua_getglobal(luaContext->lua, "package"); + + lua_pushliteral(luaContext->lua, "path"); + lua_pushliteral(luaContext->lua, "path"); + lua_gettable(luaContext->lua, -3); + oldpath = strdup(lua_tostring(luaContext->lua, -1)); + lua_pushliteral(luaContext->lua, ";"); + lua_pushstring(luaContext->lua, path); + lua_pushliteral(luaContext->lua, "/?.lua;"); + lua_pushstring(luaContext->lua, path); + lua_pushliteral(luaContext->lua, "/?/init.lua"); + lua_concat(luaContext->lua, 6); + lua_settable(luaContext->lua, -3); + +#ifdef _WIN32 +#define DLL "dll" +#elif defined(__APPLE__) +#define DLL "dylib" +#else +#define DLL "so" +#endif + lua_pushliteral(luaContext->lua, "cpath"); + lua_pushliteral(luaContext->lua, "cpath"); + lua_gettable(luaContext->lua, -3); + oldcpath = strdup(lua_tostring(luaContext->lua, -1)); + lua_pushliteral(luaContext->lua, ";"); + lua_pushstring(luaContext->lua, path); + lua_pushliteral(luaContext->lua, "/?." DLL ";"); + lua_pushstring(luaContext->lua, path); + lua_pushliteral(luaContext->lua, "/?/init." DLL); + lua_concat(luaContext->lua, 6); + lua_settable(luaContext->lua, -3); + + lua_pop(luaContext->lua, 1); + + lua_rawgeti(luaContext->lua, LUA_REGISTRYINDEX, luaContext->require); + lua_rotate(luaContext->lua, -2, 1); + int ret = lua_pcall(luaContext->lua, 1, 0, 0); + + lua_getglobal(luaContext->lua, "package"); + + lua_pushliteral(luaContext->lua, "path"); + lua_pushstring(luaContext->lua, oldpath); + lua_settable(luaContext->lua, -3); + + lua_pushliteral(luaContext->lua, "cpath"); + lua_pushstring(luaContext->lua, oldcpath); + lua_settable(luaContext->lua, -3); + + lua_pop(luaContext->lua, 1); + + if (ret) { + lua_error(luaContext->lua); + } + + return 0; +}