/*****************************************************************************
* File:   doord.c
* Date:   2006-Apr-29 3:14:27 PM EST
* Original Authors: Pat Wilbur and Jason Galens
* Last Modified:2006-Apr-30 11:23:05 PM EST
*
* Description:
*
* Port knocking daemon that authenticates clients and adds rules to iptables.
*
*
* Notice:
*
* To be used with knock.c.  Copying and distributing this file is permitted as
* long as it is packaged with knock.c, and all this information here at the top
* of the file, including authors' names, authors' e-mail addresses, the 
* description, the notice, and the thank you messages are all intact.  In other
* words, do not touch the area between the stars.  If you wish to contribute,
* please e-mail us!
*
* libpcap implementation based on the BEST EVER libpcap packet capture tutorial
* by Martin Casado.  That tutorial may be found at:
*   http://www.cet.nau.edu/~mc8/Socket/Tutorials/section1.html
*
* Thank you, Mr. Casado.
*
******************************************************************************/


#include <time.h>
#include <pcap.h>
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netinet/if_ether.h> 
#include <fcntl.h>


#define FILTER "udp"   /* our libpcap filter ;-D */
#define DEVICE "lo"  /* your network device */


#define KEY_LEN 16
#define BUF_LEN 256


#define PROMISCUOUS 0  /* you really probably shouldn't change this, it is un-
						  tested, unimplemented, useless as of right now, and 
						  generally a very bad idea */

#define TIMEOUTSECS 40 /* timeout for a port knock, CHANGING IS UNTESTED!! */

/*****************************************************************************/
/*****************************************************************************/

typedef struct PortList_t {
        unsigned short port;
        struct PortList_t * next;
} PortList;

typedef struct IPList_t {
	int authenticated;
	int byteArrived[16];
	unsigned char message[16];
	time_t firstPacketArrived;
        PortList * plist;
	unsigned char ip[4];
	struct IPList_t * next;
} IPList;



IPList IPs;
unsigned char * key;
char command[100];
char tempport[10];


/* makeshift encryption algorithm... PLEASE change before using in production */
void crypt(const unsigned char* buffer, int buf_size, const unsigned char* key, int key_len, unsigned char* ret)
{
     int i, j;
     for (i = 0, j = 0; i < buf_size; i++, j++) {
         if (j >= key_len) j = 0;
    
         ret[i] = buffer[i] ^ key[j];
     }
}


/* flushes list of IPs that have timed out, returns ptr to where  */
IPList * flushAndFindLoc(unsigned char * ip)
{
	IPList * current;
	IPList * temp;
	IPList * foundLoc;
	foundLoc = NULL;
	current = &IPs;
	while(current->next != NULL)
	{
		if( ((current->next)->authenticated == 0)
		  && (time(NULL) - (current->next)->firstPacketArrived > TIMEOUTSECS) )
		{
			temp = (current->next)->next;
			free(current->next);
			current->next = temp;
		}
		else
		{
			if (((current->next)->ip[0] == ip[0])
				&& ((current->next)->ip[1] == ip[1]) 
				&& ((current->next)->ip[2] == ip[2]) 
			        && ((current->next)->ip[3] == ip[3]) )
			{
				foundLoc = current;
			}
		}
		current = current->next;
	}
	if (foundLoc == NULL) return current;
	return foundLoc;
}


int isPortUsed(IPList * entry, unsigned short port)
{
  PortList * curr;
  curr = (entry->plist);
  while(curr != NULL)
  {
    if (curr->port == port)
      return 1;
    curr = curr->next;
  }
  return 0;
}


/* validates messages, makes appropriate iptables decisions, adjusts list */
void validate(unsigned char * ip, unsigned char * message)
{
/*	fprintf(stdout, "validating...\n");
	fflush(stdout); */
	
	int i, x;
	IPList * curr;
	IPList * temp;
	unsigned char plaintext[BUF_LEN];
	
	curr = flushAndFindLoc(ip);  /* will never be NULL */
	
	/* first check to see if all digits are in. */	
	x = 0;
	for (i = 0; i < 16; ++i)  x += (curr->next)->byteArrived[i];
		
	if ( x < 16 )
	{
		return;
	}

	crypt(message, BUF_LEN, key, KEY_LEN, plaintext);

	unsigned short port;
	((unsigned char *)(&port))[0] = plaintext[4];
	((unsigned char *)(&port))[1] = plaintext[5];

	fprintf(stdout, "PORT: %d\n",port);
	fflush(stdout);

	if ( (x == 16) && ((curr->next)->authenticated == 0) 
	     && isPortUsed((curr->next), port ) == 0 )
	{
		for (i = 0; i < 16; ++i) (curr->next)->byteArrived[i] = 0;		
		if (plaintext[0] == ip[0]
			&& plaintext[1] == ip[1]
			&& plaintext[2] == ip[2]
			&& plaintext[3] == ip[3] )
		{
			fprintf(stdout, "Added.\n");
			fflush(stdout);

			tempport[2] = '\0';

			sprintf(tempport,"%d", port);

			command[0] = '\0';			
			strcat(command, "iptables -A INPUT -p TCP -i ");
			strcat(command, DEVICE);
			strcat(command, " --dport ");
			strcat(command, tempport);
			strcat(command, " -j ACCEPT");
			fflush(stdout);
			system(command); 

			command[0] = '\0';			
			strcat(command, "iptables -A INPUT -p UDP -i ");
			strcat(command, DEVICE);
			strcat(command, " --dport ");
			strcat(command, tempport);
			strcat(command, " -j ACCEPT");
			fflush(stdout);
			system(command); 
			
			(curr->next)->authenticated = 1;
			PortList * temp;
			temp = (curr->next)->plist;
			if (temp != NULL)
			{
			  while (temp->next != NULL)
			  {
			    temp = temp->next;		     
			  }
			  temp->next = (PortList *) malloc(sizeof(PortList));
			  if (temp->next == NULL)
			  {
			    fprintf(stderr, "ERROR!  malloc failed.\n");
			    fflush(stderr);
			    exit(1);
			  }
			  temp->next->port = port;
			  temp->next->next = NULL;
			}
			else
			{
			  curr->next->plist = (PortList *) malloc(sizeof(PortList));
			  if (curr->next->plist == NULL)
			  {
			    fprintf(stderr, "ERROR!  malloc failed.\n");
			    fflush(stderr);
			    exit(1);
			  }
			  (curr->next)->plist->port = port;
			  (curr->next)->plist->next = NULL;
			}


			return;
		}
		else
		{
			/* the IP and claimed IP did not match! */
			/* leave handling for the reaper, below */		
		}
	}
	else if ( (x == 16) && ((curr->next)->authenticated == 1) 
	     && isPortUsed((curr->next), port) == 1 )
	{
		if (plaintext[0] == ip[0]
			&& plaintext[1] == ip[1]
			&& plaintext[2] == ip[2]
			&& plaintext[3] == ip[3] )
		{
			fprintf(stdout, "Removed.\n");
			fflush(stdout);

			tempport[2] = '\0';

			sprintf(tempport,"%d", port);

			command[0] = '\0';			
			strcat(command, "iptables -D INPUT -p TCP -i ");
			strcat(command, DEVICE);
			strcat(command, " --dport ");
			strcat(command, tempport);
			strcat(command, " -j ACCEPT");
			fflush(stdout);
			system(command); 

			command[0] = '\0';			
			strcat(command, "iptables -D INPUT -p UDP -i ");
			strcat(command, DEVICE);
			strcat(command, " --dport ");
			strcat(command, tempport);
			strcat(command, " -j ACCEPT");
			fflush(stdout);
			system(command); 
			
			PortList * temp;
			temp = curr->next->plist;
			while(temp->port != port)
			{
			  temp = temp->next;
			}
			free(temp);
			
			if (curr->next->plist == NULL)
			{
			  /* say hello to the reaper */
			}
			else
			  return;
		}
		else
		{
			/* the IP and claimed IP did not match! */
			return;
		}		
	}
	
	
	/*  * THE REAPER *  */
	/* remove ip from this list */
	temp = (curr->next)->next;
	free(curr->next);
	curr->next = temp;
}



/* the function that handles what could be a knock */
void knockknock(unsigned char * ip, unsigned int port)
{
	unsigned int x;
	
	int r;
	
	IPList * end;
	IPList * temp;
	
	/* every so often we need to have incomplete port knocks (or what we thought
	   might be port knocks) timeout */
	/* flushes our list of IPs that timed out, returns ptr to entry just before
	   where entry for ip is/belongs */

	end = flushAndFindLoc(ip); 

	if (end->next == NULL)
	{
		end->next = (IPList*) malloc(sizeof(IPList));
		if (end->next == NULL)
		{
			fprintf(stderr, "ERRROR!  malloc failed.\n");
			exit(1);
		}
		end->next->next = NULL;
	
		(end->next)->firstPacketArrived = time(NULL);

		(end->next)->authenticated = 0;

		for (x=0; x < 16; ++x) (end->next)->byteArrived[x] = 0;

		(end->next)->ip[0] = ip[0];
		(end->next)->ip[1] = ip[1];
		(end->next)->ip[2] = ip[2];
		(end->next)->ip[3] = ip[3];
	}


	x = port / 256;  /* determine offset of digit */
	
	if (x < 17)
	{
		(end->next)->byteArrived[x]	= 1;
		(end->next)->message[x] = (char) (port % 256);  /* determine digit */
		validate(ip, (end->next)->message);
	}
}


/* callback function, called each time a packet arrives through the filter */
void my_callback(u_char *useless,const struct pcap_pkthdr* pkthdr,const u_char*
        packet)
{
	unsigned int port;
	port = (unsigned int)0;

	unsigned char ip[4];
	
	port += (unsigned int)(packet[36]);
	port = port*256;
	port += (unsigned int)(packet[37]);

	ip[0] = (unsigned int)(packet[30]);
	ip[1] = (unsigned int)(packet[27]);
	ip[2] = (unsigned int)(packet[28]);
	ip[3] = (unsigned int)(packet[29]);

/*	fprintf(stdout, "%d.%d.%d.%d: %d\n", ip[0], ip[1], ip[2], ip[3], port);
	fflush(stdout);*/
	knockknock(ip, port); 
}

int main(int argc,char **argv)
{ 
	IPs.next = NULL;

    int i;
    char *dev; 
    char errbuf[PCAP_ERRBUF_SIZE];
    pcap_t* descr;
    const u_char *packet;
    struct pcap_pkthdr hdr;     /* pcap.h                    */
    struct ether_header *eptr;  /* net/ethernet.h            */
    struct bpf_program fp;      /* hold compiled program     */
    bpf_u_int32 maskp;          /* subnet mask               */
    bpf_u_int32 netp;           /* ip                        */


	key = (unsigned char *) malloc(KEY_LEN * sizeof(char));
    /* open server's public key file and obtain key */
    int handle = open("server.key", O_RDONLY);
    if (handle == -1) {
        fprintf(stderr, "Cannot find server key.");
        return 1;
    }

     read(handle, key, KEY_LEN);
     close(handle);


    /* grab a device to peak into... */
            /*dev = pcap_lookupdev(errbuf);*/
    dev = DEVICE;
    if(dev == NULL)
    { fprintf(stderr,"%s\n",errbuf); exit(1); }

    /* ask pcap for the network address and mask of the device */
    pcap_lookupnet(dev,&netp,&maskp,errbuf);

    /* open device for reading */
    descr = pcap_open_live(dev,BUFSIZ,PROMISCUOUS,-1,errbuf);
    if(descr == NULL)
    { printf("pcap_open_live(): %s\n",errbuf); exit(1); }

    /* Lets try and compile the program.. non-optimized */
    if(pcap_compile(descr,&fp,FILTER,0,netp) == -1)
    { fprintf(stderr,"Error calling pcap_compile\n"); exit(1); }

    /* set the compiled program as the filter */
    if(pcap_setfilter(descr,&fp) == -1)
    { fprintf(stderr,"Error setting filter\n"); exit(1); }

    /* loop, indefinitely */ 
    pcap_loop(descr,-1,my_callback,NULL);

    return 0;
}
