A keygen challenge (z3+brute)

On one of the previous CTFs that I played a keygen me binary is given, you can download it from here. We'll solve it in two ways, the easy way which is bruteforce, and the longer way which is using z3.

Let us start to understand what the binary does. We start from main. At the beginning of main it initializes a bunch of pointers to functions.

those functions are called in sequence to check the serial number. They take a char * and return either 1 or 0. If all of them passed then the serial number is valid, otherwise it's not.

Let's take a look at each function.

  • The first 3 functions are simple all they do is check if the serial number is in the requested format which is 19 characters long, consist of numbers, and has - at positions 4,9, and 14.

  • verify_1

Decompilation of the function is :

This is a simple check that does the following :

We have 4 parts, each part is the sum of the previous numbers in the serial number divided by the part number. eg:

  • part1 is the sum of the first 4 numbers divided by 1
  • part2 is the sum of the first 8 numbers divided by 2
  • and so on..

the final check is that every part equals the total sum divided by 4.

  • verify_2

This function checks that numbers in position

  • 0,1,2,3 are not equal to 5,6,7,8
  • 5,6,7,8 are not equal to 10,11,12,13
  • 10,11,12,13 are not equal to 15,16,17,18

Using z3 we can generate all possible solutions by adding the constrains to a solver and it will solve them.

#!/usr/bin/env python2

import z3

def keygen():

    s = z3.Solver()

    # Create the ser number which
    # has the format 0000-0000-0000-0000
    ser = [z3.BitVec("char_%d" % i, 32) for i in range(19)]

    # ser have dashes ever 4 numbers
    s.add(ser[4] == 45, ser[9] == 45, ser[14] == 45)

    # ser consist of numbers only
    for i in range(0,19):
        if i in [4, 9, 14]: # skip the dashes
            continue
        s.add(ser[i] >= 48, ser[i] <= 48+9)

    # First check
    sums = [z3.BitVec("sum_%d" % i, 32) for i in range(0, 4)]
    part = [z3.BitVec("part_%d" % i, 32) for i in range(0, 4)]

    s.add(sums[0] == (ser[0])+(ser[1])+(ser[2])+(ser[3]))
    s.add(part[0] == sums[0])

    s.add(sums[1] == ((ser[5])+(ser[6])+(ser[7])+(ser[8])))
    s.add(part[1] == ((sums[0]+sums[1]) / 2))

    s.add(sums[2] == ((ser[10])+(ser[11])+(ser[12])+(ser[13])))
    s.add(part[2] == ((sums[0]+sums[1]+sums[2]) / 3))

    s.add(sums[3] == ((ser[15])+(ser[16])+(ser[17])+(ser[18])))
    s.add(part[3] == ((sums[0]+sums[1]+sums[2]+sums[3]) / 4))

    s.add(part[0] == ((sums[0]+sums[1]+sums[2]+sums[3])/4))
    s.add(part[1] == ((sums[0]+sums[1]+sums[2]+sums[3])/4))
    s.add(part[2] == ((sums[0]+sums[1]+sums[2]+sums[3])/4))
    s.add(part[3] == ((sums[0]+sums[1]+sums[2]+sums[3])/4))

    # Second check
    for i in range(0, 4):
        v5 = ser[i+5]-48
        v6 = ser[i+10]-48
        s.add( (ser[i]-48) != v5, v5 != v6, v6 != (ser[i+15]-48))

    # Generate all possible solutions
    # until it's unsatisfiable
    while True:
        if s.check() == z3.sat:
            model = s.model()
            solution = ''.join([chr(int(str(model[ser[i]]))) for i in range(0, 19)])
            print solution
            remove_ser_from_model(ser, model, s)
        else:
            break

# Removes a previous soltuion
def remove_ser_from_model(s, m, solver):  
    solver.add(z3.Or(s[0] != m[s[0]],
        s[1] != m[s[1]],
        s[2] != m[s[2]],
        s[3] != m[s[3]],
        s[5] != m[s[5]],
        s[6] != m[s[6]],
        s[7] != m[s[7]],
        s[8] != m[s[8]],
        s[10] != m[s[10]],
        s[11] != m[s[11]],
        s[12] != m[s[12]],
        s[13] != m[s[13]],
        s[15] != m[s[15]],
        s[16] != m[s[16]],
        s[17] != m[s[17]],
        s[18] != m[s[18]])) 

def main():  
    keygen()

if __name__ == "__main__":  
    main()

This will generate all possible solutions. It was running while I was writing this, and so far it generated

[~/serial]$ cat keys.txt|sort -u|wc -l                                             
41984  

41K unique valid keys :p. The easy and quick way of solving this will be just decompile the check functions and run them while generating the parts of the serial randomly from a digit charset.

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <unistd.h>
#include <string.h>
#include <time.h>

char *randstring(int length);  
int verify_1(char *serial);  
int verify_2(char *serial);

char *randstring(int length) {  
    char *string = "0123456789";
    size_t slen = strlen (string);        
    static int seed = 0xdead;
    char *rstring = NULL;

    srand(time(NULL) * length + ++seed);

    if (length < 1) length = 1;

    rstring = malloc (sizeof(char) * (length +1));

    if (rstring) {
        short key = 0;
        for (int n = 0;n < length;n++) {            
            key = rand() % slen;          
            rstring[n] = string[key];
        }
        rstring[length] = '\0';
        return rstring;        
    } else {
        perror ("malloc");
        exit (1);
    }
    return NULL;
}

int verify_1(char *serial) {

    signed int numbers, i, part_n;
    int j, good, parts[4];

    numbers = j = 0;
    for ( i = 0; i <= 18; ++i ) {
        if ( serial[i] == 45 ) {
            parts[j] = numbers / (j + 1);
            ++j;
        } else {
            numbers += serial[i] - 48;
        }
    }
    parts[j] = numbers / (j + 1);
    good = 0;
    for ( part_n = 0; part_n <= 3; ++part_n ) {
        if ( parts[part_n] == numbers / 4 ) ++good;
    }
    return good == 4;
}

int verify_2(char *serial) {  
    signed int v2, i;
    int v5, v6, v3;

    v2 = v3 = 0;
    for ( i = 0; i <= 3; ++i ) {
        v5 = serial[i + 5] - 48;
        v6 = serial[i + 10] - 48;
        if ( serial[i] - 48 != v5 && v5 != v6 && v6 != serial[i + 15] - 48 )
            ++v3;
    }
    if ( v3 == 4 ) v2 = 1;
    return (unsigned int)v2;
}

int main() {  
    while(1) {
        char serial[21] = {0};
        char *r1 = randstring(4);
        char *r2 = randstring(4);
        char *r3 = randstring(4);
        char *r4 = randstring(4);
        sprintf(serial, "%s-%s-%s-%s", r1,r2,r3,r4);
        if(verify_1(serial) && verify_2(serial)) {
            printf("%s\n", serial);
        }
        free(r1); free(r2);
        free(r3); free(r4);
    }
}
}}