在单元测试中,我们需要提供业务逻辑的mock版本, 当业务逻辑实现为C++的virtual function时,这是很容易的,我们只需要写一个子类, 实现virtual function就行了,Google 的 gmock就针对这种情况设计。

可是,如果遗留代码中有一般C函数,非virtual的类成员函数,模板函数,inline函数,如何提供mock版本呢?

下面的代码用一点trick实现了上述函数的运行时mock。

原理是,在运行时,修改目标函数的机器码,改为jmp到mock版本的函数中。

实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <stdint.h>
#include <iostream>
#include <string>
#include <unistd.h>
#include <string.h>
#include <sys/mman.h>
#include "patch_elf.h"
using namespace std;


int print_op(void * addr,int leng){
    unsigned char * op=(unsigned char *) addr;
    cout<<endl<<"addr:"<<addr<<" code:"<<endl;
    for(int i=0;i<leng;++i){
        cout<<"0x"<<hex<<(unsigned int)op[i]<<" ";
    }
    cout<<endl;
    return 0;
}

int patch_func(void  * original,void * mock){
    /*
    cout<<endl<<"----------------------------------------------------------------------"
        <<endl<<__func__<<" ,i am going to patch "<<original<<" to "<<mock
        <<endl;
        */

    //rax 用于保存函数调用的返回值,所以可以占用
    //4010e1:       b8 20 0c 40 00          mov    $0x400c20,%eax
    //4010e6:       ff e0                   jmpq   *%rax

    uint32_t addr=(uint32_t)(uint64_t)mock;
    const int code_len=7;
    char inject_code[code_len]={0xb8,0x00,0x00,0x00,0x00,0xff,0xe0};
    memcpy(&inject_code[1],(char*)&addr,4);

    //print_op(inject_code,code_len);

    //接下来,把inject_code复制到original这个位置
    //print_op(original,100);
    
    //首先,要改掉内存的权限,增加写权限
    const size_t length = sysconf(_SC_PAGESIZE);
    void * code_addr= (void*) ( ( (long) original/length)*length );
    int ret=mprotect(code_addr, length, PROT_READ | PROT_WRITE | PROT_EXEC);
    if ( 0!=ret ) {
        cerr<<"mprotect failed! ret="<<ret<<endl;
    }

    //修改代码
    memcpy( original,inject_code,code_len);

    //再去掉写权限
    ret=mprotect(code_addr, length, PROT_READ |  PROT_EXEC);
    if ( 0!=ret ) {
        cerr<<"mprotect failed! ret="<<ret<<endl;
    }

    //print_op(original,100);
    //cout<<"----------------------------------------------------------------------"
    //  <<endl<<endl;
    return 0;
}

测试了

1.一般函数 2.inline函数 3.一般成员函数 4.模板函数

并在 :32位,64位; -O2, -O0,参数下编译

除了 inline函数没办法,其它的都有效

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#pragma once
#include <iostream>
#include <cstdio>
using namespace std;

class ST1{
    public:
        uint32_t a;
        uint64_t b;
        char c[200];
        double d;
        ST1 * e;

        ST1():a(0),b(0),d(0),e(0){
            c[0]=0;
        }

        //类的成员函数
        int member_func();
        int member_func_mock();
};

int member_func_extern(ST1 * st);


//一般函数
int original_func(ST1 * para1,ST1 para2,void * para3);
int mock_func(ST1 * para1,ST1 para2,void * para3);
//int ref_func(ST1 * para1,ST1 para2,void * para3);

class Base{
    private:
        uint32_t b;
        ST1 st;
    public:
        uint32_t a;

        Base():b(0),a(0){}
};

//inline 函数
inline int inline_func(int a,int b){
    int c=a+b+ 0x1111 * a + b/0x1111;
    printf("%s %d\n",__func__,c);
    return c;
}

inline int inline_func_mock(int a,int b){
    int c=a+b+100;
    printf("%s %d\n",__func__,c);
    return c;
}

//模板函数
template <typename T>
uint32_t get_member_a(T & t){
    cout<<__func__<<" a="<<t.a<<endl;
    return t.a;
}

template <typename T>
uint32_t get_member_b(T & t){
    cout<<__func__<<" b="<<t.b<<endl;
    return t.b;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <stdint.h>
#include <iostream>
#include <string>
#include <unistd.h>
#include "func.h"
using namespace std;

int original_func(ST1 * para1,ST1 para2,void * para3){
    cout<<__func__<<"\tcalled! "
        <<" a+a "<<para1->a+para2.a
        <<" b+b "<<para1->b+para2.b
        <<" c+c "<<para1->c<<para2.c
        <<" d+d "<<para1->d+para2.d
        <<" e+e "<<para1->e<<para2.e
        <<para3
        <<endl;
    return 0;
}


int mock_func(ST1 * para1,ST1 para2,void * para3){
    cout<<__func__<<"\tcalled!"
        <<endl;
    return 0;
}

int ref_func(ST1 * para1,ST1 para2,void * para3){
    return mock_func(para1,para2,para3);
}

int ST1::member_func(){
    cout<<__func__<<" called! "
        <<" a="<< this->a
        <<" b="<< this->b
        <<" c="<< this->c
        <<" d="<< this->d
        <<endl;
    return 0;
}

int ST1::member_func_mock(){
    cout<<__func__<<" called! i do nothing."
        <<endl;
    return 0;
}

int member_func_extern(ST1 * st){
    cout<<__func__<<" called! i am not member function."
        <<endl;
    return 0;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <stdint.h>
#include <iostream>
#include <string>
#include <unistd.h>
#include <string.h>
#include <sys/mman.h>

#include "func.h"
#include "patch_elf.h"

using namespace std;


int test(){
    ST1 s1,s2;
    char str[]="hello";
    s1.a=s1.b=s1.d=1;
    s1.e=NULL;

    s2.a=s2.b=s2.d=2;
    s1.e=NULL;

    cout<<"----------------------------------------------------------------------"<<endl;

    //mock original_func,替换成mock_func
    original_func(&s1,s2,&str[0]);

    patch_func((void*)&original_func, (void*)&mock_func);

    original_func(&s1,s2,&str[0]);

    cout<<"----------------------------------------------------------------------"<<endl;

    //mock inline 函数貌似不行
    int a=s1.a, b=s1.b;
    inline_func(a,b);
    patch_func( (void*) &inline_func, (void*) &inline_func_mock);
    inline_func(a,b);

    cout<<"----------------------------------------------------------------------"<<endl;

    s1.member_func();
    patch_func( (void*) &ST1::member_func, (void*) &ST1::member_func_mock);
    s1.member_func();
    patch_func( (void*) &ST1::member_func, (void*) &member_func_extern);
    s1.member_func();

    cout<<"----------------------------------------------------------------------"<<endl;

    get_member_a(s1);
    patch_func( (void*) & get_member_a<ST1> , (void*) & get_member_b<ST1>);
    get_member_a(s1);

    return 0;
}

int main(){
    test();
    return 0;
}

可以参考