001/*
002 * Copyright (c) 2013 Nu Echo Inc. All rights reserved.
003 */
004
005package com.nuecho.rivr.core.servlet;
006
007import java.io.*;
008import java.util.*;
009
010import javax.servlet.*;
011import javax.servlet.http.*;
012
013import org.slf4j.*;
014
015import com.nuecho.rivr.core.channel.*;
016import com.nuecho.rivr.core.channel.synchronous.*;
017import com.nuecho.rivr.core.channel.synchronous.step.*;
018import com.nuecho.rivr.core.dialogue.*;
019import com.nuecho.rivr.core.servlet.session.*;
020import com.nuecho.rivr.core.util.*;
021
022/**
023 * Abstract servlet interacting with a web client acting as the controller of a
024 * {@link SynchronousDialogueChannel}.
025 * <p>
026 * This abstract servlet must be extended in order to provide a specific
027 * implementation. For each session,
028 * <ol>
029 * <li>the servlet creates the {@link Session} and place it in the
030 * {@link SessionContainer}</li>
031 * <li>it creates a {@link Dialogue} with a {@link DialogueFactory}</li>
032 * <li>it creates a {@link DialogueContext} with a
033 * {@link DialogueContextFactory}</li>
034 * <li>it creates a {@link SynchronousDialogueChannel} and starts the dialogue
035 * upon initial HTTP request</li>
036 * <li>it renders the various {@link Step steps} from the dialogue channel into
037 * appropriate HTTP responses</li>
038 * <li>it translates HTTP requests into {@link InputTurn InputTurns} using the
039 * {@link InputTurnFactory}
040 * <li>once the dialogue is done, the servlet perform necessary clean-up.</li>
041 * </ol>
042 * <h3>init args</h3>
043 * <p>
044 * The following servlet initial arguments are supported:
045 * <dl>
046 * <dt>com.nuecho.rivr.core.dialogueTimeout</dt>
047 * <dd>Maximum time for dialogue to produce an {@link OutputTurn}. Value
048 * specified must be followed by unit (ms, s, m, h, d, y), e.g. <code>10s</code>
049 * for 10 seconds. Default value: <code>10 s</code></dd>
050 * <dt>com.nuecho.rivr.core.controllerTimeout</dt>
051 * <dd>Maximum time for controller to produce an {@link InputTurn}. Value
052 * specified must be followed by unit (ms, s, m, h, d, y), e.g. <code>10s</code>
053 * for 10 seconds. Default value: <code>5 m</code></dd>
054 * <dt>com.nuecho.rivr.core.sessionTimeout</dt>
055 * <dd>Maximum inactivity time for a session. Value specified must be followed
056 * by unit (ms, s, m, h, d, y), e.g. <code>10s</code> for 10 seconds. Default
057 * value: <code>30 m</code></dd>
058 * </dl>
059 * <dl>
060 * <dt>com.nuecho.rivr.core.sessionScanPeriod</dt>
061 * <dd>Time between each scan for dead sessions in the session container. Value
062 * specified must be followed by unit (ms, s, m, h, d, y), e.g. <code>10s</code>
063 * for 10 seconds. Default value: <code>2 m</code></dd>
064 * </dl>
065 * <dl>
066 * <dt>com.nuecho.rivr.core.webappServerSessionTrackingEnabled</dt>
067 * <dd>Whether a {@link javax.servlet.http.HttpSession} should be created for
068 * each dialogue or not. This is useful for load-balancers using JSESSIONID
069 * cookie to enforce server affinity (or stickyness). Value should be
070 * <code>true</code> or <code>false</code>. Default value:
071 * <code>true</code></dd>
072 * </dl>
073 *
074 * @param <F> type of {@link FirstTurn}
075 * @param <L> type of {@link LastTurn}
076 * @param <O> type of {@link OutputTurn}
077 * @param <I> type of {@link InputTurn}
078 * @param <C> type of {@link DialogueContext}
079 * @author Nu Echo Inc.
080 */
081public abstract class DialogueServlet<I extends InputTurn, O extends OutputTurn, F extends FirstTurn, L extends LastTurn, C extends DialogueContext<I, O>>
082        extends HttpServlet {
083
084    private static final String FALSE = "false";
085    private static final String TRUE = "true";
086
087    private static final String MDC_KEY_DIALOGUE_ID = "dialogueId";
088
089    private static final String SESSION_LOGGER_NAME = "com.nuecho.rivr.session";
090    private static final String DIALOGUE_LOGGER_NAME = "com.nuecho.rivr.dialogue";
091
092    private static final String SERVLET_LOGGER_NAME = "com.nuecho.rivr.servlet";
093    private static final String RESPONSES_LOGGER_NAME = "com.nuecho.rivr.servlet.responses";
094
095    private static final long serialVersionUID = 1L;
096    private static final String SESSION_CONTAINER_NAME = "com.nuecho.rivr.sessionContainer";
097
098    private static final String INITIAL_ARGUMENT_PREFIX = "com.nuecho.rivr.core.";
099    private static final String INITIAL_ARGUMENT_DIALOGUE_TIMEOUT = INITIAL_ARGUMENT_PREFIX + "dialogueTimeout";
100    private static final String INITIAL_ARGUMENT_SESSION_TIMEOUT = INITIAL_ARGUMENT_PREFIX + "sessionTimeout";
101    private static final String INITIAL_ARGUMENT_SESSION_SCAN_PERIOD = INITIAL_ARGUMENT_PREFIX + "sessionScanPeriod";
102    private static final String INITIAL_ARGUMENT_CONTROLLER_TIMEOUT = INITIAL_ARGUMENT_PREFIX + "controllerTimeout";
103
104    private static final String INITIAL_ARGUMENT_ENABLE_WEBAPP_SERVER_SESSION_TRACKING = INITIAL_ARGUMENT_PREFIX
105                                                                                         + "webappServerSessionTrackingEnabled";
106
107    private ErrorHandler<L> mErrorHandler;
108    private DialogueFactory<I, O, F, L, C> mDialogueFactory;
109    private DialogueContextFactory<C, I, O> mDialogueContextFactory;
110    private ILoggerFactory mLoggerFactory;
111    private SessionContainer<I, O, F, L, C> mSessionContainer;
112    private InputTurnFactory<I, F> mInputTurnFactory;
113
114    private Duration mDialogueTimeout = Duration.seconds(10);
115    private Duration mControllerTimeout = Duration.minutes(5);
116
117    private Duration mSessionTimeout = Duration.minutes(30);
118    private Duration mSessionScanPeriod = Duration.minutes(2);
119
120    private boolean mWebappServerSessionTrackingEnabled = true;
121    private Logger mLogger;
122    private Logger mResponseLogger;
123
124    private boolean mDestroyed;
125
126    /**
127     * Performs initialization.
128     *
129     * @throws DialogueServletInitializationException when servlet can't be
130     *             initialized properly.
131     */
132    protected abstract void initDialogueServlet() throws DialogueServletInitializationException;
133
134    /**
135     * Performs shutdown.
136     */
137    protected abstract void destroyDialogueServlet();
138
139    /**
140     * Provides the {@link StepRenderer} appropriate for the context.
141     *
142     * @param request the request
143     * @param session the session
144     * @return the <code>StepRenderer</code> object.
145     */
146    protected abstract StepRenderer<I, O, L, C> getStepRenderer(HttpServletRequest request,
147                                                                Session<I, O, F, L, C> session);
148
149    /**
150     * Initializes the servlet. The first thing done in this method is to call
151     * {@link #initDialogueServlet()}. This method is called by the servlet
152     * container.
153     */
154    @Override
155    public final void init() throws ServletException {
156
157        Throwable initError = null;
158
159        try {
160            initDialogueServlet();
161        } catch (DialogueServletInitializationException exception) {
162            initError = exception;
163        }
164
165        if (mLoggerFactory == null) {
166            mLoggerFactory = LoggerFactory.getILoggerFactory();
167        }
168
169        mLogger = mLoggerFactory.getLogger(SERVLET_LOGGER_NAME);
170        mResponseLogger = mLoggerFactory.getLogger(RESPONSES_LOGGER_NAME);
171
172        if (initError != null) {
173            mLogger.error("Unable to initialize dialogue servlet.", initError);
174            destroy();
175            throw new ServletException("Unable to initialize dialogue servlet.", initError);
176        }
177
178        ensureFieldIsSet(mInputTurnFactory, "InputTurnFactory");
179        ensureFieldIsSet(mDialogueFactory, "DialogueFactory");
180        ensureFieldIsSet(mDialogueContextFactory, "DialogueContextFactory");
181        ensureFieldIsSet(mErrorHandler, "ErrorHandler");
182
183        Logger sessionContainerLogger = mLoggerFactory.getLogger(SESSION_LOGGER_NAME);
184
185        Duration sessionScanPeriod = getDuration(INITIAL_ARGUMENT_SESSION_SCAN_PERIOD);
186        if (sessionScanPeriod != null) {
187            setSessionScanPeriod(sessionScanPeriod);
188        }
189
190        Duration sessionTimeout = getDuration(INITIAL_ARGUMENT_SESSION_TIMEOUT);
191        if (sessionTimeout != null) {
192            setSessionTimeout(sessionTimeout);
193        }
194
195        mSessionContainer = new SessionContainer<I, O, F, L, C>(sessionContainerLogger,
196                                                                mSessionTimeout,
197                                                                mSessionScanPeriod,
198                                                                SESSION_CONTAINER_NAME);
199
200        Duration dialogueTimeout = getDuration(INITIAL_ARGUMENT_DIALOGUE_TIMEOUT);
201        if (dialogueTimeout != null) {
202            setDialogueTimeout(dialogueTimeout);
203        }
204
205        Duration controllerTimeout = getDuration(INITIAL_ARGUMENT_CONTROLLER_TIMEOUT);
206        if (controllerTimeout != null) {
207            setControllerTimeout(controllerTimeout);
208        }
209
210        Boolean enableWebappServerSessionTracking = getBoolean(INITIAL_ARGUMENT_ENABLE_WEBAPP_SERVER_SESSION_TRACKING);
211        if (enableWebappServerSessionTracking != null) {
212            setWebappServerSessionTrackingEnabled(enableWebappServerSessionTracking);
213        }
214
215        mLogger.info("Dialogue servlet initialized.");
216
217    }
218
219    /**
220     * Destroys the servlet. This methods calls
221     * {@link #destroyDialogueServlet()}. This method is called by the servlet
222     * container.
223     */
224    @Override
225    public final synchronized void destroy() {
226        if (mDestroyed) return;
227        if (mSessionContainer != null) {
228            mSessionContainer.stop();
229        }
230        destroyDialogueServlet();
231
232        mLogger.info("Dialogue servlet destroyed.");
233        mDestroyed = true;
234    }
235
236    private Duration getDuration(String key) throws ServletException {
237        ServletConfig servletConfig = getServletConfig();
238        String duration = servletConfig.getInitParameter(key);
239        if (duration == null) return null;
240        try {
241            return Duration.parse(duration);
242        } catch (IllegalArgumentException exception) {
243            throw new ServletException("Unable to parse duration for init-arg '" + key + "'", exception);
244        }
245    }
246
247    private Boolean getBoolean(String key) throws ServletException {
248        ServletConfig servletConfig = getServletConfig();
249        String booleanString = servletConfig.getInitParameter(key);
250        if (booleanString == null) return null;
251        if (booleanString.equalsIgnoreCase(TRUE)) return Boolean.TRUE;
252        if (booleanString.equalsIgnoreCase(FALSE)) return Boolean.FALSE;
253        throw new ServletException("Unable to parse boolean for init-arg '"
254                                   + key
255                                   + "'.  Should be '"
256                                   + TRUE
257                                   + "' of '"
258                                   + FALSE
259                                   + "' but not '"
260                                   + booleanString
261                                   + "'.");
262    }
263
264    private void ensureFieldIsSet(Object fieldValue, String fieldName) {
265        if (fieldValue == null) throw new IllegalStateException(fieldName + " is not set.");
266    }
267
268    @Override
269    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
270        process(request, response);
271    }
272
273    @Override
274    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
275        process(request, response);
276    }
277
278    public final ILoggerFactory getLoggerFactory() {
279        return mLoggerFactory;
280    }
281
282    protected void renderOutputTurn(O outputTurn,
283                                    HttpServletRequest request,
284                                    final HttpServletResponse response,
285                                    Session<I, O, F, L, C> session) throws IOException, StepRendererException {
286        ServletResponseContent responseContent = getStepRenderer(request, session).createDocumentForOutputTurn(outputTurn,
287                                                                                                               request,
288                                                                                                               response,
289                                                                                                               session.getDialogueContext());
290        commitToResponse(response, responseContent);
291    }
292
293    protected void renderLastTurn(L result,
294                                  HttpServletRequest request,
295                                  HttpServletResponse response,
296                                  Session<I, O, F, L, C> session) throws IOException, StepRendererException {
297        ServletResponseContent responseContent = getStepRenderer(request, session).createDocumentForLastTurn(result,
298                                                                                                             request,
299                                                                                                             response,
300                                                                                                             session.getDialogueContext());
301        commitToResponse(response, responseContent);
302        session.stop();
303    }
304
305    protected void renderError(Throwable error,
306                               HttpServletRequest request,
307                               HttpServletResponse response,
308                               Session<I, O, F, L, C> session) throws IOException, StepRendererException {
309
310        L fatalErrorTurn = mErrorHandler.handleError(error);
311
312        ServletResponseContent responseContent = getStepRenderer(request, session).createDocumentForLastTurn(fatalErrorTurn,
313                                                                                                             request,
314                                                                                                             response,
315                                                                                                             session.getDialogueContext());
316        commitToResponse(response, responseContent);
317        session.stop();
318    }
319
320    public final void setInputTurnFactory(InputTurnFactory<I, F> inputTurnFactory) {
321        Assert.notNull(inputTurnFactory, "inputTurnFactory");
322        mInputTurnFactory = inputTurnFactory;
323    }
324
325    public final void setDialogueFactory(DialogueFactory<I, O, F, L, C> dialogueFactory) {
326        Assert.notNull(dialogueFactory, "dialogueFactory");
327        mDialogueFactory = dialogueFactory;
328    }
329
330    public final void setDialogueContextFactory(DialogueContextFactory<C, I, O> dialogueContextFactory) {
331        Assert.notNull(dialogueContextFactory, "dialogueContextFactory");
332        mDialogueContextFactory = dialogueContextFactory;
333    }
334
335    public final void setLoggerFactory(ILoggerFactory loggerFactory) {
336        Assert.notNull(loggerFactory, "loggerFactory");
337        mLoggerFactory = loggerFactory;
338    }
339
340    /**
341     * Sets maximum duration the servlet thread can wait for the dialogue
342     * response.
343     *
344     * @param dialogueTimeout the timeout. Cannot be <code>null</code>. A value
345     *            of <code>Duration.ZERO</code> (or equivalent) means to wait
346     *            forever.
347     */
348    public final void setDialogueTimeout(Duration dialogueTimeout) {
349        Assert.notNull(dialogueTimeout, "dialogueTimeout");
350        mDialogueTimeout = dialogueTimeout;
351    }
352
353    /**
354     * Sets maximum duration the dialogue thread can wait for the controller
355     * response.
356     *
357     * @param controllerTimeout the timeout. Cannot be <code>null</code>. A
358     *            value of <code>Duration.ZERO</code> (or equivalent) means to
359     *            wait forever.
360     * @since 1.0.1
361     */
362    public final void setControllerTimeout(Duration controllerTimeout) {
363        Assert.notNull(controllerTimeout, "controllerTimeout");
364        mControllerTimeout = controllerTimeout;
365    }
366
367    public final void setSessionTimeout(Duration sessionTimeout) {
368        Assert.notNull(sessionTimeout, "sessionTimeout");
369        mSessionTimeout = sessionTimeout;
370    }
371
372    public final void setSessionScanPeriod(Duration sessionScanPeriod) {
373        Assert.notNull(sessionScanPeriod, "sessionScanPeriod");
374        mSessionScanPeriod = sessionScanPeriod;
375    }
376
377    public final void setErrorHandler(ErrorHandler<L> errorHandler) {
378        Assert.notNull(errorHandler, "errorHandler");
379        mErrorHandler = errorHandler;
380    }
381
382    /**
383     * Indicates if the servlet should create an HttpSession object for each
384     * dialogue. Note: Nothing is stored in the <code>HttpSession</code>.
385     * However, the creation of a session would force the web container to track
386     * the session using a cookie (JSESSIONID) or to do URL-rewriting. This is
387     * only relevant if there is more than one web container fronted by a load
388     * balancer.
389     *
390     * @param enableWebappServerSessionTracking true if HttpSession are to be
391     *            used for session tracking.
392     * @since 1.0.1
393     */
394    public final void setWebappServerSessionTrackingEnabled(boolean enableWebappServerSessionTracking) {
395        mWebappServerSessionTrackingEnabled = enableWebappServerSessionTracking;
396    }
397
398    private void process(HttpServletRequest request, HttpServletResponse response) throws ServletException {
399        Session<I, O, F, L, C> session;
400        try {
401            session = getSession(request);
402            MDC.put(MDC_KEY_DIALOGUE_ID, session.getId());
403            process(request, response, session);
404        } catch (SessionNotFoundException exception) {
405            throw new ServletException("Cannot find session.", exception);
406        } finally {
407            MDC.remove(MDC_KEY_DIALOGUE_ID);
408        }
409
410    }
411
412    private void process(HttpServletRequest request, HttpServletResponse response, Session<I, O, F, L, C> session)
413            throws ServletException {
414        try {
415
416            Step<O, L> step;
417            C dialogueContext = session.getDialogueContext();
418
419            try {
420                if (dialogueContext == null) {
421                    step = startDialogue(request, response, session);
422                } else {
423                    step = continueDialogue(request, response, session);
424                }
425            } catch (Timeout exception) {
426                renderError(exception, request, response, session);
427                return;
428            } catch (InterruptedException exception) {
429                Thread.currentThread().interrupt();
430                renderError(exception, request, response, session);
431                return;
432            }
433
434            if (step instanceof OutputTurnStep) {
435                OutputTurnStep<O, L> outputTurnStep = (OutputTurnStep<O, L>) step;
436                renderOutputTurn(outputTurnStep.getOutputTurn(), request, response, session);
437            } else if (step instanceof LastTurnStep) {
438                LastTurnStep<O, L> lastTurnStep = (LastTurnStep<O, L>) step;
439                renderLastTurn(lastTurnStep.getLastTurn(), request, response, session);
440            } else if (step instanceof ErrorStep) {
441                ErrorStep<O, L> errorStep = (ErrorStep<O, L>) step;
442                Throwable throwable = errorStep.getThrowable();
443                renderError(throwable, request, response, session);
444            }
445        } catch (Exception exception) {
446            throw new ServletException("Error while rendering step.", exception);
447        }
448    }
449
450    private Step<O, L> continueDialogue(HttpServletRequest request,
451                                        HttpServletResponse response,
452                                        Session<I, O, F, L, C> session) throws ServletException, Timeout,
453            InterruptedException {
454        Assert.notNull(session, "session");
455        I inputTurn = createInputTurn(request, response);
456        SynchronousDialogueChannel<I, O, F, L, C> dialogueChannel = session.getDialogueChannel();
457        Assert.notNull(dialogueChannel, "dialogueChannel");
458        return dialogueChannel.doTurn(inputTurn, mDialogueTimeout);
459    }
460
461    private Step<O, L> startDialogue(HttpServletRequest request,
462                                     HttpServletResponse response,
463                                     Session<I, O, F, L, C> session) throws ServletException, Timeout,
464            InterruptedException {
465        SynchronousDialogueChannel<I, O, F, L, C> dialogueChannel;
466        dialogueChannel = new SynchronousDialogueChannel<I, O, F, L, C>();
467        session.setDialogueChannel(dialogueChannel);
468
469        Logger logger = mLoggerFactory.getLogger(DIALOGUE_LOGGER_NAME);
470        dialogueChannel.setLogger(logger);
471
472        dialogueChannel.setDefaultReceiveFromControllerTimeout(mControllerTimeout);
473        dialogueChannel.setDefaultReceiveFromDialogueTimeout(mDialogueTimeout);
474
475        C dialogueContext = createContext(request, session, dialogueChannel, logger);
476
477        DialogueInitializationInfo<I, O, C> initializationInfo;
478        initializationInfo = createInitializationInfo(request, response, dialogueContext);
479        Dialogue<I, O, F, L, C> dialogue;
480        try {
481            dialogue = mDialogueFactory.create(initializationInfo);
482        } catch (DialogueFactoryException exception) {
483            throw new ServletException("Unable to create dialogue.", exception);
484        }
485        F firstTurn = createFirstTurn(request, response);
486        return dialogueChannel.start(dialogue, firstTurn, mDialogueTimeout, dialogueContext);
487    }
488
489    private C createContext(HttpServletRequest request,
490                            Session<I, O, F, L, C> session,
491                            SynchronousDialogueChannel<I, O, F, L, C> dialogueChannel,
492                            Logger logger) {
493        C dialogueContext = mDialogueContextFactory.createDialogueContext(request,
494                                                                          session.getId(),
495                                                                          dialogueChannel,
496                                                                          logger);
497        session.setDialogueContext(dialogueContext);
498        return dialogueContext;
499    }
500
501    private WebDialogueInitializationInfo<I, O, C> createInitializationInfo(HttpServletRequest request,
502                                                                            HttpServletResponse response,
503                                                                            C dialogueContext) {
504        return new WebDialogueInitializationInfo<I, O, C>(dialogueContext, request, response, getServletContext(), this);
505    }
506
507    protected Session<I, O, F, L, C> getSession(HttpServletRequest request) throws SessionNotFoundException {
508        String pathInfo = request.getPathInfo();
509
510        if (pathInfo != null && !pathInfo.equals("/")) {
511            if (pathInfo.startsWith("/")) {
512                pathInfo = pathInfo.substring(1);
513            }
514
515            int firstSlash = pathInfo.indexOf('/');
516            if (firstSlash != -1) {
517                pathInfo = pathInfo.substring(0, firstSlash);
518            }
519
520            return getExistingSession(pathInfo);
521        } else {
522            String sessionId = UUID.randomUUID().toString();
523
524            Session<I, O, F, L, C> session = new Session<I, O, F, L, C>(mSessionContainer, sessionId);
525            mSessionContainer.addSession(session);
526            if (mWebappServerSessionTrackingEnabled) {
527                session.setAssociatedHttpSession(request.getSession());
528            }
529
530            return session;
531        }
532    }
533
534    protected Session<I, O, F, L, C> getExistingSession(String sessionId) throws SessionNotFoundException {
535        Session<I, O, F, L, C> session = mSessionContainer.getSession(sessionId);
536
537        if (session == null) throw new SessionNotFoundException("Unable to find session [" + sessionId + "]");
538
539        return session;
540    }
541
542    private I createInputTurn(HttpServletRequest request, HttpServletResponse response) throws ServletException {
543        try {
544            return mInputTurnFactory.createInputTurn(request, response);
545        } catch (InputTurnFactoryException exception) {
546            throw new ServletException(exception);
547        }
548    }
549
550    private F createFirstTurn(HttpServletRequest request, HttpServletResponse response) throws ServletException {
551        try {
552            return mInputTurnFactory.createFirstTurn(request, response);
553        } catch (InputTurnFactoryException exception) {
554            throw new ServletException(exception);
555        }
556    }
557
558    private void commitToResponse(final HttpServletResponse response, ServletResponseContent responseContent)
559            throws IOException {
560        ServletOutputStream outputStream = response.getOutputStream();
561
562        if (mResponseLogger.isDebugEnabled()) {
563            mResponseLogger.debug("Content-length: {}", responseContent.getContentLength());
564            mResponseLogger.debug("Content-type: {}", responseContent.getContentType());
565            mResponseLogger.debug("Content: {}", responseContent.getContentAsString());
566        }
567
568        response.setContentType(responseContent.getContentType());
569        Integer contentLength = responseContent.getContentLength();
570        if (contentLength != null) {
571            response.setContentLength(contentLength);
572        }
573
574        responseContent.writeTo(outputStream);
575    }
576}