
// THIS IS FOR LINUX/UNIX
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <signal.h>
// THIS IS FOR LINUX/UNIX

#include <string>
#include <sstream>
#include <cstdio>
#include <cstdlib>

#include "lbtypes/lbobject.hpp"
#include "lbtypes/lbint.hpp"
#include "lbtypes/lbstring.hpp"
#include "lbtypes/lbtypefactory.hpp"
#include "lbtypes/lbtypeinfo.hpp"
#include "lbtypes/LBDefineMacros.hpp"
#include "lbtypes/lbexception.hpp"
#include "lbtypes/lbvarargs.hpp"
#include "lbtypes/RefCountedPtr.hpp"

#include "lbnet/socket.hpp"


namespace Net
{
  using std::string;
  using std::ostringstream;
  using std::istringstream;

  using Luban::LBObject;
  using Luban::LBInt;
  using Luban::LBString;
  using Luban::LBVarArgs;
  using Luban::LBSymbol;
  using Luban::LBException;

  static const int OBJHEADERSIZE = 32;

  LBDEFINE(Net::Socket, 1, 0 )

  Socket* Socket::staticConstructor(const LBVarArgs* args)
  {
    if ( args == 0 || args->numArgs() != 2 )
      throw LBException("to open a socket net::socket(string hostname, int port)");

    const LBString *hostname = dynamic_cast<const LBString*>(args->getArg(0));
    const LBInt *port = dynamic_cast<const LBInt*>(args->getArg(1));
    if ( ! hostname || ! port )
      throw LBException("to open a socket net::socket(string hostname, int port)");
    
    Socket *sck = new Socket(hostname->str(), int(*port));
    if ( sck->isValid() )
      return sck;
    string err = sck->errmsg();
    delete sck;
    throw LBException("Failed to open socket: "+err);
  }

  LBDEFAULT_EQUALS_FUNC(Net::Socket)

  Socket::Socket(const string& hostname, int port )
    : _socketimp(new SocketImp(hostname, port))
  {

    // #ifdef LINUX UNIX SOLARIS
    // host lookup
    struct hostent *peerip;
    peerip = gethostbyname(hostname.c_str());
    if( !peerip ) 
      {
	_socketimp->_errmsg = "unknown host: "+hostname;
	return;
      }

    // create blank socket
    int sckid = socket(AF_INET, SOCK_STREAM, 0);
    if( sckid < 0 ) 
      {
	_socketimp->_errmsg = "Failed to open socket";
	return;
      }

    struct sockaddr_in sv_addr;
    bzero((char*)&sv_addr, sizeof(sv_addr));
    sv_addr.sin_family = AF_INET;
    bcopy(peerip->h_addr, (char *) &sv_addr.sin_addr, peerip->h_length);
    sv_addr.sin_port = htons(port);

    // connect to remote server
    if (  connect(sckid, (struct sockaddr *) &sv_addr, sizeof(sv_addr)) < 0 ) 
      {
	_socketimp->_errmsg = "cannot connect to server: "+hostname;
	return;
      }
   
    _socketimp->_socketid = sckid;

  }

  Socket::Socket(int socketid, const string& clienthost, int clientport )
    : _socketimp(new SocketImp(socketid, clienthost, clientport))
  {
  }

  string Socket::toString() const
  {
    static const string s("net::socket ");
    static const string sbad("net::socket invalid");
    if ( ! _socketimp->isValid() )
      return sbad;
    ostringstream ost;
    ost<<s<<_socketimp->_host<<":"<<_socketimp->_port;
    return ost.str();
  }

  ostream& Socket::toStream(ostream& ost) const
  {
    throw LBException("net::socket object can not be streamed out");
    return ost;
  }

  istream& Socket::fromStream(istream& ist, int major, int minor)
  {
    throw LBException("net::socket object can not be streamed in");
    return ist;
  }

  bool Socket::writeobj(const LBObject& obj)
  {
    if ( _socketimp->isValid() )
      {
	ostringstream ost;
	LBObject::instanceToStream(ost, obj);
	string strm = ost.str();
	int strmsz = strm.size();
	
	// contruct object head which indicating the total size to expect
	char objheader[OBJHEADERSIZE];
	sprintf(objheader, "  %30d", strmsz); 

	// #ifdef UNIX LINUX
	// write the header first
	int written = _socketimp->writenbytes( objheader, OBJHEADERSIZE);
	if ( written != OBJHEADERSIZE )
	  return false;

	// then write the serialized obj 
	written = _socketimp->writenbytes( strm.c_str(), strmsz);

	return  written == strmsz;

      }

    throw LBException("Can not write to invalid socket: "+_socketimp->_errmsg);
    return false;
  }

  LBObject* Socket::readobj()
  {
    if ( _socketimp->isValid() )
      {
	// read object header
	char objheader[OBJHEADERSIZE];

	// #ifdef UNIX LINUX
	// write the header first
	int szread = _socketimp->readnbytes(objheader, OBJHEADERSIZE);
	if ( szread != OBJHEADERSIZE )
	  throw LBException("Failed to read object stream header");


	// figure out the size
	char *cend=0;
	int strmsz = strtol(objheader, &cend, 0);
	if ( cend == objheader )
	  throw LBException("Invalid object stream header");

	// allocate a buffer
	char buf[strmsz];
	// then read the serialized obj 
	szread = _socketimp->readnbytes(buf, strmsz);
	if ( szread != strmsz )
	  throw LBException("Failed to read complete object stream from socket");

	string st(buf, strmsz);
	istringstream ist(st);
	string errs;
	LBObject *obj = LBObject::instanceFromStream(ist, errs);
	if ( obj )
	  return obj;
	throw LBException("Corrupted object stream from socket: "+errs);
      }
    throw LBException("Can not read from invalid socket: "+_socketimp->_errmsg);
    return 0;
  }

  bool  Socket::readChar(char& c)
  {
    return _socketimp->readnbytes( &c, 1) == 1 ;
  }
	
  void Socket::put(const LBObject& obj)
  {
    if ( _socketimp->isValid() )
      {
	string objstr = obj.toString();
	int ntowrite = objstr.size();
	int n = _socketimp->writenbytes(objstr.c_str(), ntowrite);
	if ( n != ntowrite )
	  throw LBException("Failed to write object to socket");

	return;
      }
    throw LBException("Write to invalid socket");
    return;
  }
	
  void Socket::putLine(const LBObject& obj)
  {
    if ( _socketimp->isValid() )
      {
	string objstr = obj.toString();
	objstr += '\n';
	int ntowrite = objstr.size();
	int n = _socketimp->writenbytes( objstr.c_str(), ntowrite);
	if ( n != ntowrite )
	  {
	    _socketimp->invalidate();
	    throw LBException("Failed to write object to socket");
	  }
	return;
      }
    throw LBException("Write to invalid socket");
    return;
  }

  bool Socket::readLine(string& aline)
  {
    char c;
    while ( readChar(c) && c != '\n' )
      aline += c;
    return true;
  }

  void Socket::close()   
  {
    if ( _socketimp->isValid() )
      {
	::close(_socketimp->_socketid);
	_socketimp->_socketid = -1;
      }
  }

  bool Socket::isValid() const
  {
    return _socketimp->isValid();
  }

  const string& Socket::errmsg() const
  {
    return _socketimp->_errmsg;
  }
    
  LBEXPORT_MEMBER_FUNC(Net::Socket, luban_writeobj, "writeobj", "void writeobj(obj1, obj2, obj3....)" ); 
  LBObject* Socket::luban_writeobj(const LBVarArgs *args)
  {
    if ( ! args || args->numArgs() == 0 )
      throw LBException("Nothing to write to socket");
    for( int i=0; i<args->numArgs(); i++)
      {
	const LBObject *obj = args->getArg(i);
	if ( !obj )
	  throw LBException("Can not write single null value to socket");
	
	if ( ! writeobj( *obj ) )
	  throw LBException("Failed to write to socket");

      }
    return 0;
  }
    
  LBEXPORT_MEMBER_FUNC(Net::Socket, luban_put, "write", "void write(obj1, obj2, obj3....)" ); 
  LBObject* Socket::luban_put(const LBVarArgs *args)
  {
    if ( ! args || args->numArgs() == 0 )
      throw LBException("Nothing to write to socket");
    for( int i=0; i<args->numArgs(); i++)
      {
	const LBObject *obj = args->getArg(i);
	if ( !obj )
	  throw LBException("Can not write single null value to socket");
	put( *obj );
      }
    return 0;
  }
    
  LBEXPORT_MEMBER_FUNC(Net::Socket, luban_putline, "writeline", "void writeline(obj1, obj2, obj3....)" ); 
  LBObject* Socket::luban_putline(const LBVarArgs *args)
  {
    if ( ! args || args->numArgs() == 0 )
      throw LBException("Nothing to write to socket");
    for( int i=0; i<args->numArgs(); i++)
      {
	const LBObject *obj = args->getArg(i);
	if ( !obj )
	  throw LBException("Can not write single null value to socket");
	putLine( *obj );
      }
    return 0;
  }

  LBEXPORT_MEMBER_FUNC(Net::Socket, luban_get, "read", "string read(int n=1)" ); 
  LBObject* Socket::luban_get(const LBVarArgs *args)
  {
    int toread = 1;
    if ( args && args->numArgs() ) 
      {
	if ( args->numArgs() != 1 )
	  throw LBException("file::read function take one integer argument");

	const LBInt *n = dynamic_cast<const LBInt*>(args->getArg(0));
	if ( ! n )
	  throw LBException("socket::read function take one integer argument");	      
	toread = int(*n);
      }

    char c;
    string s;
    while ( --toread >= 0 && readChar(c) )
      s += c;
    return new LBString(s);
  }

  LBEXPORT_MEMBER_FUNC(Net::Socket, luban_getline, "readline", "string readline()" ); 
  LBObject* Socket::luban_getline(const LBVarArgs *args)
  {
    if ( args && args->numArgs() )
      {
	throw LBException("socket::readline() does not take arguments");
	return 0;
      }
    string ln;
    if ( readLine(ln) )
      return new LBString(ln);
    throw LBException("Failed to readline from socket");
    return 0;
  }


  LBEXPORT_MEMBER_FUNC(Net::Socket, luban_readobj, "readobj", "object readobj()" ); 
  LBObject* Socket::luban_readobj(const LBVarArgs *args)
  {
    if ( args && args->numArgs() )
      {
	throw LBException("socket::readobj() function does not take arguments");
	return 0;
      }
    return readobj();
  }

  LBEXPORT_MEMBER_FUNC(Net::Socket, luban_close, "close", "void close()" ); 
  LBObject* Socket::luban_close(const LBVarArgs *args)
  {
    if ( args && args->numArgs() )
      {
	throw LBException("socket::close() function does not take arguments");
	return 0;
      }
    close();
    return 0;
  }



  // helper class
  Socket::SocketImp::SocketImp(const string& hostname, int port)
    : _host(hostname), _port(port), _errmsg(), _socketid(-1)
  {}

  Socket::SocketImp::SocketImp(int socketid, const string& hostname, int port)
    : _host(hostname), _port(port), _errmsg(), _socketid(socketid)
  {}

  Socket::SocketImp::~SocketImp()
  {
    if ( _socketid != -1 )
      ::close(_socketid);
  }

  void Socket::SocketImp::invalidate()
  {
    if ( _socketid != -1 )
      {
	::close(_socketid);
	_socketid = -1;
      }
  }

  bool Socket::SocketImp::isValid() const
  {
    return _socketid != -1;
  }


  int Socket::SocketImp::readnbytes(char *ptr, int n)
  {
    if ( !isValid() )
      return 0;

    int ntoread=n;
    int oneread = 0;

    while ( ntoread > 0 )
      {
	int oneread = ::read(_socketid, ptr, ntoread);
	if ( oneread < 0 )
	  {
	    invalidate();
	    return oneread;
	  }
	if ( oneread == 0 )
	  {
	    invalidate();
	    return n-ntoread;
	  }
	ntoread -= oneread;
	ptr += oneread;
      }
    return n-ntoread;
  }

  int Socket::SocketImp::writenbytes(const char *ptr, int n)
  {
    if ( !isValid() )
      return 0;

    // ignore EPIPE signal
    void ( *oldhndl)(int);
    oldhndl = signal(SIGPIPE, SIG_IGN);

    int ntowrite=n;
    int onewrite = 0;

    while ( ntowrite > 0 )
      {
	int onewrite = ::write(_socketid, ptr, ntowrite);
	if ( onewrite <= 0 )
	  {
	    invalidate();
	    signal(SIGPIPE, oldhndl);
	    return n-ntowrite;
	  }
	ntowrite -= onewrite;
	ptr += onewrite;
      }
    signal(SIGPIPE, oldhndl);
    return n-ntowrite;
  }

}
