チャットサーバ(Parallel and Concurrent Programming in Haskell Chapter 12後半)

英語の原文はこちらのページで読める
http://chimera.labs.oreilly.com/books/1230000000929/ch12.html#sec_chat

このページで紹介しているコードはほとんど上記ページからの引用。


入力された値を数倍して返すというシンプルなサーバを元にチャットサーバを実装する。

仕様は以下

  • 存在するチャネルは一つだけ(簡単のため)
  • クライアントが接続してきたときに名前を尋ねる。入力された名前が既に使われているなら他の名前にしてもらう
  • クライアントからの入力は以下のようなコマンドとして処理する
    • /tell name message: 指定のユーザへメッセージを送る
    • /kick name : 指定のユーザを追い出す
    • /quit : 接続をやめる
    • message : すべてのユーザにメッセージを投げる
  • クライアントが接続したとき・接続を切ったときは他のすべてのクライアントに通知
  • 正しくエラーを処理し不整合を起こさない
    • ex. 二人の同名のクライアントが同時につなげてきたときはどちらか一方がつながる
  • 二人のクライアントが同時に互いをkickしたときは片方のみがkickされる

アーキテクチャ

基本的には前節で作った簡易サーバと同じように動く。receiveスレッドがクライアントからの入力を受けてチャネルに入れる。serverスレッドはチャネルのメッセージを見て処理を行う。

kickの際には2つのクライアントが同時にkickされないようにするため、serverスレッドがkickの処理を行う前に既にkickされたクライアントではないかを確認するようにする。

また、クライアント毎にチャネルは一つだけにすることでメッセージの順序を保てるようにする。

クライアントのデータ

各クライアント毎に持つデータを定義する。

type ClientName = String

data Client = Client
  { clientName     :: ClientName
  , clientHandle   :: Handle
  , clientKicked   :: TVar (Maybe String)
  , clientSendChan :: TChan Message
  }

data Message = Notice String
             | Tell ClientName String
             | Broadcast ClientName String
             | Command String

clientKickedはこのクライアントがkickされたかどうかを保持するデータ。Nothingであればkickされていない。
チャネルに入れるメッセージは4種類。Noticeはサーバからサーバへ渡すメッセージ、Tellは特定のクライアント宛て、Broadcastは全員宛て、Commandはクライアントからの生の入力。

サーバーのデータ

サーバの状態を表すデータを定義。接続している各クライアントの名前とそのクライアントデータのマップ。

data Server = Server
  { clients :: TVar (Map ClientName Client)
  }

実装

main

クライアントからのリクエストを受ける部分はほとんど変わらない。サーバ型のデータを作成しそれをtalkに渡すようにしたところのみ追加された。

main :: IO ()
main = withSocketsDo $ do 
  server <- newServer                                          
  sock <- listenOn $ PortNumber $ fromIntegral port            
  printf "Listening on port %d\n" port                         
  forever $ do                                                 
    (handle, host, port) <- accept sock                        
    printf "Accepted connection from %s: %s\n" host (show port)
    forkFinally (talk handle server) (\_ -> hClose handle)     

port :: Int 
port = 44444

クライアントの接続時

クライアントの接続時は少しやることが増える。名前を聞いて既存のクライアントと重複していないか調べ、重複していれば別の名前を入れてもらい、そうでなければクライアントデータを作成してクライアントのマップに追加する。また、接続時には他のクライアントに対して通知を行う。

newClient :: ClientName -> Handle -> STM Client  
newClient name handle = do                       
  c <- newTChan                                  
  k <- newTVar Nothing                           
  return Client { clientName      = name         
                , clientHandle    = handle       
                , clientKicked    = k            
                , clientSendChan  = c            
                }                                

checkAddClient :: Server -> ClientName -> Handle -> IO (Maybe Client)
checkAddClient server@Server{..} name handle = atomically $ do       
  clientmap <- readTVar clients                                      
  if M.member name clientmap                                         
    then return Nothing                                              
    else do                                                          
      client <- newClient name handle                                
      writeTVar clients $ M.insert name client clientmap             
      broadcast server $ Notice $ name ++ " has connected"           
      return $ Just client                                           

removeClient :: Server -> ClientName -> IO ()            
removeClient server@Server{..} name = atomically $ do    
  modifyTVar' clients $ M.delete name                    
  broadcast server $ Notice $ name ++ " has disconnected"

newClientはシンプル。空のチャネルと空のTVarを作ってClientに包んでいる。

checkAddClientは接続時に名前を聞いて重複がなければJust clientを返す。重複している名前ならNothing。

removeClientはクライアントのマップから指定のクライアントを除いて他のクライアントへ通知。

ここで使われているbroadcastは以下。
クライアントのマップを取り出してそこにあるすべてのクライアントのチャネルに対してメッセージを追加する。

broadcast :: Server -> Message -> STM()                          
broadcast Server{..} msg = do                                    
  clientmap <- readTVar clients                                  
  mapM_ (\client -> sendMessage client msg) $ M.elems clientmap  

sendMessage :: Client -> Message -> STM ()                
sendMessage Client{..} msg = writeTChan clientSendChan msg

sendMessageで使われているClient{..}という表現はRecordWildCardsというghcの拡張を使ったもの。これを指定しておくと、レコードの形で定義されたデータを表現するのに、そのレコード名を書くだけでよくなる。もしこの拡張を使わないと以下のように書くことになる。

sendMessage client msg = writeTChan (clientSendChan client) msg

これでもよいが、レコードが多い場合には至るところにこのようなパターンのコードが書かれてしまうため冗長になる。

以上の関数を使ってクライアントの接続処理を行うtalkを書く。

talk :: Handle -> Server -> IO ()                                                   
talk handle server@Server{..} = do                                                  
  hSetNewlineMode handle universalNewlineMode    (1)
  hSetBuffering handle LineBuffering                                                
  readName                                                                          
  where                                                                             
    readName = do                                                                   
      hPutStrLn handle "What is your name ?"                                        
      name <- hGetLine handle                                                       
      if null name                                                                  
        then readName                                                               
        else mask $ \restore -> do               (2)                                   
          ok <- checkAddClient server name handle                                   
          case ok of                                                                
            Nothing -> restore $ do                                                 
              hPrintf handle "the name %s is in use, choose another\n" name         
              readName                                                              
            Just client ->                                                          
              restore (runClient server client) `finally` removeClient server name  

(1)は\r\nを\nに変換するためのコード。
readNameがメインの処理でクライアントから入力された名前に応じて処理を分ける。

注目すべき点は(2)でmaskをしている所。checkAddClientの処理中に例外を受け取ってしまった場合、クライアントのマップには名前とデータが追加されるものの実際は存在しないという状態のまま残ってしまう。なのでここでmaskしてrunClientのときに受け取るようにしている。

メッセージのやりとり

クライアントの接続が完了したのでメッセージの送受信を行う段階。
serverスレッド、receiveスレッドをそれぞれ作成してraceによって協調させる。

まずはreceiveスレッド

receiveThread = forever $ do                                                           
  msg <- hGetLine clientHandle                                                         
  atomically $ sendMessage client $ Command msg                                        

receiveスレッドはクライアントのハンドルからデータを受け取ってすべてをCommandに包んでチャネルに入れるだけ。

次にserverスレッド

serverThread = join $ atomically $ do                                                  
  k <- readTVar clientKicked                                                           
  case k of                                                                            
    Just reason -> return $ hPutStrLn clientHandle $ "you have been kicked: " ++ reason
    Nothing -> do                                                                      
      msg <- readTChan clientSendChan                                                  
      return $ do                                                                      
        continue <- handleMessage server client msg                                    
        when continue $ serverThread                                                   

serverスレッドは以下を順に処理を行う。

  • kickされたか否かを確認
    • kickされている => kickされたことをそのクライアントに通知
    • kickされていない => 以下へ
  • チャネルからメッセージを読み込む(TChanなのでメッセージが存在しなければretryすることになる)
  • メッセージの内容にしたがって処理(handleMessage)
  • 返り値によってserverスレッドを終了するか判断

チャネルにデータがない場合はretryされるため、チャネルにデータが入った段階で再度頭から処理が実行される。つまり、チャネルのメッセージを処理する前に必ずkickの確認が入る。

また、全体をjoinしているのは前回作成した簡易サーバと同じ理由。前回は以下のような形に書いていた。

serverThread = do
  action <- atomically $ do ...
  action

今回もこれと同じ状況になっていてSTM内でIOが実行できない。なのでatomicallyでSTMからIOを取り出して次にそれを実行する。これを簡潔に書くとjoinを使った形になる。ちなみにjoinの型はこちら。

join :: Monad m => m (m a) -> m a

ここではmはIOとなるので、joinして一枚剥いだときにIOが実行されることになる。

最後にrunClient自体は以下のようになる。

runClient :: Server -> Client -> IO ()                                                     
runClient server@Server{..} client@Client{..} = do                                         
  race serverThread receiveThread                                                          
  return ()                                                                                

serverThreadとreceiveThreadをraceによってそれぞれ別スレッドとして実行する。raceを使っていることでどちらかが終了すればもう片方も終了させることができる。例外時も同様に処理されるためクリーンアップの処理を別に書く必要がなくてとても簡潔に書ける。

メッセージ自体の処理

serverThreadで呼ばれたhandleMessageがメッセージそのものを見て処理を分けている。

handleMessage :: Server -> Client -> Message -> IO Bool                    
handleMessage server client@Client{..} message =                           
  case message of                                                          
    Notice msg        -> output $ "*** " ++ msg                            
    Tell name msg     -> output $ "*" ++ name ++ "*: " ++ msg              
    Broadcast name msg-> output $ "<" ++ name ++ ">: " ++ msg              
    Command msg ->                                                         
      case words msg of                                                    
        ["/kick", who] -> do                                               
          atomically $ kick server who clientName                          
          return True                                                      
        "/tell" : who : what -> do                                         
          tell server client who $ unwords what                            
          return True                                                      
        ["/quit"] ->                                                       
          return False                                                     
        ('/':_):_ -> do                                                    
          hPutStrLn clientHandle $ "unrecognized command: " ++ msg         
          return True                                                      
        _ -> do                                                            
          atomically $ broadcast server $ Broadcast clientName msg         
          return True                                                      
  where                                                                    
    output s = do                                                          
      hPutStrLn clientHandle s                                             
      return True                                                          

返り値はBoolになっていてクライアントとの接続を終了したい場合はFalseを返す。

Commandはreceiveスレッドによってチャネルに追加されるメッセージで、クライアントから自分にきたメッセージがそのまま入る。なのでその文字列をチェックして以下のようにする。

  • /kick => 指定されたクライアントのステータスをkickの状態に変更
  • /tell => 指定されたクライアントのチャネルにのみメッセージを追加する(Tellに包んで)
  • /quit => 接続終了
  • /上記以外 => エラー
  • 上記以外 => すべてのクライアントのチャネルにメッセージを追加(Broadcastに包んで)

上記のように処理されて他のサーバから入ってきたメッセージをNotice, Tell, Broadcastに応じて表示を変えつつクライアントに返す。

kickとtellに関しては以下のように定義。
クライアントのマップから対応する相手のデータを読みだして、そのステータスを変更またはチャネルにデータを追加する。

sendToName :: Server -> ClientName -> Message -> STM Bool               
sendToName server@Server{..} name msg = do                              
  clientmap <- readTVar clients                                         
  case M.lookup name clientmap of                                       
    Nothing -> return False                                             
    Just client -> sendMessage client msg >> return True                
                                                                        
kick :: Server -> ClientName -> ClientName -> STM ()                    
kick server@Server{..} who by = do                                      
  clientmap <- readTVar clients                                         
  case M.lookup who clientmap of                                        
    Nothing ->                                                          
      void $ sendToName server by $ Notice (who ++ " is not connected") 
    Just Client{..} -> do                                               
      writeTVar clientKicked $ Just $ "by " ++ by                       
      void $ sendToName server by $ Notice ("you kicked " ++ who)       
                                                                        
tell :: Server -> Client -> ClientName -> String -> IO ()               
tell server@Server{..} Client{..} who s = do                            
  ok <- atomically $ sendToName server who $ Tell clientName s          
  if ok                                                                 
    then return ()                                                      
    else hPutStrLn clientHandle $ who ++ " is not connected"            

Haskellによる並列・並行プログラミング

Haskellによる並列・並行プログラミング