/*
 *  Copyright © 2003-2015 Amichai Rothman
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 *  For additional info see http://www.freeutils.net/source/
 */

package net.freeutils.net;

import java.io.*;
import java.net.*;
import java.nio.channels.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * {@code NetUtils} is a network utility class.
 *
 * @author Amichai Rothman
 */
public class NetUtils {

    /**
     * Closes all given Closeables in their natural order,
     * silently ignoring nulls and thrown exceptions.
     *
     * @param closeables the Closeables to close
     */
    public static void closeSilently(Closeable... closeables) {
        for (Closeable c : closeables) {
            try {
                if (c != null)
                    c.close();
            } catch (IOException ignore) {}
        }
    }

    /**
     * Closes the given socket if it is not null.
     *
     * @param sock the socket to close
     */
    public static void close(Socket sock) {
        try {
            if (sock != null)
                sock.close();
        } catch (IOException ioe) {
            System.err.println("Error closing socket");
            ioe.printStackTrace();
        }
    }

    /**
     * Connects to a remote host on a given port, and iterates over pairs of input
     * and output files, sending each input file in turn and saving the corresponding
     * response to the output file.
     *
     * @param host the name of the remote host
     * @param port the port on the remote host to connect to
     * @param infiles the names of files containing data to be sent
     * @param outfiles the names of files where responses will be saved
     * @throws IOException if an IO error occurs
     */
    public static void doSession(String host, int port, String[] infiles, String[] outfiles) throws IOException {
        Socket sock = null;
        try {
            // open socket
            System.out.println("Opening socket to " + host + ':' + port);
            sock = new Socket(host, port);
            sock.setSoTimeout(30000);
            // send and receive all files in session
            for (int i = 0; i < infiles.length; i++) {
                System.out.println("Sending " + infiles[i] + " receiving " + outfiles[i]);
                send(sock, infiles[i], outfiles[i]);
            }
            System.out.println("Done.");
        } finally {
            close(sock);
        }
    }

    /**
     * Downloads a file over HTTP.
     *
     * @param url the full HTTP url to be downloaded
     * @param outfile the name of the file prefix to which the downloaded
     *        data is written
     * @throws IOException if an IO error occurs
     */
    public static void doSession(String url, String outfile) throws IOException {
        Socket sock = null;
        FileOutputStream out = null;
        FileInputStream in = null;
        try {
            // open socket
            if (url.indexOf("://") < 0)
                url = "http://" + url;
            URL u = new URL(url);
            int port = u.getPort();
            System.out.println("Opening socket to " + u.getHost() + ':' + port);
            sock = new Socket(u.getHost(), port);
            sock.setSoTimeout(30000);
            // send request and receive response
            StringBuilder s = new StringBuilder();
            s.append("GET ").append(u.getPath()).append(" HTTP/1.0\r\n")
                .append("Host: ").append(u.getHost()).append("\r\nConnection: close\r\n\r\n\r\n");
            System.out.println("sending:\n" + s);
            ByteArrayInputStream bais = new ByteArrayInputStream(s.toString().getBytes());
            out = new FileOutputStream(outfile + ".response");
            send(sock, bais, out, false);
            closeSilently(out);
            splitHTTP(outfile + ".response", outfile + ".headers.txt", outfile + ".body");
            System.out.println("Done");
        } finally {
            close(sock);
            closeSilently(in, out);
        }
    }

    /**
     * Splits a file containing a full HTTP response into two files,
     * one containing the headers and the other containing the body.
     *
     * @param infile the file containing a full HTTP response
     * @param headerfile the file to which the headers will be written
     * @param bodyfile the file to which the body will be written
     * @throws IOException if an IO error occurs
     */
    public static void splitHTTP(String infile, String headerfile, String bodyfile) throws IOException {
        // split headers and content into separate files
        InputStream in = null;
        OutputStream out = null;
        final byte[] separator = { '\r', '\n', '\r', '\n' };
        try {
            // open input file
            in = new BufferedInputStream(new FileInputStream(infile));
            // write headers
            int c = -1;
            int sep = 0;
            out = new BufferedOutputStream(new FileOutputStream(headerfile));
            while (sep < separator.length && ((c = in.read()) != -1)) {
                out.write(c);
                sep = c == separator[sep] ? sep + 1 : 0;
            }
            closeSilently(out);
            // write body
            if (c != -1) {
                out = new BufferedOutputStream(new FileOutputStream(bodyfile));
                while ((c = in.read()) != -1)
                    out.write(c);
            }
        } finally {
            closeSilently(in, out);
        }
    }

    /**
     * Sends a file's content to a given socket, saving the response to a file.
     *
     * @param sock the socket to communicate through
     * @param infile the name of a file containing data to be sent
     * @param outfile the name of a file where the response will be saved
     * @throws IOException if an IO error occurs
     */
    public static void send(Socket sock, String infile, String outfile) throws IOException {
        InputStream in = null;
        OutputStream out = null;
        try {
            in = infile.equalsIgnoreCase("stdin") ? System.in : new FileInputStream(infile);
            out = outfile.equalsIgnoreCase("stdout") ? System.out : new FileOutputStream(outfile);
            send(sock, in, out, in == System.in);
        } finally {
            closeSilently(in, out);
        }
    }

    /**
     * Sends an input stream's content to a given socket,
     * saving the response to an output stream.
     *
     * @param sock the socket to communicate through
     * @param is the input stream containing data to be sent
     * @param os the output stream to which the response will be written
     * @param flushEOL if true, sent data is flushed when an EOL ('\n')
     *        character is encountered
     * @throws IOException if an IO error occurs
     */
    public static void send(Socket sock, InputStream is, OutputStream os, boolean flushEOL) throws IOException {
        BufferedInputStream ins = null;
        BufferedOutputStream outs = null;
        BufferedInputStream in = null;
        BufferedOutputStream out = null;
        try {
            // initialize
            in = new BufferedInputStream(sock.getInputStream());
            out = new BufferedOutputStream(sock.getOutputStream());
            ins = new BufferedInputStream(is);
            outs = new BufferedOutputStream(os);
            int count;
            byte[] data = new byte[4096];
            // send data
            while ((count = ins.read(data)) != -1) {
                out.write(data, 0, count);
                if (flushEOL) {
                    for (int i = 0; i < count; i++)
                        if (data[i] == '\n')
                            out.flush();
                }
            }
            out.flush();
            // receive response
            while ((count = in.read(data)) != -1)
                outs.write(data, 0, count);
        } finally {
            closeSilently(in, out);
        }

    }

    /**
     * The {@code Scanner} class scans a range of ports.
     */
    public static class Scanner implements Runnable {

        InetAddress address;
        int minport;
        int maxport;
        int concurrent;
        AtomicInteger nextPort;
        Collection<Integer> ports;

        /**
         * Attempts to establish a socket connection to a given host on a given port.
         *
         * @param address the address of the host
         * @param port the port to scan on host
         * @return true if socket is open for communication
         */
        public static boolean canConnect(InetAddress address, int port) {
            Socket sock = null;
            boolean success;
            try {
                // attempt to open a socket
                sock = new Socket();
                sock.connect(new InetSocketAddress(address, port), 10000);
                success = true;
            } catch (IOException ioe) {
                success = false;
            } finally {
                close(sock);
            }
            return success;
        }

        /**
         * Constructs a Scanner that can scan a range of ports on a given host.
         *
         * @param host the name of the host to scan
         * @param minport the minimum port number of the port range to scan;
         *                if negative, defaults to 1
         * @param maxport the maximum port number of the port range to scan;
         *                if negative, defaults to 1024
         * @param concurrent the number of concurrent ports to scan at a time (1024 max);
         *                if less than one or greater than the number of ports to scan,
         *                defaults to the number of ports in the range
         * @throws UnknownHostException if the host cannot be resolved
         */
        public Scanner(String host, int minport, int maxport, int concurrent) throws UnknownHostException {
            if (minport < 0) minport = 1;
            if (maxport < 0) maxport = 1024;
            int count = maxport - minport + 1;
            if (concurrent < 1 || concurrent > count) concurrent = count;
            if (concurrent > 1024) concurrent = 1024;

            // resolve the host address
            this.address = InetAddress.getByName(host);
            this.minport = minport;
            this.maxport = maxport;
            this.concurrent = concurrent;
            this.nextPort = new AtomicInteger(minport);
            this.ports = new ConcurrentLinkedQueue<Integer>();

            System.out.println("Scanning host: " + address);
        }

        /**
         * Returns the successfully opened ports after a scan is complete.
         *
         * @return the successfully opened port numbers
         */
        public Integer[] getOpenPorts() {
            return ports.toArray(new Integer[ports.size()]);
        }

        /**
         * Scans a range of ports on a given host, attempting to establish socket connections.
         * The scan is performed asynchronously, using an NIO selector and channels on a single
         * thread.
         */
        public void scanAsync() {
            // only one thread here, but keeping it thread-safe for educational purposes :-)
            Semaphore sem = new Semaphore(concurrent);
            ConcurrentMap<Channel, Integer> data = new ConcurrentHashMap<Channel, Integer>();
            Selector sel;
            try {
                sel = Selector.open();
                int port = -1;
                while (port <= maxport || !data.isEmpty()) {
                    // throttle concurrent connection attempts
                    while (sem.tryAcquire() && (port = nextPort.getAndIncrement()) <= maxport) {
                        try {
                            // start new connection
                            SocketChannel chan = SocketChannel.open();
                            chan.configureBlocking(false);
                            chan.socket().setSoTimeout(10000);
                            chan.connect(new InetSocketAddress(address, port));
                            chan.register(sel, SelectionKey.OP_CONNECT);
                            data.put(chan, port);
                        } catch (IOException e) {
                            System.err.println("error starting connection to port " + port);
                            sem.release();
                        }
                    }
                    // wait for events
                    sel.select();
                    // handle events
                    for (Iterator<SelectionKey> it = sel.selectedKeys().iterator(); it.hasNext(); ) {
                        SelectionKey key = it.next();
                        it.remove();
                        SocketChannel chan = (SocketChannel)key.channel();

                        if (!key.isValid()) {
                            chan.close();
                            key.cancel();
                            sem.release();
                            data.remove(chan);
                        } else if (key.isConnectable()) {
                            int chport = data.get(chan);
                            try {
                                if (chan.finishConnect()) {
                                    // connect succeeded
                                    ports.add(new Integer(chport));
                                    System.out.println("Port " + chport + " scan successful");
                                }
                            } catch (IOException e) {
                                // connect failed
                                System.out.println("Can't connect to port " + chport);
                            } finally {
                                chan.close();
                                key.cancel();
                                sem.release();
                                data.remove(chan);
                            }
                        }
                    }
                }
            } catch (IOException ioe) {
                ioe.printStackTrace();
            }
        }

        /**
         * Scans a range of ports on a given host, attempting to establish socket connections.
         * The scan is performed synchronously, using a separate blocking thread for each
         * concurrently attempted connection.
         */
        public void scan() {
            Thread[] scanners = new Thread[concurrent];
            // start threads
            for (int i = 0; i < concurrent; i++) {
                scanners[i] = new Thread(this, "Scanner-" + i);
                scanners[i].start();
            }
            // wait for all threads to finish
            for (int i = 0; i < concurrent; i++) {
                try {
                    scanners[i].join();
                } catch (InterruptedException ie) {}
            }
        }

        @Override
        public void run() {
            int port;
            while ((port = nextPort.getAndIncrement()) <= maxport) {
                if (canConnect(address, port)) {
                    ports.add(new Integer(port));
                    System.out.println("Port " + port + " scan successful");
                } else {
                    System.out.println("Can't connect to port " + port);
                }
                Thread.yield(); // yield CPU time for better user responsiveness
            }
        }
    }

    /**
     * The main application entry point.
     *
     * @param args the command line arguments
     */
    public static void main(String[] args) {
        long startTime = System.currentTimeMillis();
        try {
            if (args.length < 1) {
                System.out.println("Usage: java NetUtils [get|url|scan]\n\r");
            } else if (args[0].equalsIgnoreCase("get")) {
                if (args.length != 5) {
                    System.out.println("\n\rUsage: java NetUtils get <host> <port> <infile> <outfile>\n\r");
                    System.out.println("\t<host> the host to communicate with.");
                    System.out.println("\t<port> the port on which to communicate.");
                    System.out.println("\t<infile> a file containing data to send.");
                    System.out.println("\t<outfile> a file to output response to.");
                } else {
                    doSession(args[1],
                        Integer.parseInt(args[2]),
                        new String[] { args[3] },
                        new String[] { args[4] }
                    );
                }
            } else if (args[0].equalsIgnoreCase("url")) {
                if (args.length != 3) {
                    System.out.println("\n\rUsage: java NetUtils url <url> <outfile>\n\r");
                    System.out.println("\t<url> the URL to download.");
                    System.out.println("\t<outfile> a file to output response to.");
                } else {
                    doSession(args[1], args[2]);
                }
            } else if (args[0].equalsIgnoreCase("scan")) {
                if (args.length < 2) {
                    System.out.println("\n\rUsage: java NetUtils scan <host> [minport] [maxport] [threads]\n\r");
                    System.out.println("\t<host> the address of the host to scan.");
                    System.out.println("\t[minport] the port to start scanning with (default 1).");
                    System.out.println("\t[maxport] the port to end scanning with (default 1024).");
                    System.out.println("\t[concurrent] the number of concurrent ports to scan at a time (1024 max and default).");
                } else {
                    String host = args[1];
                    int minport = args.length > 2 ? Integer.parseInt(args[2]) : -1;
                    int maxport = args.length > 3 ? Integer.parseInt(args[3]) : -1;
                    int concurrent = args.length > 4 ? Integer.parseInt(args[4]) : -1;
                    long start = System.currentTimeMillis();
                    Scanner scanner = new Scanner(host, minport, maxport, concurrent);
                    scanner.scanAsync();
                    Integer[] ports = scanner.getOpenPorts();
                    long end = System.currentTimeMillis();
                    System.out.println("Scanned ports " + minport + "-" + maxport
                        + " on " + host + " in " + (end - start) + " milliseconds.");
                    System.out.println("\r\nSuccessfully connected to ports: "
                        + Arrays.asList(ports));
                }
            }
            long endTime = System.currentTimeMillis();
            System.out.println("Finished after " + (endTime-startTime) + " millis");
        } catch (Exception e) {
            e.printStackTrace();
        }

        System.exit(0);
    }

}
