// matmult.cpp : Defines the entry point for the console application.
//
#include "stdafx.h"
#include <memory.h>
#include <iostream>
#include <windows.h>
#include <D3dx8math.h>
#include <stdio.h>
#include <assert.h>
using namespace std;
// Load an identity matrix
void IdentityMatrixAsm( float *pfM )
{
__asm
{
push ebx
mov eax, pfM
mov ebx, 00000000h // 0.0 in float
mov dword ptr[ eax + 4 ], ebx
mov dword ptr[ eax + 8 ], ebx
mov dword ptr[ eax + 12 ], ebx
mov dword ptr[ eax + 16 ], ebx
mov dword ptr[ eax + 24 ], ebx
mov dword ptr[ eax + 28 ], ebx
mov dword ptr[ eax + 32 ], ebx
mov dword ptr[ eax + 36 ], ebx
mov dword ptr[ eax + 44 ], ebx
mov dword ptr[ eax + 48 ], ebx
mov dword ptr[ eax + 52 ], ebx
mov dword ptr[ eax + 56 ], ebx
mov ebx, 3f800000h // 1.0 in float
mov dword ptr[ eax ], ebx
mov dword ptr[ eax + 20 ], ebx
mov dword ptr[ eax + 40 ], ebx
mov dword ptr[ eax + 60 ], ebx
pop ebx
}
}
// loop unrolled matrix multiplication
// Basically, I load a single row from Matrix 1, push it onto the floating point stack
// and then get the vector product each column from Matrix 2 to get a row of results.
// The assembly is no longer "straightforward", I optimized by grouping reads so
// the their use is delayed as long as possible. (prevents lost cycles)
// Since the initial row values are cached in the fp stack, it IS legal to do:
// MulMatrixAsm( M1, M1, M2 ); for matrix stack operations
// MulMatrixAsm( M2, M1, M2 ); is not possible though.
void MultMatrixAsm( float *pfResults, const float *pfM1, const float *pfM2 )
{
// assertions dont show up in release builds, so it incurs no performance loss.
assert( pfResults != pfM2 );
__asm
{
push ebx
push edx
mov eax, pfM1
mov ebx, pfM2
mov edx, pfResults
fld dword ptr[ eax ]
fld dword ptr[ eax + 4 ]
fld dword ptr[ eax + 8 ]
fld dword ptr[ eax + 12 ]
fld dword ptr[ ebx ]
fmul st(0), st(4)
fld dword ptr[ ebx + 4 ]
fmul st(0), st(5)
fld dword ptr[ ebx + 16 ]
fld dword ptr[ ebx + 20 ]
fmul st(0), st(6)
faddp st(2), st(0)
fmul st(0), st(5)
faddp st(2), st(0)
fld dword ptr[ ebx + 32 ]
fld dword ptr[ ebx + 36 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmul st(0), st(4)
faddp st(2), st(0)
fld dword ptr[ ebx + 48 ]
fld dword ptr[ ebx + 52 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmul st(0), st(3)
faddp st(2), st(0)
fstp dword ptr[ edx + 4 ]
fstp dword ptr[ edx ]
fld dword ptr[ ebx + 12 ]
fmul st(0), st(4)
fld dword ptr[ ebx + 8 ]
fmulp st(5), st(0)
fld dword ptr[ ebx + 24 ]
fld dword ptr[ ebx + 28 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmulp st(4), st(0)
fld dword ptr[ ebx + 40 ]
fld dword ptr[ ebx + 44 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmulp st(3), st(0)
fld dword ptr[ ebx + 56 ]
fld dword ptr[ ebx + 60 ]
fmul st(0), st(3)
faddp st(2), st(0)
fmulp st(2), st(0)
fstp dword ptr[ edx + 12 ]
faddp st(1), st(0)
faddp st(1), st(0)
faddp st(1), st(0)
fstp dword ptr[ edx + 8 ]
fld dword ptr[ eax + 16 ]
fld dword ptr[ eax + 20 ]
fld dword ptr[ eax + 24 ]
fld dword ptr[ eax + 28 ]
fld dword ptr[ ebx ]
fmul st(0), st(4)
fld dword ptr[ ebx + 4 ]
fmul st(0), st(5)
fld dword ptr[ ebx + 16 ]
fld dword ptr[ ebx + 20 ]
fmul st(0), st(6)
faddp st(2), st(0)
fmul st(0), st(5)
faddp st(2), st(0)
fld dword ptr[ ebx + 32 ]
fld dword ptr[ ebx + 36 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmul st(0), st(4)
faddp st(2), st(0)
fld dword ptr[ ebx + 48 ]
fld dword ptr[ ebx + 52 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmul st(0), st(3)
faddp st(2), st(0)
fstp dword ptr[ edx + 20 ]
fstp dword ptr[ edx + 16 ]
fld dword ptr[ ebx + 12 ]
fmul st(0), st(4)
fld dword ptr[ ebx + 8 ]
fmulp st(5), st(0)
fld dword ptr[ ebx + 24 ]
fld dword ptr[ ebx + 28 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmulp st(4), st(0)
fld dword ptr[ ebx + 40 ]
fld dword ptr[ ebx + 44 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmulp st(3), st(0)
fld dword ptr[ ebx + 56 ]
fld dword ptr[ ebx + 60 ]
fmul st(0), st(3)
faddp st(2), st(0)
fmulp st(2), st(0)
fstp dword ptr[ edx + 28 ]
faddp st(1), st(0)
faddp st(1), st(0)
faddp st(1), st(0)
fstp dword ptr[ edx + 24 ]
fld dword ptr[ eax + 32 ]
fld dword ptr[ eax + 36 ]
fld dword ptr[ eax + 40 ]
fld dword ptr[ eax + 44 ]
fld dword ptr[ ebx ]
fmul st(0), st(4)
fld dword ptr[ ebx + 4 ]
fmul st(0), st(5)
fld dword ptr[ ebx + 16 ]
fld dword ptr[ ebx + 20 ]
fmul st(0), st(6)
faddp st(2), st(0)
fmul st(0), st(5)
faddp st(2), st(0)
fld dword ptr[ ebx + 32 ]
fld dword ptr[ ebx + 36 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmul st(0), st(4)
faddp st(2), st(0)
fld dword ptr[ ebx + 48 ]
fld dword ptr[ ebx + 52 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmul st(0), st(3)
faddp st(2), st(0)
fstp dword ptr[ edx + 36 ]
fstp dword ptr[ edx + 32 ]
fld dword ptr[ ebx + 12 ]
fmul st(0), st(4)
fld dword ptr[ ebx + 8 ]
fmulp st(5), st(0)
fld dword ptr[ ebx + 24 ]
fld dword ptr[ ebx + 28 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmulp st(4), st(0)
fld dword ptr[ ebx + 40 ]
fld dword ptr[ ebx + 44 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmulp st(3), st(0)
fld dword ptr[ ebx + 56 ]
fld dword ptr[ ebx + 60 ]
fmul st(0), st(3)
faddp st(2), st(0)
fmulp st(2), st(0)
fstp dword ptr[ edx + 44 ]
faddp st(1), st(0)
faddp st(1), st(0)
faddp st(1), st(0)
fstp dword ptr[ edx + 40 ]
fld dword ptr[ eax + 48 ]
fld dword ptr[ eax + 52 ]
fld dword ptr[ eax + 56 ]
fld dword ptr[ eax + 60 ]
fld dword ptr[ ebx ]
fmul st(0), st(4)
fld dword ptr[ ebx + 4 ]
fmul st(0), st(5)
fld dword ptr[ ebx + 16 ]
fld dword ptr[ ebx + 20 ]
fmul st(0), st(6)
faddp st(2), st(0)
fmul st(0), st(5)
faddp st(2), st(0)
fld dword ptr[ ebx + 32 ]
fld dword ptr[ ebx + 36 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmul st(0), st(4)
faddp st(2), st(0)
fld dword ptr[ ebx + 48 ]
fld dword ptr[ ebx + 52 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmul st(0), st(3)
faddp st(2), st(0)
fstp dword ptr[ edx + 52 ]
fstp dword ptr[ edx + 48 ]
fld dword ptr[ ebx + 12 ]
fmul st(0), st(4)
fld dword ptr[ ebx + 8 ]
fmulp st(5), st(0)
fld dword ptr[ ebx + 24 ]
fld dword ptr[ ebx + 28 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmulp st(4), st(0)
fld dword ptr[ ebx + 40 ]
fld dword ptr[ ebx + 44 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmulp st(3), st(0)
fld dword ptr[ ebx + 56 ]
fld dword ptr[ ebx + 60 ]
fmul st(0), st(3)
faddp st(2), st(0)
fmulp st(2), st(0)
fstp dword ptr[ edx + 60 ]
faddp st(1), st(0)
faddp st(1), st(0)
faddp st(1), st(0)
fstp dword ptr[ edx + 56 ]
pop edx
pop ebx
}
}
/*
// This is the loop version of the above...
__forceinline void MultMatrixAsm( float *pfResults, const float *pfM1, const float *pfM2 )
{
// assertions dont show up in release builds, so it incurs no performance loss.
assert( Results != M2 );
__asm
{
push ebx
push ecx
push edx
mov eax, pfM1
mov ebx, pfM2
mov edx, pfResults
mov ecx, 48
row:
fld dword ptr[ eax + ecx ]
fld dword ptr[ eax + ecx + 4 ]
fld dword ptr[ eax + ecx + 8 ]
fld dword ptr[ eax + ecx + 12 ]
fld dword ptr[ ebx ]
fmul st(0), st(4)
fld dword ptr[ ebx + 4 ]
fmul st(0), st(5)
fld dword ptr[ ebx + 16 ]
fld dword ptr[ ebx + 20 ]
fmul st(0), st(6)
faddp st(2), st(0)
fmul st(0), st(5)
faddp st(2), st(0)
fld dword ptr[ ebx + 32 ]
fld dword ptr[ ebx + 36 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmul st(0), st(4)
faddp st(2), st(0)
fld dword ptr[ ebx + 48 ]
fld dword ptr[ ebx + 52 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmul st(0), st(3)
faddp st(2), st(0)
fstp dword ptr[ edx + ecx + 4 ]
fstp dword ptr[ edx + ecx ]
fld dword ptr[ ebx + 12 ]
fmul st(0), st(4)
fld dword ptr[ ebx + 8 ]
fmulp st(5), st(0)
fld dword ptr[ ebx + 24 ]
fld dword ptr[ ebx + 28 ]
fmul st(0), st(5)
faddp st(2), st(0)
fmulp st(4), st(0)
fld dword ptr[ ebx + 40 ]
fld dword ptr[ ebx + 44 ]
fmul st(0), st(4)
faddp st(2), st(0)
fmulp st(3), st(0)
fld dword ptr[ ebx + 56 ]
fld dword ptr[ ebx + 60 ]
fmul st(0), st(3)
faddp st(2), st(0)
fmulp st(2), st(0)
fstp dword ptr[ edx + ecx + 12 ]
faddp st(1), st(0)
faddp st(1), st(0)
faddp st(1), st(0)
fstp dword ptr[ edx + ecx + 8 ]
sub ecx, 16
jge row
pop edx
pop ecx
pop ebx
}
}
*/
// C Version.. VC7 can loop unroll this, but its not that much faster still..
// M1, M2, and Results must be different pointers, a version could be written
void CMultMatrix( float *pfResults, float *pfM1, float *pfM2 )
{
// assertions dont show up in release builds, so it incurs no performance loss.
assert( pfResults != pfM1 );
assert( pfResults != pfM2 );
for( int i = 0; i < 16; i+=4 )
{
for ( int j = 0; j < 4; j++ )
pfResults[ i + j ] = pfM1[ i ] * pfM2[ j ] +
pfM1[ i + 1 ] * pfM2[ j + 4 ] +
pfM1[ i + 2 ] * pfM2[ j + 8 ] +
pfM1[ i + 3 ] * pfM2[ j + 12 ];
}
}
void PrintMatrix( float * M )
{
cout << endl;
for ( int i = 0; i < 16; i+=4 )
{
for ( int j = 0; j < 4; j++ )
cout << M[ i + j ] << "\t";
cout << endl;
}
}
int __cdecl main(int argc, char* argv[])
{
float M1[16];
float M2[16];
float Results[16];
int i;
for( i = 0; i < 16; i++ )
{
M1[ i ] = 16.0f - i;
M2[ i ] = i;
}
LARGE_INTEGER CTestStart, CTestFinish;
LARGE_INTEGER D3DTestStart, D3DTestFinish;
LARGE_INTEGER AsmTestStart, AsmTestFinish;
// Run CMultMatrix test
memset( Results, 0, sizeof( float ) * 16 );
QueryPerformanceCounter( &CTestStart );
for ( i = 0; i < 5000000; i++ )
CMultMatrix( Results, M1, M2 );
QueryPerformanceCounter( &CTestFinish );
PrintMatrix( Results );
LARGE_INTEGER CTestTime;
CTestTime.QuadPart = CTestFinish.QuadPart - CTestStart.QuadPart;
// number too big to print
assert( CTestTime.HighPart == 0 );
cout << "C :\t" << CTestTime.LowPart << endl;
// Run D3DXMatrixMultiply test
D3DXMATRIX D3DResults;
D3DXMATRIX D3DM1(M1);
D3DXMATRIX D3DM2(M2);
memset( D3DResults, 0, sizeof( float ) * 16 );
QueryPerformanceCounter( &D3DTestStart );
for ( i = 0; i < 5000000; i++ )
D3DXMatrixMultiply( &D3DResults, &D3DM1, &D3DM2 );
QueryPerformanceCounter( &D3DTestFinish );
PrintMatrix( D3DResults );
LARGE_INTEGER D3DTestTime;
D3DTestTime.QuadPart = D3DTestFinish.QuadPart - D3DTestStart.QuadPart;
// number too big to print
assert( D3DTestTime.HighPart == 0 );
cout << "D3D:\t" << D3DTestTime.LowPart << endl;
// Run MultMatrixAsm test
memset( Results, 0, sizeof( float ) * 16 );
QueryPerformanceCounter( &AsmTestStart );
for ( i = 0; i < 5000000; i++ )
MultMatrixAsm( Results, M1, M2 );
QueryPerformanceCounter( &AsmTestFinish );
PrintMatrix( Results );
LARGE_INTEGER AsmTestTime;
AsmTestTime.QuadPart = AsmTestFinish.QuadPart - AsmTestStart.QuadPart;
// number too big to print
assert( AsmTestTime.HighPart == 0 );
cout << "Asm:\t" << AsmTestTime.LowPart << endl;
getchar();
return 0;
} |