// "server" means the real IRC server
// "client" means bouncer clients

// uses socket demo code from https://beej.us/guide/bgnet/html/single/bgnet.html
// and getstdin() uses getLine() from https://stackoverflow.com/questions/4023895/

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <netdb.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <sys/select.h>

#define MAXDATASIZE 513 // max number of bytes we can get at once (RFC2812 says 512, plus one for null terminator)
#define STDIN 0 // stdin is fd 0
#define BACKLOG 10 // maximum length to which the queue of pending connections for sockfd may grow
#define MAXCLIENTS 32 // maximum number of clients that can connect to the bouncer at a time

#define BOUNCERLISTENPORT "1234" // TODO: change this to a config option!

// getstdin() return codes
#define OK       0
#define NO_INPUT 1
#define TOO_LONG 2

int debugmode = 0;

// Print a debugging message, if debugging enabled
void debug(char *string) {
  if (debugmode) {
    printf("DEBUG: %s\n", string);
  }
}

// Get stdin line with buffer overrun protection
static int getstdin(char *prompt, char *buff, size_t sz) {
  int ch, extra;

  // Print optional prompt
  if (prompt != NULL) {
    printf ("%s", prompt);
    fflush (stdout);
  }

  // Get the intput from stdin
  if (fgets (buff, sz, stdin) == NULL) {
    return NO_INPUT;
  }

  // If it was too long, there'll be no newline. In that case, we flush
  // to end of line so that excess doesn't affect the next call.
  if (buff[strlen(buff) - 1] != '\n') { // strlen of the actually entered line, not the original array size
    extra = 0;
    while (((ch = getchar()) != '\n') && (ch != EOF)) {
      extra = 1;
    }
    return (extra == 1) ? TOO_LONG : OK;
  }

  // Otherwise remove newline and give string back to caller.
  buff[strlen(buff) - 1] = '\0';
  return OK;
}

// Append CR-LF to the end of a string (after cleaning up any existing trailing CR or LF)
void appendcrlf(char *string) {
  // Make sure it doesn't already end with CR or LF
  while (string[strlen(string) - 1] == '\r' || string[strlen(string) - 1] == '\r') {
    string[strlen(string) - 1] = '\0';
  }

  int startlen = strlen(string);
  string[startlen] = '\r';
  string[startlen + 1] = '\n';
  string[startlen + 2] = '\0';
}

// get sockaddr, IPv4 or IPv6:
void *get_in_addr(struct sockaddr *sa) {
  if (sa->sa_family == AF_INET) {
    return &(((struct sockaddr_in*)sa)->sin_addr);
  }

  return &(((struct sockaddr_in6*)sa)->sin6_addr);
}

// Create socket to connect to real IRC server
int createserversocket(char *host, char *port) {
  int sockfd;
  struct addrinfo hints, *servinfo, *p;
  int rv;// return value for getaddrinfo (for error message)
  char s[INET6_ADDRSTRLEN];

  memset(&hints, 0, sizeof hints);
  hints.ai_family = AF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;

  if ((rv = getaddrinfo(host, port, &hints, &servinfo)) != 0) {
    fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(rv));
    return 1;
  }

  // loop through all the results and connect to the first we can
  for (p = servinfo; p != NULL; p = p->ai_next) {
    if ((sockfd = socket(p->ai_family, p->ai_socktype, p->ai_protocol)) == -1) {
      perror("bouncer-server: socket");
      continue;
    }

    if (connect(sockfd, p->ai_addr, p->ai_addrlen) == -1) {
      close(sockfd);
      perror("bouncer-server: connect");
      continue;
    }

    break;
  }

  if (p == NULL) {
    fprintf(stderr, "bouncer-server: failed to connect\n");
    return 2;
  }

  inet_ntop(p->ai_family, get_in_addr((struct sockaddr *)p->ai_addr), s, sizeof s);
  printf("bouncer-server: connecting to %s\n", s);

  freeaddrinfo(servinfo); // all done with this structure

  return sockfd;
}

// Create listening socket to listen for bouncer client connections
int createclientsocket(char *listenport) {
  listenport = BOUNCERLISTENPORT;

  int listener;     // listening socket descriptor
  int rv; // return value for getaddrinfo (for error message)
	struct addrinfo hints, *ai, *p;
  int yes = 1; // for enabling socket options with setsockopt

	// get us a socket and bind it
	memset(&hints, 0, sizeof hints);
	hints.ai_family = AF_UNSPEC;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_flags = AI_PASSIVE;

	if ((rv = getaddrinfo(NULL, listenport, &hints, &ai)) != 0) {
		fprintf(stderr, "bouncer-client: %s\n", gai_strerror(rv));
		exit(1);
	}

  // Try for IPv6
	for (p = ai; p != NULL; p = p->ai_next) {
    if (p->ai_family == AF_INET6) {
      listener = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
      if (listener != -1) {
        // success, got IPv6!
        printf("success, got IPv6!  ai_family: %d\n", p->ai_family);
        break;
      }
    }
  }

  // Try for IPv4 if IPv6 failed
  if (listener < 0) {
    for (p = ai; p != NULL; p = p->ai_next) {
      if (p->ai_family == AF_INET) {
        listener = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
        if (listener != -1) {
          // moderate success, got IPv4!
          printf("moderate success, got IPv4!  ai_family: %d\n", p->ai_family);
          break;
        }
      }
    }
  }

  // allow address re-use
  setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)); // 1 as in non-zero as in enable

  if (bind(listener, p->ai_addr, p->ai_addrlen) < 0) {
    // failed to bind
    close(listener);
    printf("bouncer-client: failed to bind, exiting...\n");
    exit(1);
  }

	// if we got here, it means we didn't get bound
	if (p == NULL) {
		fprintf(stderr, "bouncer-client: failed to bind\n");
		exit(2);
	}

	freeaddrinfo(ai); // all done with this

    // listen
    if (listen(listener, BACKLOG) == -1) {
        perror("listen");
        exit(1);
    }

  return listener;
}

// Relay/send message to all clients (optionally except one)
// "except" is used to send to all clients _except_ the fd provided (except = 0 avoids this, i.e. sends to all)
int sendtoallclients(int *clientsockfd, int fdmax, int arr_clients[], char *str, int str_len, int except) {

  char *sendertype;

  if (except) {
    sendertype = "bouncer-client";
  } else {
    sendertype = "bouncer-server";
  }

  // relay/send to all clients...
  for (int i = *clientsockfd + 1; i <= fdmax; i++) {
    // Skip the current client if "except" non-zero (no need to send back to itself)
    if (i == except) {
      continue;
    }
    // TODO maybe see if things are in rfds (not sure what this means any more - perhaps it was to do with only sending to connected things which is now solved)
    // ...but only if they are connected
    for (int j = 0; j < MAXCLIENTS; j++) {
      if (arr_clients[j] == i) {
        printf("%s: sending %s to client with fd %d.\n", sendertype, str, i);
        if (send(i, str, str_len, 0) == -1) {
          perror("send");
        }
      }
    }
  }

  return 0;
}

// Where the big bouncing loop is
void dochat(int *serversockfd, int *clientsockfd) {
  char serverbuf[MAXDATASIZE]; // buffer for receiving data on server socket
  char clientbuf[MAXDATASIZE]; // buffer for receiving data on client socket(s)
  int servernumbytes; // Number of bytes received from remote server
  char outgoingmsg[MAXDATASIZE]; // String to send to server
  int outgoingmsgrc; // Return code from getstdin() for outgoing message
  int arr_clients[MAXCLIENTS]; // Array of all clients - 0 means not connected, greater than 0 means connected and the value is the fd number (so we know which ones to try to read)
  int num_clients = 0; // Current number of clients

  int fdmax; // highest numbered socket fd

  socklen_t addrlen; // client remote address size
  char remoteIP[INET6_ADDRSTRLEN]; // remote IP (assume up to IPv6 size)
  int newfd;  // newly accept()ed socket descriptor
  struct sockaddr_storage remoteaddr; // client address
  int clientnumbytes;

  fdmax = *clientsockfd; // keep track of highest fd number, currently client socket as created last (after server socket)

  fd_set rfds; // set of read fds to monitor with select() - 0: stdin, 1: stdout, 2: stderr, 3 and higher: sockets (at time of writing, 3: real IRC server, 4: client listener, 5 and higher: clients)

  // set all the clients to be "not connected"
  for (int i = 0; i < MAXCLIENTS; i++) {
    arr_clients[i] = 0;
  }

  while (1) {
    printf("top of loop, fdmax %d.\n", fdmax);
    FD_ZERO(&rfds); // clear entries from fd set

    FD_SET(STDIN, &rfds); // add STDIN (fd 0) to read fds to monitor
    FD_SET(*serversockfd, &rfds); // add our server network socket to monitor
    FD_SET(*clientsockfd, &rfds); // add our client network socket to monitor

    // Add all connected clients to monitor (only add ones that are connected (clients[i] > 0))
    // TODO - make sure *serversockfd stays at the same value (probably 3?) in all cases - what if the server disconnects/reconnects/etc.
    // TODO - now that only connected clients are monitored, perhaps tracking using both fdmax and num_client loops is unnecessary?
    for (int i = 0; i < MAXCLIENTS; i++) {
      if (arr_clients[i] > 0) {
        printf("monitoring fd %d.\n", arr_clients[i]);
        FD_SET(arr_clients[i], &rfds);
      }
    }

    printf("select()ing...\n");
    // check to see if anything in the fd_set is waiting - waits here until one of the fds in the set does something
    if (select(fdmax + 1, &rfds, NULL, NULL, NULL) < 0) { // network socket + 1, rfds, no writes, no exceptions/errors, no timeout
      printf("receive error, exiting!?\n");
      perror("select");
    }

    // TODO - switch around the serversockfd and STDIN FD_ISSET if-statements?  They feel the wrong way round.  Are they like this on purpose?  I can't remember.

    // see if there's anything on the server side from the real IRCd
    if (FD_ISSET(*serversockfd, &rfds)) {
      printf("reading server socket!\n");

      if ((servernumbytes = recv(*serversockfd, serverbuf, MAXDATASIZE - 1, 0)) == -1) {
        printf("receive error (-1), exiting...\n");
        perror("recv");
        exit(1);
      } else if (servernumbytes == 0) {
        printf("socket closed (or no data received) (0), exiting...\n");
        perror("recv");
        exit(1);
      }
      serverbuf[servernumbytes] = '\0';

      printf("BOUNCER-SERVER RECEIVED: %s\n", serverbuf);
      printf("bouncer-server: sending it to all clients...\n");

      // Relay/send to all clients ("except" = 0 because this should send to all clients)
      sendtoallclients(clientsockfd, fdmax, arr_clients, serverbuf, servernumbytes, 0);
    }

    // see if there's anything from stdin
    if (FD_ISSET(STDIN, &rfds)) {
      printf("reading stdin!\n");

      outgoingmsgrc = getstdin(NULL, outgoingmsg, sizeof(outgoingmsg));

      if (outgoingmsgrc == NO_INPUT) {
        printf("\nError!  No input.\n");
      } else if (outgoingmsgrc == TOO_LONG) {
        printf ("Error!  Too long.  Would have allowed up to: [%s]\n", outgoingmsg);
      }

      appendcrlf(outgoingmsg);

      if (send(*serversockfd, outgoingmsg, strlen(outgoingmsg), 0) == -1) { // 0 means no flags
        printf("send error, exiting...\n");
        perror("send");
      }
    }

    // go through all the remaining sockets to see if there's anything from the client sockets (either new connections or existing clients sending messages)
    // (clear newfd before doing this so we can tell if we're querying a new client or not)
    newfd = 0;
    for (int i = *clientsockfd; i <= fdmax; i++) {
      // skip if newfd as we know we have just accept()ed it
      if (i == newfd) {
        continue;
      }
      printf("checking client socket %d out of %d.\n", i, fdmax);
      if (FD_ISSET(i, &rfds)) {
        printf("fd %d is FD_ISSET and it is a...\n", i);
        // if value of clientsockfd then must be a new connection, if greater must be an existing connection
        if (i == *clientsockfd) {
          printf("...new connection!\n");
          // handle new connections
          if (num_clients >= MAXCLIENTS) {
            fprintf(stderr, "too many clients!\n");
            exit(1); // TODO - handle cleanly instead of exiting!
          }
          addrlen = sizeof remoteaddr;
          newfd = accept(*clientsockfd, (struct sockaddr *)&remoteaddr, &addrlen);
          if (newfd == -1) {
            // something went wrong when accept()ing
            perror("accept");
          } else {
            FD_SET(newfd, &rfds); // add to master set // TODO - needed?  going to be re-done at the top anyway...
            if (newfd > fdmax) {    // keep track of the max
              fdmax = newfd;
            }
            // Find a free element in the clients array and set to new fd value
            for (int j = 0; j < MAXCLIENTS; j++) {
                if (arr_clients[j] == 0) {
                  arr_clients[j] = newfd;
                  break;
                }
            }
            // TODO - Handle the "find a free element" loop not finding a free element
            num_clients++; // Track total number of clients
            printf("bouncer-client: new connection from %s on socket %d\n", inet_ntop(remoteaddr.ss_family, get_in_addr((struct sockaddr*)&remoteaddr), remoteIP, INET6_ADDRSTRLEN), newfd);
            printf("bouncer-client: total client connections: %d\n", num_clients);
          }
        } else {
          printf("...previous connection!\n");
          // handle data from a client
          if ((clientnumbytes = recv(i, clientbuf, sizeof clientbuf, 0)) <= 0) {
            // got error or connection closed by client
            if (clientnumbytes == 0) {
              // connection closed
              printf("bouncer-client: socket %d hung up\n", i);
            } else {
              perror("recv");
            }
            close(i); // bye!
            FD_CLR(i, &rfds); // remove from master set - TODO is this needed at the moment since we just add everything from *clientsockfd to fdmax to rfds
            // Remove the client from the clients array
            for (int j = 0; j < MAXCLIENTS; j++) {
                if (arr_clients[j] == i) {
                  printf("found and clearing fd %d from arr_clients[%d]\n", i, j);
                  arr_clients[j] = 0;
                  break;
                }
            }
            // TODO - Handle the "remove the client" loop not finding the old fd
            num_clients--; // Track total number of clients
            printf("bouncer-client: total client connections: %d\n", num_clients);
          } else {
            // we got some data from a client
            // null terminate that baby
            clientbuf[clientnumbytes] = '\0'; // TODO make sure this can't overrun if some super long line (max bytes?) was received
            // clear up any newlines
            while (clientbuf[strlen(clientbuf) - 1] == '\n' || clientbuf[strlen(clientbuf) - 1] == '\r') {
              clientbuf[strlen(clientbuf) - 1] = '\0';
            }
            printf("BOUNCER-CLIENT RECEIVED: '%s'\n", clientbuf);
            printf("bouncer-client: sending it to the server...\n");
            appendcrlf(clientbuf);
            // send it to the server - TODO function up this sending thing
            if (send(*serversockfd, clientbuf, strlen(clientbuf), 0) == -1) { // 0 means no flags
              printf("send error, exiting...\n");
              perror("send");
            }

            // send the same thing to all *other* clients (all except for fd "i")
            sendtoallclients(clientsockfd, fdmax, arr_clients, clientbuf, strlen(clientbuf), i);
          }
        }
      }
    }
  }
}

int main(int argc, char *argv[]) {
  if (argc < 3) {
    fprintf(stderr,"usage: %s hostname port [-d]\n", argv[0]);
    exit(1);
  }

  if (argc == 4) {
    if (!strcmp(argv[3], "-d")) {
      debugmode = 1;
      debug("debug mode enabled\n");
    }
  }

  // TODO: see if any of this can be shared (i.e. 1. avoid code duplication, and 2. see if variables can be shared between client/server sockets)

  // TODO: track fdmax - kind of doing this now with arr_clients and num_clients but might be pointlessly tracking both in some places (?)

  // I will try to keep to the notation of "server" meaning the real IRCd, "bouncer" meaning the bouncer, and "client" meaning the real IRC client

  // BOUNCER-TO-SERVER socket things

  // Create server socket
  int serversockfd = createserversocket(argv[1], argv[2]);
printf("serversockfd: %d.\n", serversockfd);

  // Create client socket (after server so we can use its fd number later as fdmax)
  int clientsockfd = createclientsocket(BOUNCERLISTENPORT);
printf("clientsockfd: %d.\n", clientsockfd);

  dochat(&serversockfd, &clientsockfd);

  printf("dochat() complete, closing socket...\n");

  close(serversockfd);

  return 0;
}