From af9f86984a2ce329cb44a97459592f0b191fe252 Mon Sep 17 00:00:00 2001 From: TheWover <17090738+TheWover@users.noreply.github.com> Date: Tue, 1 Dec 2020 14:53:16 -0600 Subject: [PATCH] Fixed #12 and #13 --- DInvoke/DInvoke/DynamicInvoke/Generic.cs | 83 ++++++++++++++++++++---- DInvoke/DInvoke/ManualMap/Map.cs | 9 ++- 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/DInvoke/DInvoke/DynamicInvoke/Generic.cs b/DInvoke/DInvoke/DynamicInvoke/Generic.cs index dbd8e80..5e8befc 100644 --- a/DInvoke/DInvoke/DynamicInvoke/Generic.cs +++ b/DInvoke/DInvoke/DynamicInvoke/Generic.cs @@ -77,7 +77,7 @@ public static IntPtr LoadModuleFromDisk(string DLLPath) /// Name of the exported procedure. /// Optional, indicates if the function can try to load the DLL from disk if it is not found in the loaded module list. /// IntPtr for the desired function. - public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool CanLoadFromDisk = false) + public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool CanLoadFromDisk = false, bool ResolveForwards = false) { IntPtr hModule = GetLoadedModuleAddress(DLLName); if (hModule == IntPtr.Zero && CanLoadFromDisk) @@ -93,7 +93,7 @@ public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool throw new DllNotFoundException(DLLName + ", Dll was not found."); } - return GetExportAddress(hModule, FunctionName); + return GetExportAddress(hModule, FunctionName, ResolveForwards); } /// @@ -104,7 +104,7 @@ public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool /// Ordinal of the exported procedure. /// Optional, indicates if the function can try to load the DLL from disk if it is not found in the loaded module list. /// IntPtr for the desired function. - public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLoadFromDisk = false) + public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLoadFromDisk = false, bool ResolveForwards = false) { IntPtr hModule = GetLoadedModuleAddress(DLLName); if (hModule == IntPtr.Zero && CanLoadFromDisk) @@ -120,7 +120,7 @@ public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLo throw new DllNotFoundException(DLLName + ", Dll was not found."); } - return GetExportAddress(hModule, Ordinal); + return GetExportAddress(hModule, Ordinal, ResolveForwards: ResolveForwards); } /// @@ -132,7 +132,7 @@ public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLo /// 64-bit integer to initialize the keyed hash object (e.g. 0xabc or 0x1122334455667788). /// Optional, indicates if the function can try to load the DLL from disk if it is not found in the loaded module list. /// IntPtr for the desired function. - public static IntPtr GetLibraryAddress(string DLLName, string FunctionHash, long Key, bool CanLoadFromDisk = false) + public static IntPtr GetLibraryAddress(string DLLName, string FunctionHash, long Key, bool CanLoadFromDisk = false, bool ResolveForwards = false) { IntPtr hModule = GetLoadedModuleAddress(DLLName); if (hModule == IntPtr.Zero && CanLoadFromDisk) @@ -148,7 +148,7 @@ public static IntPtr GetLibraryAddress(string DLLName, string FunctionHash, long throw new DllNotFoundException(DLLName + ", Dll was not found."); } - return GetExportAddress(hModule, FunctionHash, Key); + return GetExportAddress(hModule, FunctionHash, Key, ResolveForwards: ResolveForwards); } /// @@ -252,7 +252,7 @@ public static string GetAPIHash(string APIName, long Key) /// A pointer to the base address where the module is loaded in the current process. /// The name of the export to search for (e.g. "NtAlertResumeThread"). /// IntPtr for the desired function. - public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName) + public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName, bool ResolveForwards = false) { IntPtr FunctionPtr = IntPtr.Zero; try @@ -281,15 +281,26 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName) Int32 NamesRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + ExportRVA + 0x20)); Int32 OrdinalsRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + ExportRVA + 0x24)); + // Get the VAs of the name table's beginning and end. + Int64 NamesBegin = ModuleBase.ToInt64() + Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + NamesRVA)); + Int64 NamesFinal = NamesBegin + NumberOfNames * 4; + // Loop the array of export name RVA's for (int i = 0; i < NumberOfNames; i++) { string FunctionName = Marshal.PtrToStringAnsi((IntPtr)(ModuleBase.ToInt64() + Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + NamesRVA + i * 4)))); + if (FunctionName.Equals(ExportName, StringComparison.OrdinalIgnoreCase)) { + Int32 FunctionOrdinal = Marshal.ReadInt16((IntPtr)(ModuleBase.ToInt64() + OrdinalsRVA + i * 2)) + OrdinalBase; Int32 FunctionRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + FunctionsRVA + (4 * (FunctionOrdinal - OrdinalBase)))); FunctionPtr = (IntPtr)((Int64)ModuleBase + FunctionRVA); + + if (ResolveForwards == true) + // If the export address points to a forward, get the address + FunctionPtr = GetForwardAddress(FunctionPtr); + break; } } @@ -315,7 +326,7 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName) /// A pointer to the base address where the module is loaded in the current process. /// The ordinal number to search for (e.g. 0x136 -> ntdll!NtCreateThreadEx). /// IntPtr for the desired function. - public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal) + public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal, bool ResolveForwards = false) { IntPtr FunctionPtr = IntPtr.Zero; try @@ -352,6 +363,11 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal) { Int32 FunctionRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + FunctionsRVA + (4 * (FunctionOrdinal - OrdinalBase)))); FunctionPtr = (IntPtr)((Int64)ModuleBase + FunctionRVA); + + if (ResolveForwards == true) + // If the export address points to a forward, get the address + FunctionPtr = GetForwardAddress(FunctionPtr); + break; } } @@ -378,7 +394,7 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal) /// Hash of the exported procedure. /// 64-bit integer to initialize the keyed hash object (e.g. 0xabc or 0x1122334455667788). /// IntPtr for the desired function. - public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, long Key) + public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, long Key, bool ResolveForwards = false) { IntPtr FunctionPtr = IntPtr.Zero; try @@ -416,6 +432,11 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, lo Int32 FunctionOrdinal = Marshal.ReadInt16((IntPtr)(ModuleBase.ToInt64() + OrdinalsRVA + i * 2)) + OrdinalBase; Int32 FunctionRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + FunctionsRVA + (4 * (FunctionOrdinal - OrdinalBase)))); FunctionPtr = (IntPtr)((Int64)ModuleBase + FunctionRVA); + + if (ResolveForwards == true) + // If the export address points to a forward, get the address + FunctionPtr = GetForwardAddress(FunctionPtr); + break; } } @@ -434,6 +455,45 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, lo return FunctionPtr; } + /// + /// Check if an address to an exported function should be resolved to a forward. If so, return the address of the forward. + /// + /// The Wover (@TheRealWover) + /// Function of an exported address, found by parsing a PE file's export table. + /// IntPtr for the forward. If the function is not forwarded, return the original pointer. + public static IntPtr GetForwardAddress(IntPtr ExportAddress) + { + IntPtr FunctionPtr = ExportAddress; + try + { + // Assume it is a forward. If it is not, we will get an error + string ForwardNames = Marshal.PtrToStringAnsi(FunctionPtr); + string[] values = ForwardNames.Split('.'); + + string ForwardModuleName = values[0]; + string ForwardExportName = values[1]; + + // Check if it is an API Set mapping + Dictionary ApiSet = GetApiSetMapping(); + string LookupKey = ForwardModuleName.Substring(0, ForwardModuleName.Length - 2) + ".dll"; + if (ApiSet.ContainsKey(LookupKey)) + ForwardModuleName = ApiSet[LookupKey]; + else + ForwardModuleName = ForwardModuleName + ".dll"; + + IntPtr hModule = GetPebLdrModuleEntry(ForwardModuleName); + if (hModule != IntPtr.Zero) + { + FunctionPtr = GetExportAddress(hModule, ForwardExportName); + } + } + catch + { + // Do nothing, it was not a forward + } + return FunctionPtr; + } + /// /// Given a module base address, resolve the address of a function by calling LdrGetProcedureAddress. /// @@ -548,7 +608,8 @@ public static Dictionary GetApiSetMapping() { Data.PE.ApiSetNamespaceEntry SetEntry = new Data.PE.ApiSetNamespaceEntry(); SetEntry = (Data.PE.ApiSetNamespaceEntry)Marshal.PtrToStructure((IntPtr)((UInt64)pApiSetNamespace + (UInt64)Namespace.EntryOffset + (UInt64)(i * Marshal.SizeOf(SetEntry))), typeof(Data.PE.ApiSetNamespaceEntry)); - string ApiSetEntryName = Marshal.PtrToStringUni((IntPtr)((UInt64)pApiSetNamespace + (UInt64)SetEntry.NameOffset), SetEntry.NameLength/2) + ".dll"; + string ApiSetEntryName = Marshal.PtrToStringUni((IntPtr)((UInt64)pApiSetNamespace + (UInt64)SetEntry.NameOffset), SetEntry.NameLength/2); + string ApiSetEntryKey = ApiSetEntryName.Substring(0, ApiSetEntryName.Length - 2) + ".dll" ; // Remove the patch number and add .dll Data.PE.ApiSetValueEntry SetValue = new Data.PE.ApiSetValueEntry(); SetValue = (Data.PE.ApiSetValueEntry)Marshal.PtrToStructure((IntPtr)((UInt64)pApiSetNamespace + (UInt64)SetEntry.ValueOffset), typeof(Data.PE.ApiSetValueEntry)); @@ -559,7 +620,7 @@ public static Dictionary GetApiSetMapping() } // Add pair to dict - ApiSetDict.Add(ApiSetEntryName, ApiSetValue); + ApiSetDict.Add(ApiSetEntryKey, ApiSetValue); } // Return dict diff --git a/DInvoke/DInvoke/ManualMap/Map.cs b/DInvoke/DInvoke/ManualMap/Map.cs index 407bad6..49eff7d 100644 --- a/DInvoke/DInvoke/ManualMap/Map.cs +++ b/DInvoke/DInvoke/ManualMap/Map.cs @@ -235,12 +235,13 @@ public static void RewriteModuleIAT(Data.PE.PE_META_DATA PEINFO, IntPtr ModuleMe } else { - // API Set DLL? + string LookupKey = DllName.Substring(0, DllName.Length - 6) + ".dll"; + // API Set DLL? Ignore the patch number. if (OSVersion.MajorVersion >= 10 && (DllName.StartsWith("api-") || DllName.StartsWith("ext-")) && - ApiSetDict.ContainsKey(DllName) && ApiSetDict[DllName].Length > 0) + ApiSetDict.ContainsKey(LookupKey) && ApiSetDict[LookupKey].Length > 0) { // Not all API set DLL's have a registered host mapping - DllName = ApiSetDict[DllName]; + DllName = ApiSetDict[LookupKey]; } // Check and / or load DLL @@ -319,6 +320,8 @@ public static void RewriteModuleIAT(Data.PE.PE_META_DATA PEINFO, IntPtr ModuleMe } } } + + // Go to the next IID counter++; iid = (Data.Win32.Kernel32.IMAGE_IMPORT_DESCRIPTOR)Marshal.PtrToStructure( (IntPtr)((UInt64)pImportTable + (uint)(Marshal.SizeOf(iid) * counter)),