#include "cons.h"
#include <stdio.h>
#include <sys/types.h>
#include <sgtty.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <fcntl.h>
#include <pwd.h>

/* if something goes wrong, reset the terminal and exit */
#define puke(msg) {  if (screwy) {\
                       ioctl(0,TIOCGETP,&sty);\
                       sty.sg_flags = oldflags;\
                       ioctl(0,TIOCSETP,&sty);\
                     }\
                     perror(msg); exit(-1); }
#define cough(msg) {  if (screwy) {\
                       ioctl(0,TIOCGETP,&sty);\
                       sty.sg_flags = oldflags;\
                       ioctl(0,TIOCSETP,&sty);\
                     }\
                     perror(msg); }

struct sockaddr_in  conserv_port;
struct hostent *hp;
int s, flag = 0, screwy=0;
struct protoent *proto;
int rmask, wmask, nf, prtnum=0;
short oldflags;
char ch[BUFSIZ];
char *command, *ptr;
char *usrnm;
char *server;
char lclhost[64], pass[32];
struct sgttyb sty;
char ports[128], *portptr, *index();
char *hostlist[]=HOSTS;

void main(argc,argv)
int argc;
char *argv[];
{
  char opt, ich1=0, ich2=0;
  int nr, hn, i;
  char tmp_msg[132];

  /* start poor excuse for command line parsing */
  if (argc < 2 || argc > 5) {
    fprintf(stderr,"usage: %s [attach|force|spy|who] system-name [num1 num2]\n",argv[0]);
    exit(-1);
  }
  if (argc > 3) {   /* there must be new escape characters supplied */
    ich1=(char)atol(argv[argc-2]);
    ich2=(char)atol(argv[argc-1]);
    server=argv[argc-3];
  }
  else server=argv[argc-1];
  usrnm=getpwuid(getuid())->pw_name;
  gethostname(lclhost,63);
  command="attach";

  if (argc == 3 || argc == 5)
  {
    if ((opt=argv[1][0]) == '-') opt=argv[1][1];
    switch(opt)
    {
      case 'a' : /* attach */
          command="attach";
          break;
      case 'A' : /* attach with log replay */
          command="Attach";
          break;
      case 'f' : /* force attach */
          command="force";
          break;
      case 'F' : /* force attach with log replay */
          command="Force";
          break;
      case 's' : /* spy */
          command="spy";
          break;
      case 'S' : /* spy with log replay */
          command="Spy";
          break;
      case 'w' : /* who */
      case 'W' :
          command="who";
          break;
      default  : /* huh? */
          fprintf(stderr,"usage: %s [attach|force|spy|who] system-name\n",argv[0]);
          exit(-1);
          break;
    }
  }

  if (strcmp(server,"who") == 0) command="who";

  /* for each console server */
  for (hn=0; hostlist[hn] != NULL; ++hn) {

    /* set variables for socket connection */
    bzero(&conserv_port, sizeof(conserv_port));
    conserv_port.sin_family = AF_INET;
    hp = gethostbyname(hostlist[hn]);
    if (hp == NULL) {
	    sprintf(tmp_msg, "can't get hostname for %s", hostlist[hn]);
	    cough(tmp_msg);
	    continue;
    }
    bcopy(hp->h_addr, &conserv_port.sin_addr, hp->h_length);
    conserv_port.sin_port = htons(PORT);

    /* set up the socket */
    if ((s = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
	    sprintf(tmp_msg, "can't create socket to %s", hostlist[hn]);
	    cough (tmp_msg);
	    continue;
    }
    if (connect(s,&conserv_port, sizeof(conserv_port)) != 0) {
	    sprintf(tmp_msg, "can't connect to %s", hostlist[hn]);
	    cough(tmp_msg);
	    continue;
    }

    /* send server name then wait for port number */
    sprintf(ch,"%s\n",server);
    if (write(s,ch,strlen(ch)) == 0) cough ("write");

    /* get the port number(s) */
    ptr=ports;
    do {
      if (read(s,ptr,1) == 0) cough ("read");
    } while (*ptr++ != '\n');
    *ptr='\0';
    close(s);

    if (strcmp(command,"who") == 0) {
      who();
      if (strcmp(server,"who") != 0) break;  /* for single server who */
    }
    else {
      portptr=ports;
      sscanf(portptr,"%d",&prtnum);
      if (prtnum != 0) break;
    }
  }

  if (strcmp(command,"who") == 0) exit(0);
  else if (prtnum == 0) {
    fprintf(stderr,"%s",ports);
    exit(-1);
  }

  conserv_port.sin_port = htons(prtnum);

  strcpy(pass,getpass("Enter password:"));
  
  /* change stdin to raw with no echo */
  ioctl(0,TIOCGETP,&sty);
  oldflags = sty.sg_flags;
  screwy=1;
  sty.sg_flags = (sty.sg_flags | RAW);
  sty.sg_flags = (sty.sg_flags & ~ECHO);
  ioctl(0,TIOCSETP,&sty);

  /* set up the socket */
  if ((s = socket(AF_INET, SOCK_STREAM, 0)) < 0) puke ("socket");
  if (connect(s,&conserv_port, sizeof(conserv_port)) != 0)
    puke("connect b");

  /* send sign-on stuff, then wait for "ok^M\n" before allowing a write */
  sprintf(ch,"%s:%s:%s:%s:%s\n",server,command,usrnm,lclhost,pass);
  if (write(s,ch,strlen(ch)) == 0) puke ("write");

  ptr=ch;
  do {
    if (read(s,ptr,1) == 0)
      puke ("read");
  } while (*ptr++ != '\n');
  *ptr='\0';
  if (ch[0] != 'o' || ch[1] != 'k' || ch[2] != '\r' || ch[3] != '\n')
  {
    fprintf(stderr,"%s",ch);
    ioctl(0,TIOCGETP,&sty);
    sty.sg_flags = oldflags;
    ioctl(0,TIOCSETP,&sty);
    exit(-1);
  }

  if (ich1 == 0 || ich2 == 0) {  /* if escape sequence is unchaged */
    printf("Escape sequence is ctrl-underscore, escape.\r\n");
    printf("Enter  ctrl-_ esc ?   for help.\r\n");
    printf("%s",ch);
  }
  else {  /* change escape sequence */
    printf("Escape sequence is redefined, ascii values %d, %d.\r\n",ich1,ich2);
    printf("%s",ch);
    sprintf(ch,"k%c%c",ich1,ich2);
    write(s,ch,5);  /* blindly assume it worked - we'll find out soon nuff */
  }

  /* read from stdin and the socket (non-blocking!).
     rmask indicates which descriptors to read from,
     the others are not used, nor is the result from
     select, read, or write.  */
  wmask = 0; flag=0;
  while (flag == 0) {
    rmask = (1 << s) | (1 << 0); /* reset read mask */
    nf = select(32,&rmask,&wmask,&wmask,NULL); /* and this (lint) */
    if ((rmask >> s) & 1) {  /* anything from socket? */
      if ((nr=read(s,ch,BUFSIZ)) == 0) {
        flag = 1; /* reached EOF */
      }
      else {
	/* strip input - (conserver only strips output) */
        for (i=0; i<nr; ++i) ch[i]&=127;
        nf = write(1,ch,nr);
      }
    }
    if ((rmask >> 0) & 1) {   /* anything from stdin? */
      if ((nr=read(0,ch,BUFSIZ)) == 0)
        flag = 1; /* reached EOF */
      else
        nf = write(s,ch,nr);
    }
  }
  close(s);

  /* change stdin back to original state */
  ioctl(0,TIOCGETP,&sty);
  sty.sg_flags = oldflags;
  ioctl(0,TIOCSETP,&sty);
  fprintf(stderr,"\nConnection closed.\n");
}

/* who is implemented seperately from above because of the need to connect
   to several ports.  It's about the same as what's above otherwise */
who()
{
  int nr, i;

  /* change stdin to raw with no echo */
  ioctl(0,TIOCGETP,&sty);
  oldflags = sty.sg_flags;
  screwy=1;
  sty.sg_flags = (sty.sg_flags | RAW);
  sty.sg_flags = (sty.sg_flags & ~ECHO);
  ioctl(0,TIOCSETP,&sty);

  portptr=ports;
  do {
    if (*portptr == '\n') break;  /* ignore CR at end of list of ports */
    sscanf(portptr,"%d",&prtnum);
    conserv_port.sin_port = htons(prtnum);
    if (prtnum == 0) {
      fprintf(stderr,"%s",ports);
      ioctl(0,TIOCGETP,&sty);
      sty.sg_flags = oldflags;
      ioctl(0,TIOCSETP,&sty);
      exit(-1);
    }

    /* set up the socket */
    if ((s = socket(AF_INET, SOCK_STREAM, 0)) < 0) puke ("socket");
    if (connect(s,&conserv_port, sizeof(conserv_port)) != 0)
      puke("connect c");
  
    /* send sign-on stuff, then wait for "ok^M\n" before allowing a write */
    sprintf(ch,"%s:%s:%s:%s:%s\n",server,command,usrnm,lclhost,pass);
    if (write(s,ch,strlen(ch)) == 0) puke ("write");
  
    ptr=ch;
    do {
    if (read(s,ptr,1) == 0)
      puke ("read");
    } while (*ptr++ != '\n');
    *ptr='\0';
    if (ch[0] != 'o' || ch[1] != 'k' || ch[2] != '\r' || ch[3] != '\n')
    {
      fprintf(stderr,"%s",ch);
      ioctl(0,TIOCGETP,&sty);
      sty.sg_flags = oldflags;
      ioctl(0,TIOCSETP,&sty);
      exit(-1);
    }
  
    /* read from stdin and the socket (non-blocking!).
       rmask indicates which descriptors to read from,
       the others are not used, nor is the result from
       select, read, or write.  */
    wmask = 0; flag=0;
    while (flag == 0) {
      rmask = (1 << s) | (1 << 0); /* reset read mask */
      nf = select(32,&rmask,&wmask,&wmask,NULL); /* and this (lint) */
      if ((rmask >> s) & 1) {  /* anything from socket? */
        if ((nr=read(s,ch,BUFSIZ)) == 0) {
	  flag = 1; /* reached EOF */
        }
        else {
          for (i=0; i<nr; ++i) ch[i]&=127;  /* clear parity */
	  nf = write(1,ch,nr);
        }
      }
    }
    close(s);
  } while ((portptr=index(portptr,':')+1) != (char *)(NULL+1));

  /* change stdin back to original state */
  ioctl(0,TIOCGETP,&sty);
  sty.sg_flags = oldflags;
  ioctl(0,TIOCSETP,&sty);
  if (command[0] != 'w') fprintf(stderr,"\nConnection closed.\n");
  exit(0);
}
