/*
 * rum - version 0.9
 * LICENSE: beerware
 * author: stano@websupport.sk (http://websupport.sk/~stanojr/)
 * one process tcp redirector with sockfile support, multiple listen ports and with nice statistics ! :)
 * linux 2.6 only because we use epoll syscall
 *
 * usage:
 * ./rum -s tcp:host:port [-s tcp:host:port [-s sock:path]] -d tcp:host:port [-b] [-m tcp:host:port]
 *         -s - where to listen host:port or sockfile (host muste be some ip address from interface or 0.0.0.0 for all inerfaces)
 *         -d - destination host:port
 *         -b - goto background
 *         -m - statistics port
 *
 *
 * #./rum -s tcp:0.0.0.0:3306 -s sock:/tmp/mysql.sock -d tcp:1.2.3.4:3306 -m 127.0.0.1:666 -d
 * # telnet localhost 666          
 * Trying 127.0.0.1...
 * Connected to localhost.
 * Escape character is '^]'.
 * [              source] [                  bytes] [         destination] [         connections]
 * [    tcp:0.0.0.0:3306] [-->                   0] [    tcp:1.2.3.4:3306] [                   0]
 *                        [<--                   0]                                              
 * 
 * [sock:/tmp/mysql.sock] [-->                3424] [    tcp:1.2.3.4:3306] [                   2]
 *                        [<--              814830]                                              
 *
 * ---
 * greetz 42448 :)
 */

#define BUFLEN 16384

#define SOCK_SERVER     0
#define SOCK_STATS      1
#define SOCK_IN         2
#define SOCK_OUT        3
#define SOCK_FLUSHEND   4
#define SOCK_CONNECT	5

#define RUM_EPOLL_EVENTS 16

#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <sys/un.h>
#include <sys/epoll.h>
#include <getopt.h>
#include <netdb.h>
#include <errno.h>
#include <signal.h>
#include <sys/ioctl.h>
#include <netinet/tcp.h>
#include <linux/sockios.h>

typedef struct {
	int fd;
	int type; // 0 - normal, 1 - statistics
	char *s; // string (tcp:blah:blah)

	unsigned int nr_allconn,nr_conn;
	unsigned int input_bytes,output_bytes; // input_bytes are from client->redir->target output are target->redir->client
} listeners;

struct core {
	int nl;
	listeners l[16];
} core;

struct dst {
	int t;
	struct sockaddr_in sin;
	struct sockaddr_un sun;
	int addrlen;
	char *s;
} dst;

struct ed {
	int fd;
	int op;
	int type;
	int z;
	void *ed2;
	char *bufptr;
	int len;
	char buf[BUFLEN];
};

void usage() {
	printf("./rum -s tcp:host:port [-s tcp:host:port [-s sock:path]] -d tcp:host:port [-b] [-m tcp:host:port]\n\t-s - where to listen host:port or sockfile (host muste be some ip address from interface or 0.0.0.0 for all inerfaces)\n\t-d - destination host:port\n\t-b - goto background\n\t-m - statistics port\n");
	_exit(-1);
}

int makeserver(char *wwtf) {
	int sock,sockopt;
	char *wtf;

	wtf=strdup(wwtf);

	if (strstr(wtf,"tcp:")==wtf) {
		char *host,*port,*tmp;
		struct hostent *h;
		struct sockaddr_in s;
		int p;

		wtf+=4;
		tmp=strstr(wtf,":");
		if (tmp!=NULL) {
			host=wtf;
			*tmp='\0';
			tmp++;
			port=tmp;
		} else {
			usage();
		}

		if ((h=gethostbyname(host))==NULL) {
			herror("gethostbyname");
			fflush(stdout);
			fflush(stderr);
			_exit(-1);
		}
		p=atoi(port);

		memset(&s,0,sizeof(struct sockaddr_in));
		memcpy(&s.sin_addr,h->h_addr_list[0],sizeof(in_addr_t));
		s.sin_port=htons(p);
		s.sin_family=AF_INET;

		if ((sock=socket(PF_INET,SOCK_STREAM,0))==-1) {
			perror("socket");
			_exit(-1);
		}

		sockopt = 1;
		setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &sockopt, sizeof(sockopt));


		if (bind(sock,(struct sockaddr *)&s,sizeof(struct sockaddr_in))==-1) {
			perror("bind");
			_exit(-1);
		}

		if (listen(sock,255)==-1) {
			perror("listen");
			_exit(-1);
		}

		printf("listening on tcp:%s:%s\n",host,port);

		return sock;

	} else if (strstr(wtf,"sock:")==wtf) {
		char *sockf;
		struct sockaddr_un s;

		wtf+=5;
		sockf=wtf;

		memset(&s,0,sizeof(struct sockaddr_un));
		s.sun_family=AF_UNIX;
		memcpy(s.sun_path,sockf,strlen(sockf));

		if (!access(sockf,F_OK)) {
			if (unlink(sockf)) {
				perror("unlink");
				_exit(-1);
			}
		}

		if ((sock=socket(PF_UNIX,SOCK_STREAM,0))==-1) {
			perror("socket");
			_exit(-1);
		}

		sockopt = 1;
		setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &sockopt, sizeof(sockopt));

		if (bind(sock,(struct sockaddr *)&s,sizeof(struct sockaddr_un))==-1) {
			perror("bind");
			_exit(-1);
		}

		if (listen(sock,255)==-1) {
			perror("listen");
			_exit(-1);
		}

		chmod(sockf,0777);

		printf("listening on sock:%s\n",sockf);

		return sock;

	} else {
		usage();
	}

	_exit(-1);
}

void prepareclient(char *wwtf) {
	char *wtf;

	wtf=strdup(wwtf);

	dst.s=strdup(wtf);

	if (strstr(wtf,"tcp:")==wtf) {
		char *host,*port,*tmp;
		struct hostent *h;
		int p;

		wtf+=4;
		tmp=strstr(wtf,":");
		if (tmp!=NULL) {
			host=wtf;
			*tmp='\0';
			tmp++;
			port=tmp;
		} else {
			usage();
		}

		if ((h=gethostbyname(host))==NULL) {
			herror("gethostbyname");
			_exit(-1);
		}
		p=atoi(port);

		memset(&dst.sin,0,sizeof(struct sockaddr_in));
		memcpy(&dst.sin.sin_addr,h->h_addr_list[0],sizeof(in_addr_t));
		dst.sin.sin_port=htons(p);
		dst.sin.sin_family=AF_INET;
		dst.addrlen=sizeof(struct sockaddr_in);
		dst.t='t';
	} else if (strstr(wtf,"sock:")==wtf) {
		char *sockf;

		wtf+=5;
		sockf=wtf;

		memset(&dst.sun,0,sizeof(struct sockaddr_un));
		dst.sun.sun_family=AF_UNIX;
		memcpy(dst.sun.sun_path,sockf,strlen(sockf));
		dst.addrlen=sizeof(struct sockaddr_un);
		dst.t='s';
	}
}

int makeclientsock() {
	int sock;
	//struct linger l;

	if (dst.t=='t') {
		if ((sock=socket(PF_INET,SOCK_STREAM,0))==-1) {
			perror("socket");
			return -1;
		}
	} else if (dst.t=='s') {
		if ((sock=socket(PF_UNIX,SOCK_STREAM,0))==-1) {
			perror("socket");
			return -1;
		}
	}
	fcntl(sock,F_SETFL,O_NONBLOCK);

	return sock;
}

int makeclientconnect(int sock) {
	int val;
	int valsize=sizeof(val);

	if (dst.t=='t') {
		connect(sock,(struct sockaddr *)&dst.sin,dst.addrlen);
	} else if (dst.t=='s') {
		connect(sock,(struct sockaddr *)&dst.sun,dst.addrlen);
	}

	getsockopt(sock,SOL_SOCKET,SO_ERROR,&val,&valsize);

	return val;
}

int main(int ac, char *av[]) {
	int csock,len,ret,i,z,y,ch,daemonize=0,log=0;
	struct sockaddr csa;
	int epfd;
	struct epoll_event ev,ev2,events[RUM_EPOLL_EVENTS];
	struct ed *sed,*sed2,*sed3,*ed,*ed2;
	char *logfile;

	signal(SIGPIPE,SIG_IGN);

	if (ac==1) {
		usage();
	}

	core.nl=0;

	while ((ch = getopt(ac, av, "bd:s:m:l:")) != -1) {
		switch (ch) {
			case 'b':
				daemonize=1;
			break;
			case 's':
				core.l[core.nl].s=strdup(optarg);
				core.l[core.nl].fd=makeserver(optarg);
				core.l[core.nl].type=SOCK_SERVER;
				core.l[core.nl].nr_conn=0;
				core.l[core.nl].nr_allconn=0;
				core.l[core.nl].input_bytes=0;
				core.l[core.nl].output_bytes=0;
				core.nl++;
			break;
			case 'm':
				core.l[core.nl].fd=makeserver(optarg);
				core.l[core.nl].type=SOCK_STATS;
				core.nl++;
			break;
			case 'd':
				prepareclient(optarg);
			break;
			case 'l':
				log=1;
				logfile=strdup(optarg);
			break;
		}
	}

	if (daemonize) {
		if (log) {
			daemon(0,1);
			close(0);
			close(1);
			close(2);
			ret=open(logfile,O_WRONLY|O_CREAT|O_APPEND,S_IRUSR|S_IWUSR);
			if (ret!=-1) {
				dup2(ret,1);
				dup2(ret,2);
			}
			
		} else {
			daemon(0,0);
		}
	}

	epfd=epoll_create(RUM_EPOLL_EVENTS);

	memset(&ev,0,sizeof(ev));
	for (i=0;i<core.nl;i++) {
		ed=malloc(sizeof(struct ed));
		ev.events=EPOLLIN;
		ev.data.ptr=ed;
		ed->op=EPOLLIN;
		ed->fd=core.l[i].fd;
		ed->ed2=NULL;
		if (core.l[i].type==SOCK_SERVER) {
			ed->type=SOCK_SERVER; // listener
		} else if (core.l[i].type==SOCK_STATS) {
			ed->type=SOCK_STATS; // statistics
		}

		epoll_ctl(epfd,EPOLL_CTL_ADD,core.l[i].fd,&ev);
	}

	while (1) {
		ret=epoll_wait(epfd,(struct epoll_event *)&events,RUM_EPOLL_EVENTS,-1);
		if (ret<0 && errno!=EINTR) {
			_exit(0);
		} else if (ret<0 && errno==EINTR) {
			continue;
		} else if (ret>0) {
			/* check for new connections */
			for (i=0;i<ret;i++) {
				if (events[i].data.ptr==NULL) {
					continue;
				}

				sed=events[i].data.ptr;
				sed2=sed->ed2;

				if ((events[i].events&EPOLLIN)) {
					if (sed->type==SOCK_SERVER) {
						// accept new connection
						for (z=0;z<core.nl;z++) {
							if (core.l[z].fd==sed->fd) {
								len=sizeof(csa);
								csock=accept(core.l[z].fd,&csa,&len);
								fcntl(csock,F_SETFL,O_NONBLOCK);
								memset(&ev,0,sizeof(ev));
								ed=malloc(sizeof(struct ed));
								ev.data.ptr=ed;
								ev.events=EPOLLIN;
								ed->op=ev.events;
								ed->fd=csock;
								ed->type=SOCK_IN; // in conn
								ed->z=z; // what listener
								ed->len=0;

								csock=makeclientsock();
								if (csock==-1) {
									close(ed->fd);
									break;
								}
								//fcntl(csock,F_SETFL,O_NONBLOCK);
								memset(&ev2,0,sizeof(ev2));
								ed2=malloc(sizeof(struct ed));
								ed->ed2=ed2;
								//epoll_ctl(epfd,EPOLL_CTL_ADD,ed->fd,&ev);

								ed2->fd=csock;
								ev2.events=EPOLLOUT;
								ed2->op=ev2.events;
								ed2->type=SOCK_CONNECT; // out conn
								ed2->z=z; // what listener
								ed2->ed2=ed;
								ev2.data.ptr=ed2;
								ed2->len=0;

								epoll_ctl(epfd,EPOLL_CTL_ADD,ed2->fd,&ev2);

								core.l[z].nr_conn++;
								core.l[z].nr_allconn++;

								break;
							}
						}

					} else if (sed->type==SOCK_STATS) {
						// accept, print statistics and close
						for (z=0;z<core.nl;z++) {
							if (core.l[z].fd==sed->fd) {
								FILE *fp;
								len=sizeof(csa);
								csock=accept(core.l[z].fd,&csa,&len);
								fp=fdopen(csock, "w");
								if (fp!=NULL) {
									fprintf(fp,"[%20s] [   %10s] [%20s] [%15s] [%18s]\n","source","bytes","destination","all connections","actual connections");
									for(y=0;y<core.nl;y++) {
										if (core.l[y].type==SOCK_STATS)
											break;
										fprintf(fp,"[%20s] [-->%10u] [%20s] [%15u] [%18u]\n", core.l[y].s, core.l[y].input_bytes, dst.s, core.l[y].nr_allconn,core.l[y].nr_conn);
										fprintf(fp," %20s  [<--%10u]  %20s   %15s   %18s\n\n", "", core.l[y].output_bytes,"","","");
									}
									fflush(fp);
								}
								close(csock);
							}
						}
					} else if (sed->type==SOCK_IN || sed->type==SOCK_OUT) {
						// copy data
						if (sed->len<=0) {
							sed->len=read(sed->fd,sed->buf,BUFLEN);
							sed->bufptr=sed->buf;
						}
						if (sed->len==0) {
							epoll_ctl(epfd,EPOLL_CTL_DEL,sed->fd,NULL);
							epoll_ctl(epfd,EPOLL_CTL_DEL,sed2->fd,NULL);
							for (z=i;z<ret;z++) {
								if (sed2==events[z].data.ptr) {
									events[z].data.ptr=NULL;
									break;
								}
							}
							core.l[sed->z].nr_conn--;
							close(sed->fd);
							close(sed2->fd);
							free(sed2);
							free(sed);
						} else if (sed->len==-1 && errno==EAGAIN) {
						} else if (sed->len==-1 && errno!=EAGAIN) {
							/* napr conn resed by peer*/
							epoll_ctl(epfd,EPOLL_CTL_DEL,sed->fd,NULL);
							epoll_ctl(epfd,EPOLL_CTL_DEL,sed2->fd,NULL);
							for (z=i;z<ret;z++) {
								if (sed2==events[z].data.ptr) {
									events[z].data.ptr=NULL;
									break;
								}
							}
							core.l[sed->z].nr_conn--;
							close(sed->fd);
							close(sed2->fd);
							free(sed2);
							free(sed);
							break;
						} else {
							if (sed->type==SOCK_IN) {
								core.l[sed->z].input_bytes+=len;
							} else if (sed->type==SOCK_OUT) {
								core.l[sed->z].output_bytes+=len;
							}
							len=write(sed2->fd,sed->bufptr,sed->len);
							if (len==-1 && errno!=EAGAIN){
								epoll_ctl(epfd,EPOLL_CTL_DEL,sed->fd,NULL);
								epoll_ctl(epfd,EPOLL_CTL_DEL,sed2->fd,NULL);
								for (z=i;z<ret;z++) {
									if (sed2==events[z].data.ptr) {
										events[z].data.ptr=NULL;
										break;
									}
								}
								core.l[sed->z].nr_conn--;
								close(sed->fd);
								close(sed2->fd);
								free(sed2);
								free(sed);
							} else if (len==-1 && errno==EAGAIN) {
								sed2->op|=EPOLLOUT;
								ev.events=sed2->op;
								ev.data.ptr=sed2;
								epoll_ctl(epfd,EPOLL_CTL_MOD,sed2->fd,&ev);
								sed->op^=EPOLLIN;
								ev.events=sed->op;
								ev.data.ptr=sed;
								epoll_ctl(epfd,EPOLL_CTL_MOD,sed->fd,&ev);
							} else if (len<sed->len) {
								sed->len-=len;
								sed->bufptr=sed->bufptr+len;
								sed2->op|=EPOLLOUT;
								ev.events=sed2->op;
								ev.data.ptr=sed2;
								epoll_ctl(epfd,EPOLL_CTL_MOD,sed2->fd,&ev);
								sed->op^=EPOLLIN;
								ev.events=sed->op;
								ev.data.ptr=sed;
								epoll_ctl(epfd,EPOLL_CTL_MOD,sed->fd,&ev);
							} else {
								sed->len=0;
							}
						}
					} else {
						printf("hmmmmm??? sed->type: %d\n",sed->type);
					}
				} else if ((events[i].events&EPOLLOUT) || (events[i].events&EPOLLHUP)) {
					if (sed->type==SOCK_IN || sed->type==SOCK_OUT) {
						if (sed2->len>0) {
							len=write(sed->fd,sed2->bufptr,sed2->len);
							if (len==-1 && errno!=EAGAIN){
								epoll_ctl(epfd,EPOLL_CTL_DEL,sed->fd,NULL);
								epoll_ctl(epfd,EPOLL_CTL_DEL,sed2->fd,NULL);
								for (z=i;z<ret;z++) {
									if (sed2==events[z].data.ptr) {
										events[z].data.ptr=NULL;
										break;
									}
								}
								core.l[sed->z].nr_conn--;
								close(sed->fd);
								close(sed2->fd);
								free(sed2);
								free(sed);
							} else if (len==-1 && errno==EAGAIN) {
							} else if (len<sed2->len) {
								sed2->len-=len;
								sed2->bufptr=sed2->bufptr+len;
							} else {
								sed2->len=0;

								sed->op^=EPOLLOUT;
								ev.events=sed->op;
								ev.data.ptr=sed;
								epoll_ctl(epfd,EPOLL_CTL_MOD,sed->fd,&ev);

								sed2->op|=EPOLLIN;
								ev.events=sed2->op;
								ev.data.ptr=sed2;
								epoll_ctl(epfd,EPOLL_CTL_MOD,sed2->fd,&ev);
							}

						} else {
							printf("uoaaaa, len is %d\n",sed2->len);
						}
					} else if (sed->type==SOCK_CONNECT) {
						errno=0;
						z=makeclientconnect(sed->fd);
						if (z==EINPROGRESS || z==EALREADY) {
							continue;
						} else if (z==0) {
							fcntl(sed->fd,F_SETFL,O_NONBLOCK);
							ev.events=EPOLLIN;
							ev.data.ptr=sed2;
							sed2->op=ev.events;
							sed2->type=SOCK_IN;
							if (epoll_ctl(epfd,EPOLL_CTL_ADD,sed2->fd,&ev)==-1) {
								epoll_ctl(epfd,EPOLL_CTL_DEL,sed->fd,NULL);
								close(sed->fd);
								close(sed2->fd);
								free(sed2);
								free(sed);
								continue;
							}
							ev.events=EPOLLIN;
							ev.data.ptr=sed;
							sed->op=ev.events;
							sed->type=SOCK_OUT;
							epoll_ctl(epfd,EPOLL_CTL_MOD,sed->fd,&ev);
						} else {
							printf("connect error: %s\n",strerror(z));
							fflush(stdout);
							epoll_ctl(epfd,EPOLL_CTL_DEL,sed->fd,NULL);
							close(sed->fd);
							close(sed2->fd);
							free(sed2);
							free(sed);
						}
					} else {
						printf("errrrrrr fd:%d\n",sed->fd);
					}
				}
			}
		}
	}
}

