Skip to content

Commit

Permalink
Fixed #12 and #13
Browse files Browse the repository at this point in the history
  • Loading branch information
TheWover committed Dec 1, 2020
1 parent bed08aa commit af9f869
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 14 deletions.
83 changes: 72 additions & 11 deletions DInvoke/DInvoke/DynamicInvoke/Generic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public static IntPtr LoadModuleFromDisk(string DLLPath)
/// <param name="FunctionName">Name of the exported procedure.</param>
/// <param name="CanLoadFromDisk">Optional, indicates if the function can try to load the DLL from disk if it is not found in the loaded module list.</param>
/// <returns>IntPtr for the desired function.</returns>
public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool CanLoadFromDisk = false)
public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool CanLoadFromDisk = false, bool ResolveForwards = false)
{
IntPtr hModule = GetLoadedModuleAddress(DLLName);
if (hModule == IntPtr.Zero && CanLoadFromDisk)
Expand All @@ -93,7 +93,7 @@ public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool
throw new DllNotFoundException(DLLName + ", Dll was not found.");
}

return GetExportAddress(hModule, FunctionName);
return GetExportAddress(hModule, FunctionName, ResolveForwards);
}

/// <summary>
Expand All @@ -104,7 +104,7 @@ public static IntPtr GetLibraryAddress(string DLLName, string FunctionName, bool
/// <param name="Ordinal">Ordinal of the exported procedure.</param>
/// <param name="CanLoadFromDisk">Optional, indicates if the function can try to load the DLL from disk if it is not found in the loaded module list.</param>
/// <returns>IntPtr for the desired function.</returns>
public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLoadFromDisk = false)
public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLoadFromDisk = false, bool ResolveForwards = false)
{
IntPtr hModule = GetLoadedModuleAddress(DLLName);
if (hModule == IntPtr.Zero && CanLoadFromDisk)
Expand All @@ -120,7 +120,7 @@ public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLo
throw new DllNotFoundException(DLLName + ", Dll was not found.");
}

return GetExportAddress(hModule, Ordinal);
return GetExportAddress(hModule, Ordinal, ResolveForwards: ResolveForwards);
}

/// <summary>
Expand All @@ -132,7 +132,7 @@ public static IntPtr GetLibraryAddress(string DLLName, short Ordinal, bool CanLo
/// <param name="Key">64-bit integer to initialize the keyed hash object (e.g. 0xabc or 0x1122334455667788).</param>
/// <param name="CanLoadFromDisk">Optional, indicates if the function can try to load the DLL from disk if it is not found in the loaded module list.</param>
/// <returns>IntPtr for the desired function.</returns>
public static IntPtr GetLibraryAddress(string DLLName, string FunctionHash, long Key, bool CanLoadFromDisk = false)
public static IntPtr GetLibraryAddress(string DLLName, string FunctionHash, long Key, bool CanLoadFromDisk = false, bool ResolveForwards = false)
{
IntPtr hModule = GetLoadedModuleAddress(DLLName);
if (hModule == IntPtr.Zero && CanLoadFromDisk)
Expand All @@ -148,7 +148,7 @@ public static IntPtr GetLibraryAddress(string DLLName, string FunctionHash, long
throw new DllNotFoundException(DLLName + ", Dll was not found.");
}

return GetExportAddress(hModule, FunctionHash, Key);
return GetExportAddress(hModule, FunctionHash, Key, ResolveForwards: ResolveForwards);
}

/// <summary>
Expand Down Expand Up @@ -252,7 +252,7 @@ public static string GetAPIHash(string APIName, long Key)
/// <param name="ModuleBase">A pointer to the base address where the module is loaded in the current process.</param>
/// <param name="ExportName">The name of the export to search for (e.g. "NtAlertResumeThread").</param>
/// <returns>IntPtr for the desired function.</returns>
public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName)
public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName, bool ResolveForwards = false)
{
IntPtr FunctionPtr = IntPtr.Zero;
try
Expand Down Expand Up @@ -281,15 +281,26 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName)
Int32 NamesRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + ExportRVA + 0x20));
Int32 OrdinalsRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + ExportRVA + 0x24));

// Get the VAs of the name table's beginning and end.
Int64 NamesBegin = ModuleBase.ToInt64() + Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + NamesRVA));
Int64 NamesFinal = NamesBegin + NumberOfNames * 4;

// Loop the array of export name RVA's
for (int i = 0; i < NumberOfNames; i++)
{
string FunctionName = Marshal.PtrToStringAnsi((IntPtr)(ModuleBase.ToInt64() + Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + NamesRVA + i * 4))));

if (FunctionName.Equals(ExportName, StringComparison.OrdinalIgnoreCase))
{

Int32 FunctionOrdinal = Marshal.ReadInt16((IntPtr)(ModuleBase.ToInt64() + OrdinalsRVA + i * 2)) + OrdinalBase;
Int32 FunctionRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + FunctionsRVA + (4 * (FunctionOrdinal - OrdinalBase))));
FunctionPtr = (IntPtr)((Int64)ModuleBase + FunctionRVA);

if (ResolveForwards == true)
// If the export address points to a forward, get the address
FunctionPtr = GetForwardAddress(FunctionPtr);

break;
}
}
Expand All @@ -315,7 +326,7 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string ExportName)
/// <param name="ModuleBase">A pointer to the base address where the module is loaded in the current process.</param>
/// <param name="Ordinal">The ordinal number to search for (e.g. 0x136 -> ntdll!NtCreateThreadEx).</param>
/// <returns>IntPtr for the desired function.</returns>
public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal)
public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal, bool ResolveForwards = false)
{
IntPtr FunctionPtr = IntPtr.Zero;
try
Expand Down Expand Up @@ -352,6 +363,11 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal)
{
Int32 FunctionRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + FunctionsRVA + (4 * (FunctionOrdinal - OrdinalBase))));
FunctionPtr = (IntPtr)((Int64)ModuleBase + FunctionRVA);

if (ResolveForwards == true)
// If the export address points to a forward, get the address
FunctionPtr = GetForwardAddress(FunctionPtr);

break;
}
}
Expand All @@ -378,7 +394,7 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, short Ordinal)
/// <param name="FunctionHash">Hash of the exported procedure.</param>
/// <param name="Key">64-bit integer to initialize the keyed hash object (e.g. 0xabc or 0x1122334455667788).</param>
/// <returns>IntPtr for the desired function.</returns>
public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, long Key)
public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, long Key, bool ResolveForwards = false)
{
IntPtr FunctionPtr = IntPtr.Zero;
try
Expand Down Expand Up @@ -416,6 +432,11 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, lo
Int32 FunctionOrdinal = Marshal.ReadInt16((IntPtr)(ModuleBase.ToInt64() + OrdinalsRVA + i * 2)) + OrdinalBase;
Int32 FunctionRVA = Marshal.ReadInt32((IntPtr)(ModuleBase.ToInt64() + FunctionsRVA + (4 * (FunctionOrdinal - OrdinalBase))));
FunctionPtr = (IntPtr)((Int64)ModuleBase + FunctionRVA);

if (ResolveForwards == true)
// If the export address points to a forward, get the address
FunctionPtr = GetForwardAddress(FunctionPtr);

break;
}
}
Expand All @@ -434,6 +455,45 @@ public static IntPtr GetExportAddress(IntPtr ModuleBase, string FunctionHash, lo
return FunctionPtr;
}

/// <summary>
/// Check if an address to an exported function should be resolved to a forward. If so, return the address of the forward.
/// </summary>
/// <author>The Wover (@TheRealWover)</author>
/// <param name="ExportAddress">Function of an exported address, found by parsing a PE file's export table.</param>
/// <returns>IntPtr for the forward. If the function is not forwarded, return the original pointer.</returns>
public static IntPtr GetForwardAddress(IntPtr ExportAddress)
{
IntPtr FunctionPtr = ExportAddress;
try
{
// Assume it is a forward. If it is not, we will get an error
string ForwardNames = Marshal.PtrToStringAnsi(FunctionPtr);
string[] values = ForwardNames.Split('.');

string ForwardModuleName = values[0];
string ForwardExportName = values[1];

// Check if it is an API Set mapping
Dictionary<string, string> ApiSet = GetApiSetMapping();
string LookupKey = ForwardModuleName.Substring(0, ForwardModuleName.Length - 2) + ".dll";
if (ApiSet.ContainsKey(LookupKey))
ForwardModuleName = ApiSet[LookupKey];
else
ForwardModuleName = ForwardModuleName + ".dll";

IntPtr hModule = GetPebLdrModuleEntry(ForwardModuleName);
if (hModule != IntPtr.Zero)
{
FunctionPtr = GetExportAddress(hModule, ForwardExportName);
}
}
catch
{
// Do nothing, it was not a forward
}
return FunctionPtr;
}

/// <summary>
/// Given a module base address, resolve the address of a function by calling LdrGetProcedureAddress.
/// </summary>
Expand Down Expand Up @@ -548,7 +608,8 @@ public static Dictionary<string, string> GetApiSetMapping()
{
Data.PE.ApiSetNamespaceEntry SetEntry = new Data.PE.ApiSetNamespaceEntry();
SetEntry = (Data.PE.ApiSetNamespaceEntry)Marshal.PtrToStructure((IntPtr)((UInt64)pApiSetNamespace + (UInt64)Namespace.EntryOffset + (UInt64)(i * Marshal.SizeOf(SetEntry))), typeof(Data.PE.ApiSetNamespaceEntry));
string ApiSetEntryName = Marshal.PtrToStringUni((IntPtr)((UInt64)pApiSetNamespace + (UInt64)SetEntry.NameOffset), SetEntry.NameLength/2) + ".dll";
string ApiSetEntryName = Marshal.PtrToStringUni((IntPtr)((UInt64)pApiSetNamespace + (UInt64)SetEntry.NameOffset), SetEntry.NameLength/2);
string ApiSetEntryKey = ApiSetEntryName.Substring(0, ApiSetEntryName.Length - 2) + ".dll" ; // Remove the patch number and add .dll

Data.PE.ApiSetValueEntry SetValue = new Data.PE.ApiSetValueEntry();
SetValue = (Data.PE.ApiSetValueEntry)Marshal.PtrToStructure((IntPtr)((UInt64)pApiSetNamespace + (UInt64)SetEntry.ValueOffset), typeof(Data.PE.ApiSetValueEntry));
Expand All @@ -559,7 +620,7 @@ public static Dictionary<string, string> GetApiSetMapping()
}

// Add pair to dict
ApiSetDict.Add(ApiSetEntryName, ApiSetValue);
ApiSetDict.Add(ApiSetEntryKey, ApiSetValue);
}

// Return dict
Expand Down
9 changes: 6 additions & 3 deletions DInvoke/DInvoke/ManualMap/Map.cs
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,13 @@ public static void RewriteModuleIAT(Data.PE.PE_META_DATA PEINFO, IntPtr ModuleMe
}
else
{
// API Set DLL?
string LookupKey = DllName.Substring(0, DllName.Length - 6) + ".dll";
// API Set DLL? Ignore the patch number.
if (OSVersion.MajorVersion >= 10 && (DllName.StartsWith("api-") || DllName.StartsWith("ext-")) &&
ApiSetDict.ContainsKey(DllName) && ApiSetDict[DllName].Length > 0)
ApiSetDict.ContainsKey(LookupKey) && ApiSetDict[LookupKey].Length > 0)
{
// Not all API set DLL's have a registered host mapping
DllName = ApiSetDict[DllName];
DllName = ApiSetDict[LookupKey];
}

// Check and / or load DLL
Expand Down Expand Up @@ -319,6 +320,8 @@ public static void RewriteModuleIAT(Data.PE.PE_META_DATA PEINFO, IntPtr ModuleMe
}
}
}

// Go to the next IID
counter++;
iid = (Data.Win32.Kernel32.IMAGE_IMPORT_DESCRIPTOR)Marshal.PtrToStructure(
(IntPtr)((UInt64)pImportTable + (uint)(Marshal.SizeOf(iid) * counter)),
Expand Down

0 comments on commit af9f869

Please sign in to comment.