Wednesday, August 23, 2006

Where did the function call originate from?

Just a minor update, and once again Microsoft specific. What I'd like to do this time, is determine where a call comes from. The technique is particularily handy if you are developing a dll which is loaded by another module, and you've hooked some other APIs in the process. At that point, you may find yourself wanting to know who's calling a given function, if it's the module itself, or another of its loaded dlls.

Up until recently, I've used inline assembly to fetch the return address off the stack. That's obviously a sub-optimal solution, as you need to know, or figure out, how much has been put on the stack since the function was called. Given that the solution is to be compiled with Microsofts C++-compiler, one can rather take advantage of the _ReturnAddress intrinsic. For caller info, the only other thing that's needed, is a loop through the process' currently loaded modules, using the ToolHelp API, and return the one which address range matches that of the caller.

Example code, with a function/helper pair which returns information about the calling module:

#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include <TlHelp32.h>
#include <intrin.h>
#include <iostream>
#include <string>
#include <exception>

void someMethod(void*);
#define getCallingModuleInfo(pModuleInfo) \
_getCallingModuleInfo
(reinterpret_cast<DWORD_PTR>(_ReturnAddress()), \
pModuleInfo)
BOOL _getCallingModuleInfo(DWORD_PTR dwCallerOffset, MODULEENTRY32* pModuleInfo);

int main(int argc, char* argv[])
{
// Call the method from the test app itself
someMethod(0);

// Have kernel32 call the method
HANDLE hThread = CreateThread(0,
0,
reinterpret_cast<LPTHREAD_START_ROUTINE>(someMethod),
0,
0,
NULL);
WaitForSingleObject(hThread, INFINITE);
CloseHandle(hThread);

std::cout << "Press enter to exit." << std::endl;
std::cin.get();
return 0;
}

/*******************************************
Expected output:

Call originates from: WhosCalling.exe
Call originates from: kernel32.dll
Press enter to exit.
*******************************************/

void someMethod(void*)
{
MODULEENTRY32 moduleInfo;
if(getCallingModuleInfo(&moduleInfo))
{
std::cout << "Call originates from: " << moduleInfo.szModule << std::endl;
}
else
{
std::cout << "Failed to locate the origin of this call" << std::endl;
}
}

BOOL _getCallingModuleInfo(DWORD_PTR dwCallerOffset, MODULEENTRY32* pModuleInfo)
{
_ASSERT(dwCallerOffset != 0);
_ASSERT(pModuleInfo != NULL);

// Create a snapshot of the loaded modules for the current process
HANDLE hModuleSnap;
if((hModuleSnap =
CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId())
) == INVALID_HANDLE_VALUE)
{
throw std::exception("Error: Could not create a process module snapshot.");
}

MODULEENTRY32 me32;
me32.dwSize = sizeof(MODULEENTRY32);

// Prepare to iterate the modules
if(!Module32First(hModuleSnap, &me32))
{
// No modules found, cleanup and exit
CloseHandle(hModuleSnap);
return FALSE;
}

do
{
// Match the address of the calling method against
// the memory range of the currently
iterated module
if(dwCallerOffset >= (DWORD_PTR)me32.modBaseAddr &&
dwCallerOffset < ((DWORD_PTR)me32.modBaseAddr +
(DWORD_PTR)me32.modBaseSize))
{
// Copy the information of the matching module to the user supplied buffer
memcpy_s(pModuleInfo, sizeof(MODULEENTRY32), &me32, sizeof(MODULEENTRY32));

// Cleanup and exit
CloseHandle(hModuleSnap);
return TRUE;
}
} while(Module32Next(hModuleSnap, &me32)); // Continue iteration

// Cleanup and exit
CloseHandle(hModuleSnap);
return FALSE;
}