/*
**  SSL/TCPConnection.m
**
**  Copyright (c) 2001, 2002, 2003
**
**  Author: Ludovic Marcotte <ludovic@Sophos.ca>
**
**  This library is free software; you can redistribute it and/or
**  modify it under the terms of the GNU Lesser General Public
**  License as published by the Free Software Foundation; either
**  version 2.1 of the License, or (at your option) any later version.
**  
**  This library is distributed in the hope that it will be useful,
**  but WITHOUT ANY WARRANTY; without even the implied warranty of
**  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
**  Lesser General Public License for more details.
**  
**  You should have received a copy of the GNU Lesser General Public
**  License along with this library; if not, write to the Free Software
**  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/

#import "TCPSSLConnection.h"

#import <Pantomime/Constants.h>

#import <stdio.h>
#import <stdlib.h>
#import <netinet/in.h>
#import <signal.h>
#import <sys/ioctl.h>
#import <sys/socket.h>
#import <sys/time.h>
#import <sys/types.h>
#import <netdb.h>
#import <string.h>
#import <unistd.h>	/* For read() and write() and close() */

#ifdef MACOSX
#import <sys/uio.h>	/* For read() and write() */
#endif

static void sigpipe_handle(int x);

#define READ_BUFFER     4096

@implementation TCPSSLConnection

//
//
//
- (void) dealloc
{
  RELEASE(name);

  [super dealloc];
}


//
//
//
- (id) initWithName: (NSString *) theName
	       port: (int) thePort
{
  return [self initWithName: theName
	       port: thePort
	       connectionTimeout: 30
	       readTimeout: 30
	       writeTimeout: 30];
}


//
// This methods throws an exception if the connection timeout
// is exhausted and the connection hasn't been established yet.d
//
- (id) initWithName: (NSString *) theName
	       port: (int) thePort
  connectionTimeout: (int) theConnectionTimeout
	readTimeout: (int) theReadTimeout
       writeTimeout: (int) theWriteTimeout
{
  struct sockaddr_in server;
  struct hostent *host_info;
  int nonblock = 1;

  if ( theName == nil || thePort <= 0 )
    {
      AUTORELEASE(self);
      NSDebugLog(@"TCPConnection: Attempted to initialize with a nil name or a negative or zero port value.");
      return nil;
    }
  
  // We set our ivars through our mutation methods
  [self setName: theName];
  [self setPort: thePort];
  [self setConnectionTimeout: theConnectionTimeout];
  [self setReadTimeout: theReadTimeout];
  [self setWriteTimeout: theWriteTimeout];

  // We first initialize our SSL context
  [self _initializeSSLContext];
  
  // We get the file descriptor associated to a socket
  fd = socket(PF_INET, SOCK_STREAM, 0);

  if ( [self fd] == -1 ) 
    {
      AUTORELEASE(self);
      NSDebugLog(@"TCPConnection: An error occured while creating the endpoint for communications");
      return nil;
    }

  // We get our hostent structure for our server name
  host_info = gethostbyname([[self name] cString]);
  
  if ( !host_info )
    {
      AUTORELEASE(self);
      NSDebugLog(@"TCPConnection: Unable to get the hostent structure.");
      return nil;
    }

  server.sin_family = host_info->h_addrtype;
  memcpy((char *)&server.sin_addr, host_info->h_addr, host_info->h_length);
  server.sin_port = htons( [self port] );
  
  // We set the non-blocking I/O flag on [self fd]
  if ( ioctl([self fd], FIONBIO, &nonblock) == -1 )
    {
      AUTORELEASE(self);
      NSDebugLog(@"TCPConnection: Unable to set the non-blocking I/O flag on the socket");
      return nil;
    }

  // We initiate our connection to the socket
  if ( connect([self fd], (struct sockaddr *)&server, sizeof(server)) == -1 )
    {
      if ( errno == EINPROGRESS )
        {
          // The  socket is non-blocking and the connection cannot be completed immediately.
          fd_set fdset;
          struct timeval timeout;
          int value;
          
          // We remove all descriptors from the set fdset
          FD_ZERO(&fdset);
          
          // We add the descriptor [self fd] to the fdset set
          FD_SET([self fd], &fdset);
	 
	  // We set the timeout for our connection
          timeout.tv_sec = [self connectionTimeout];
          timeout.tv_usec = 0;
          
          value = select ([self fd] + 1, NULL, &fdset, NULL, &timeout);

	  // An error occured..
          if ( value == -1 )
            {
	      AUTORELEASE(self);
              NSDebugLog(@"TCPConnection: An error occured while calling select().");
              return nil;
            }
	  // Our fdset has ready descriptors (for writability)
          else if ( value > 0 )
            {
              int soError, size;
              
	      size = sizeof(soError);
	      
              // We get the options at the socket level (so we use SOL_SOCKET)
              // returns -1 on error, 0 on success
              if ( getsockopt([self fd], SOL_SOCKET, SO_ERROR, &soError, &size) == -1 )
                {
		  AUTORELEASE(self);
                  NSDebugLog(@"TCPConnection: An error occured while trying to get the socket options.");
                  return nil;
                }

              if ( soError != 0)
                {
		  AUTORELEASE(self);
                  NSDebugLog(@"TCPConnection: connect failed.");
                  return nil;
                }
            }
	  // select() has returned 0 which means that the timeout has expired.
          else
            {
	      AUTORELEASE(self);
              NSDebugLog(@"TCPConnection: The connection timeout has expired.");
              return nil;
            }
        } // if ( errno == EINPROGRESS ) ...
      else
        {
	  AUTORELEASE(self);
          NSDebugLog(@"TCPConnection: A general socket error occured.");
          return nil;
        }
    } // if ( connect(...) )
  
  // We clear the non-blocking I/O flag, for now.
  nonblock = 0;
  
  if (ioctl([self fd], FIONBIO, &nonblock) == -1)
   {
     AUTORELEASE(self);
     NSDebugLog(@"TCPConnection: An error occured while clearing the non-blocking I/O flag");
     return nil;
   }
  
  // We then connect the SSL socket
  ssl = SSL_new(ctx);
  sbio = BIO_new_socket([self fd], BIO_NOCLOSE);
  SSL_set_bio(ssl, sbio, sbio);
  
  if ( SSL_connect(ssl) <= 0 )
    {
      AUTORELEASE(self);
      NSDebugLog(@"TCPConnection: An error occured while establishing the SSL connection.");
      return nil;
    }
  
  return self;
}


//
// access / mutation methods
//

- (NSString *) name
{
  return name;
}

- (void) setName: (NSString *) theName
{
  RELEASE(name);
  RETAIN(theName);
  name = theName;
}

- (int) port
{
  return port;
}

- (void) setPort: (int) thePort
{
  port = thePort;
}

- (int) connectionTimeout
{
  return connectionTimeout;
}

- (void) setConnectionTimeout: (int) theConnectionTimeout
{
  if ( theConnectionTimeout > 0 )
    {
      connectionTimeout = theConnectionTimeout;
    }
  else
    {
      connectionTimeout = 30;
    }
}

- (int) readTimeout
{
  return readTimeout;
}

- (void) setReadTimeout: (int) theReadTimeout
{
  if ( theReadTimeout > 0 )
    {
      readTimeout = theReadTimeout;
    }
  else
    {
      readTimeout = 30;
    }
}

- (int) writeTimeout
{
  return writeTimeout;
}

- (void) setWriteTimeout: (int) theWriteTimeout
{
  if ( theWriteTimeout > 0 )
    {
      writeTimeout = theWriteTimeout;
    }
  else
    {
      writeTimeout = 30;
    }
}


//
// This method is used to return the file descriptor
// associated with our socket.
//
- (int) fd
{
  return fd;
}


//
//
//
- (void) setStopTarget: (id) theTarget
{
  // Do nothing for now
}


//
//
//
- (void) setStopSelector: (SEL) theSelector
{
  // Do nothing for now
}


//
// other methods
//
- (void) close
{
  SSL_CTX_free(ctx);

  if ( close([self fd]) < 0 )
    {
      NSDebugLog(@"TCPConnection: An error occured while closing the file descriptor associated with the socket");
    }
}


//
// Read "theLength" bytes.
//
- (NSData *) readDataOfLength: (int) theLength
{
  NSData *aData;
  char *buf;
  int len;

  buf = (char *) malloc( theLength * sizeof(char));
  memset(buf, 0, theLength);
  len = theLength;
  
  [self _readBytes: buf
	length: &len];

  aData = [[NSData alloc] initWithBytesNoCopy: buf  
			  length: theLength
			  freeWhenDone: YES];

  if ( [aData length] == 0 )
    {
      RELEASE(aData);
      return nil;
    }
  
  return AUTORELEASE(aData);
}


//
//
//
- (NSData *) readDataToEndOfLine
{
  NSData *aData;
 
  char *buf;
  int len;
  
  buf = (char *) malloc( READ_BUFFER * sizeof(char));

  [self _readBytesBySkippingCR: NO  
	buf: &buf  
	length: &len];

  aData = [NSData dataWithBytesNoCopy: buf
    		  length: len
		  freeWhenDone: YES];

  return aData;
}


//
// Read a string of size theLenght (excluding the null char).
//
- (NSString *) readStringOfLength: (int) theLength
{
  NSString *aString;
  char *buf;
  int len;

  buf = (char *) malloc( (theLength + 1) * sizeof(char));
  memset(buf, 0, theLength + 1);
  len = theLength;
  
  [self _readBytes: buf
	length: &len];
  
  aString = [NSString stringWithCString: buf];
  free(buf);
  
  //NSLog(@"R: |%@|", aString);
  
  if ( [aString length] == 0 )
    {
      return nil;
    }
 
  return aString;
}


//
// The current line length limit that we read from a socket is 4096 bytes (READ_BUFFER)
// including the null char.
//
- (NSString *) readStringToEndOfLine
{
  return [self readStringToEndOfLineSkippingCR: NO];
}


//
//
//
- (NSString *) readStringToEndOfLineSkippingCR: (BOOL) aBOOL
{
  NSString *aString;

  char *buf;
  int len;

  buf = (char *) malloc( READ_BUFFER * sizeof(char));
  
  [self _readBytesBySkippingCR: aBOOL
	buf: &buf
	length: &len];

  aString = [NSString stringWithCString: buf];
  free(buf);

  //NSLog(@"R: |%@| %d", aString, len);

  if (aString == nil || [aString length] == 0)
    {
      return nil;
    }
  else
    {
      return aString;
    }
}


//
// Write a 'line' to the socket. A line is a simple string object
// terminated by a CRLF.
//
- (BOOL) writeLine: (NSString *) theLine
{
  return [self writeString: [NSString stringWithFormat: @"%@\r\n", theLine] ];
}


//
// Write a string to the socket. We currently write the cString representation
// of a Unicode string.
//
- (BOOL) writeString: (NSString *) theString
{
  char *cString;
  int len;

  //NSLog(@"S: |%@|", theString);

  cString = (char *)[theString cString];
  len = strlen( cString );
  
  [self _writeBytes: cString
	length: &len];

  return YES;
}


//
// Write bytes to the socket.
//
- (BOOL) writeData: (NSData *) theData
{
  char *bytes;
  int len;

  bytes = (char*)[theData bytes];
  len = [theData length];
  
  [self _writeBytes: bytes
	length: &len]; 
  
  return YES;
}

@end


//
// private methods
// 
@implementation TCPSSLConnection (Private)

- (void) _initializeSSLContext
{
  SSL_METHOD *meth;
  BIO *bio_err;

  bio_err = 0;
  
  if( !bio_err )
    {
      // Global system initialization
      SSL_library_init();
      SSL_load_error_strings();
      
      // An error write context 
      bio_err = BIO_new_fp(stderr,BIO_NOCLOSE);
    }
  
  // Set up a SIGPIPE handler
  signal(SIGPIPE, sigpipe_handle);
  
  // Create our context
  meth = SSLv23_method();
  ctx = SSL_CTX_new(meth);
}


//
//
//
- (void) _readBytes: (char *) theBytes
	     length: (int *) theLength
{
  int tot, bytes;
  
  tot = 0;

  while ( tot < *theLength )
    {
      bytes = SSL_read(ssl, theBytes + tot, *theLength - tot);
      tot += bytes;
    }
}


//
//
//
- (void) _readBytesBySkippingCR: (BOOL) aBOOL
			    buf: (char **) buf
			 length: (int *) theLength
{
  int i, len, size;
  char c;

  memset(*buf, 0, READ_BUFFER);
  size = READ_BUFFER;
  len = 1;  
  i = 0;

  while ( YES )
    {     
      [self _readBytes: &c
	    length: &len];
      
      // We verify if we must expand our buffer
      if ( (i+1) == (size - 2) )
	{
	  size += READ_BUFFER;
	  *buf = realloc(*buf, size); 
	  memset(*buf+READ_BUFFER, 0, READ_BUFFER);
	}

      if ( !aBOOL )
        {
          (*buf)[i] = c;
          i++;
        }

      if (c == '\n')
        {
          break;
        }
     
      // We skip the \r
      if (aBOOL && c != '\r' )
        {
          (*buf)[i] = c;
          i++;
        }
    }

  *theLength = i;
}


//
//
//
- (void) _writeBytes: (char *) theBytes
	      length: (int *) theLength
{
  int tot, bytes;

  tot = 0;

  while ( tot < *theLength )
    {
      bytes = SSL_write(ssl, theBytes + tot, *theLength - tot);
      tot += bytes;
    }
}

@end


//
// C functions
//
static void sigpipe_handle(int x)
{
  // We do nothing for now.
}
