/*
 * Copyright 2013 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package org.jboss.netty.handler.codec.spdy;

import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.handler.codec.embedder.DecoderEmbedder;
import org.junit.Test;

import java.util.List;
import java.util.Map;

import static org.jboss.netty.handler.codec.spdy.SpdyCodecUtil.*;
import static org.junit.Assert.*;

public class SpdySessionHandlerTest {

    private static final int closeSignal = SPDY_SETTINGS_MAX_ID;
    private static final SpdySettingsFrame closeMessage = new DefaultSpdySettingsFrame();

    static {
        closeMessage.setValue(closeSignal, 0);
    }

    private static void assertDataFrame(Object msg, int streamId, boolean last) {
        assertNotNull(msg);
        assertTrue(msg instanceof SpdyDataFrame);
        SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg;
        assertEquals(spdyDataFrame.getStreamId(), streamId);
        assertEquals(spdyDataFrame.isLast(), last);
    }

    private static void assertSynReply(Object msg, int streamId, boolean last, SpdyHeaders headers) {
        assertNotNull(msg);
        assertTrue(msg instanceof SpdySynReplyFrame);
        assertHeaders(msg, streamId, last, headers);
    }

    private static void assertRstStream(Object msg, int streamId, SpdyStreamStatus status) {
        assertNotNull(msg);
        assertTrue(msg instanceof SpdyRstStreamFrame);
        SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg;
        assertEquals(spdyRstStreamFrame.getStreamId(), streamId);
        assertEquals(spdyRstStreamFrame.getStatus(), status);
    }

    private static void assertPing(Object msg, int id) {
        assertNotNull(msg);
        assertTrue(msg instanceof SpdyPingFrame);
        SpdyPingFrame spdyPingFrame = (SpdyPingFrame) msg;
        assertEquals(spdyPingFrame.getId(), id);
    }

    private static void assertGoAway(Object msg, int lastGoodStreamId) {
        assertNotNull(msg);
        assertTrue(msg instanceof SpdyGoAwayFrame);
        SpdyGoAwayFrame spdyGoAwayFrame = (SpdyGoAwayFrame) msg;
        assertEquals(spdyGoAwayFrame.getLastGoodStreamId(), lastGoodStreamId);
    }

    private static void assertHeaders(Object msg, int streamId, boolean last, SpdyHeaders headers) {
        assertNotNull(msg);
        assertTrue(msg instanceof SpdyHeadersFrame);
        SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg;
        assertEquals(spdyHeadersFrame.getStreamId(), streamId);
        assertEquals(spdyHeadersFrame.isLast(), last);
        for (String name: headers.names()) {
            List<String> expectedValues = headers.getAll(name);
            List<String> receivedValues = spdyHeadersFrame.headers().getAll(name);
            assertTrue(receivedValues.containsAll(expectedValues));
            receivedValues.removeAll(expectedValues);
            assertTrue(receivedValues.isEmpty());
            spdyHeadersFrame.headers().remove(name);
        }
        assertTrue(spdyHeadersFrame.headers().isEmpty());
    }

    private static void testSpdySessionHandler(SpdyVersion spdyVersion, boolean server) {
        DecoderEmbedder<Object> sessionHandler =
            new DecoderEmbedder<Object>(
                    new SpdySessionHandler(spdyVersion, server), new EchoHandler(closeSignal, server));
        sessionHandler.pollAll();

        int localStreamId = server ? 1 : 2;
        int remoteStreamId = server ? 2 : 1;

        SpdySynStreamFrame spdySynStreamFrame = new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0);
        spdySynStreamFrame.headers().set("Compression", "test");

        SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId);
        spdyDataFrame.setLast(true);

        // Check if session handler returns INVALID_STREAM if it receives
        // a data frame for a Stream-ID that is not open
        sessionHandler.offer(new DefaultSpdyDataFrame(localStreamId));
        assertRstStream(sessionHandler.poll(), localStreamId, SpdyStreamStatus.INVALID_STREAM);
        assertNull(sessionHandler.peek());

        // Check if session handler returns PROTOCOL_ERROR if it receives
        // a data frame for a Stream-ID before receiving a SYN_REPLY frame
        sessionHandler.offer(new DefaultSpdyDataFrame(remoteStreamId));
        assertRstStream(sessionHandler.poll(), remoteStreamId, SpdyStreamStatus.PROTOCOL_ERROR);
        assertNull(sessionHandler.peek());
        remoteStreamId += 2;

        // Check if session handler returns PROTOCOL_ERROR if it receives
        // multiple SYN_REPLY frames for the same active Stream-ID
        sessionHandler.offer(new DefaultSpdySynReplyFrame(remoteStreamId));
        assertNull(sessionHandler.peek());
        sessionHandler.offer(new DefaultSpdySynReplyFrame(remoteStreamId));
        assertRstStream(sessionHandler.poll(), remoteStreamId, SpdyStreamStatus.STREAM_IN_USE);
        assertNull(sessionHandler.peek());
        remoteStreamId += 2;

        // Check if frame codec correctly compresses/uncompresses headers
        sessionHandler.offer(spdySynStreamFrame);
        assertSynReply(sessionHandler.poll(), localStreamId, false, spdySynStreamFrame.headers());
        assertNull(sessionHandler.peek());
        SpdyHeadersFrame spdyHeadersFrame = new DefaultSpdyHeadersFrame(localStreamId);
        spdyHeadersFrame.headers().add("HEADER","test1");
        spdyHeadersFrame.headers().add("HEADER","test2");
        sessionHandler.offer(spdyHeadersFrame);
        assertHeaders(sessionHandler.poll(), localStreamId, false, spdyHeadersFrame.headers());
        assertNull(sessionHandler.peek());
        localStreamId += 2;

        // Check if session handler closed the streams using the number
        // of concurrent streams and that it returns REFUSED_STREAM
        // if it receives a SYN_STREAM frame it does not wish to accept
        spdySynStreamFrame.setStreamId(localStreamId);
        spdySynStreamFrame.setLast(true);
        spdySynStreamFrame.setUnidirectional(true);
        sessionHandler.offer(spdySynStreamFrame);
        assertRstStream(sessionHandler.poll(), localStreamId, SpdyStreamStatus.REFUSED_STREAM);
        assertNull(sessionHandler.peek());

        // Check if session handler rejects HEADERS for closed streams
        int testStreamId = spdyDataFrame.getStreamId();
        sessionHandler.offer(spdyDataFrame);
        assertDataFrame(sessionHandler.poll(), testStreamId, spdyDataFrame.isLast());
        assertNull(sessionHandler.peek());
        spdyHeadersFrame.setStreamId(testStreamId);
        sessionHandler.offer(spdyHeadersFrame);
        assertRstStream(sessionHandler.poll(), testStreamId, SpdyStreamStatus.INVALID_STREAM);
        assertNull(sessionHandler.peek());

        // Check if session handler drops active streams if it receives
        // a RST_STREAM frame for that Stream-ID
        sessionHandler.offer(new DefaultSpdyRstStreamFrame(remoteStreamId, 3));
        assertNull(sessionHandler.peek());
        remoteStreamId += 2;

        // Check if session handler honors UNIDIRECTIONAL streams
        spdySynStreamFrame.setLast(false);
        sessionHandler.offer(spdySynStreamFrame);
        assertNull(sessionHandler.peek());
        spdySynStreamFrame.setUnidirectional(false);

        // Check if session handler returns PROTOCOL_ERROR if it receives
        // multiple SYN_STREAM frames for the same active Stream-ID
        sessionHandler.offer(spdySynStreamFrame);
        assertRstStream(sessionHandler.poll(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR);
        assertNull(sessionHandler.peek());
        localStreamId += 2;

        // Check if session handler returns PROTOCOL_ERROR if it receives
        // a SYN_STREAM frame with an invalid Stream-ID
        spdySynStreamFrame.setStreamId(localStreamId - 1);
        sessionHandler.offer(spdySynStreamFrame);
        assertRstStream(sessionHandler.poll(), localStreamId - 1, SpdyStreamStatus.PROTOCOL_ERROR);
        assertNull(sessionHandler.peek());
        spdySynStreamFrame.setStreamId(localStreamId);

        // Check if session handler returns PROTOCOL_ERROR if it receives
        // an invalid HEADERS frame
        spdyHeadersFrame.setStreamId(localStreamId);
        spdyHeadersFrame.setInvalid();
        sessionHandler.offer(spdyHeadersFrame);
        assertRstStream(sessionHandler.poll(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR);
        assertNull(sessionHandler.peek());

        sessionHandler.finish();
    }

    @Test
    public void testSpdyClientSessionHandler() {
        testSpdySessionHandler(SpdyVersion.SPDY_3, false);
        testSpdySessionHandler(SpdyVersion.SPDY_3_1, false);
    }

    @Test
    public void testSpdyClientSessionHandlerPing() {
        testSpdySessionHandlerPing(SpdyVersion.SPDY_3, false);
        testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, false);
    }

    @Test
    public void testSpdyClientSessionHandlerGoAway() {
        testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3, false);
        testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, false);
    }

    @Test
    public void testSpdyServerSessionHandler() {
        testSpdySessionHandler(SpdyVersion.SPDY_3, true);
        testSpdySessionHandler(SpdyVersion.SPDY_3_1, true);
    }

    @Test
    public void testSpdyServerSessionHandlerPing() {
        testSpdySessionHandlerPing(SpdyVersion.SPDY_3, true);
        testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, true);
    }

    @Test
    public void testSpdyServerSessionHandlerGoAway() {
        testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3, true);
        testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, true);
    }

    private static void testSpdySessionHandlerPing(SpdyVersion spdyVersion, boolean server) {
        DecoderEmbedder<Object> sessionHandler =
            new DecoderEmbedder<Object>(
                    new SpdySessionHandler(spdyVersion, server), new EchoHandler(closeSignal, server));
        sessionHandler.pollAll();

        int localStreamId = server ? 1 : 2;
        int remoteStreamId = server ? 2 : 1;

        SpdyPingFrame localPingFrame = new DefaultSpdyPingFrame(localStreamId);
        SpdyPingFrame remotePingFrame = new DefaultSpdyPingFrame(remoteStreamId);

        // Check if session handler returns identical local PINGs
        sessionHandler.offer(localPingFrame);
        assertPing(sessionHandler.poll(), localPingFrame.getId());
        assertNull(sessionHandler.peek());

        // Check if session handler ignores un-initiated remote PINGs
        sessionHandler.offer(remotePingFrame);
        assertNull(sessionHandler.peek());

        sessionHandler.finish();
    }

    private static void testSpdySessionHandlerGoAway(SpdyVersion spdyVersion, boolean server) {
        DecoderEmbedder<Object> sessionHandler =
            new DecoderEmbedder<Object>(
                    new SpdySessionHandler(spdyVersion, server), new EchoHandler(closeSignal, server));
        sessionHandler.pollAll();

        int localStreamId = server ? 1 : 2;
        int remoteStreamId = server ? 2 : 1;

        SpdySynStreamFrame spdySynStreamFrame = new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0);
        spdySynStreamFrame.headers().set("Compression", "test");

        SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId);
        spdyDataFrame.setLast(true);

        // Send an initial request
        sessionHandler.offer(spdySynStreamFrame);
        assertSynReply(sessionHandler.poll(), localStreamId, false, spdySynStreamFrame.headers());
        assertNull(sessionHandler.peek());
        sessionHandler.offer(spdyDataFrame);
        assertDataFrame(sessionHandler.poll(), localStreamId, true);
        assertNull(sessionHandler.peek());

        // Check if session handler sends a GOAWAY frame when closing
        sessionHandler.offer(closeMessage);
        assertGoAway(sessionHandler.poll(), localStreamId);
        assertNull(sessionHandler.peek());
        localStreamId += 2;

        // Check if session handler returns REFUSED_STREAM if it receives
        // SYN_STREAM frames after sending a GOAWAY frame
        spdySynStreamFrame.setStreamId(localStreamId);
        sessionHandler.offer(spdySynStreamFrame);
        assertRstStream(sessionHandler.poll(), localStreamId, SpdyStreamStatus.REFUSED_STREAM);
        assertNull(sessionHandler.peek());

        // Check if session handler ignores Data frames after sending
        // a GOAWAY frame
        spdyDataFrame.setStreamId(localStreamId);
        sessionHandler.offer(spdyDataFrame);
        assertNull(sessionHandler.peek());

        sessionHandler.finish();
    }

    // Echo Handler opens 4 half-closed streams on session connection
    // and then sets the number of concurrent streams to 3
    private static class EchoHandler extends SimpleChannelUpstreamHandler {
        private final int closeSignal;
        private final boolean server;

        EchoHandler(int closeSignal, boolean server) {
            this.closeSignal = closeSignal;
            this.server = server;
        }

        @Override
        public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {

            // Initiate 4 new streams
            int streamId = server ? 2 : 1;
            SpdySynStreamFrame spdySynStreamFrame = new DefaultSpdySynStreamFrame(streamId, 0, (byte) 0);
            spdySynStreamFrame.setLast(true);
            Channels.write(e.getChannel(), spdySynStreamFrame);
            spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2);
            Channels.write(e.getChannel(), spdySynStreamFrame);
            spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2);
            Channels.write(e.getChannel(), spdySynStreamFrame);
            spdySynStreamFrame.setStreamId(spdySynStreamFrame.getStreamId() + 2);
            Channels.write(e.getChannel(), spdySynStreamFrame);

            // Limit the number of concurrent streams to 1
            SpdySettingsFrame spdySettingsFrame = new DefaultSpdySettingsFrame();
            spdySettingsFrame.setValue(SpdySettingsFrame.SETTINGS_MAX_CONCURRENT_STREAMS, 1);
            Channels.write(e.getChannel(), spdySettingsFrame);
        }

        @Override
        public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
                throws Exception {
            Object msg = e.getMessage();
            if (msg instanceof SpdySynStreamFrame) {

                SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg;

                int streamId = spdySynStreamFrame.getStreamId();
                SpdySynReplyFrame spdySynReplyFrame = new DefaultSpdySynReplyFrame(streamId);
                spdySynReplyFrame.setLast(spdySynStreamFrame.isLast());
                for (Map.Entry<String, String> entry: spdySynStreamFrame.headers()) {
                    spdySynReplyFrame.headers().add(entry.getKey(), entry.getValue());
                }

                Channels.write(e.getChannel(), spdySynReplyFrame, e.getRemoteAddress());
                return;
            }

            if (msg instanceof SpdySynReplyFrame) {
                return;
            }

            if (msg instanceof SpdyDataFrame ||
                msg instanceof SpdyPingFrame ||
                msg instanceof SpdyHeadersFrame) {

                Channels.write(e.getChannel(), msg, e.getRemoteAddress());
                return;
            }

            if (msg instanceof SpdySettingsFrame) {

                SpdySettingsFrame spdySettingsFrame = (SpdySettingsFrame) msg;
                if (spdySettingsFrame.isSet(closeSignal)) {
                    Channels.close(e.getChannel());
                }
            }
        }
    }
}
