source: ralphm-patches/s2s.patch @ 9:7f5cf72c97fc

Last change on this file since 9:7f5cf72c97fc was 9:7f5cf72c97fc, checked in by Ralph Meijer <ralphm@…>, 11 years ago

Save recent work.

File size: 39.4 KB
  • new file wokkel/server.py

    diff -r 313d45b505a7 wokkel/server.py
    - +  
     1# -*- test-case-name: wokkel.test.test_server -*-
     2#
     3# Copyright (c) 2003-2008 Ralph Meijer
     4# See LICENSE for details.
     5
     6"""
     7XMPP Server-to-Server protocol.
     8
     9This module implements several aspects of XMPP server-to-server communications
     10as described in XMPP Core (RFC 3920). Refer to that document for the meaning
     11of the used terminology.
     12"""
     13
     14# hashlib is new in Python 2.5, try that first.
     15try:
     16    from hashlib import sha256
     17except ImportError:
     18    from Crypto.Hash.SHA256 import new as sha256
     19
     20import hmac
     21
     22from zope.interface import implements
     23
     24from twisted.application import service
     25from twisted.internet import defer, reactor
     26from twisted.names.srvconnect import SRVConnector
     27from twisted.python import log
     28from twisted.words.protocols.jabber import error, ijabber, jid, xmlstream
     29from twisted.words.xish import domish
     30
     31from wokkel.generic import DeferredXmlStreamFactory, XmlPipe
     32from wokkel.compat import XmlStreamServerFactory
     33
     34NS_DIALBACK = 'jabber:server:dialback'
     35
     36def generateKey(secret, receivingServer, originatingServer, streamID):
     37    """
     38    Generate a dialback key for server-to-server XMPP Streams.
     39
     40    The dialback key is generated using the algorithm described in
     41    U{XEP-0185<http://www.xmpp.org/extensions/xep-0185.html>}. The used
     42    terminology for the parameters is described in RFC-3920.
     43
     44    @param secret: the shared secret known to the Originating Server and
     45                   Authoritive Server.
     46    @type secret: C{str}
     47    @param receivingServer: the Receiving Server host name.
     48    @type receivingServer: C{str}
     49    @param originatingServer: the Originating Server host name.
     50    @type originatingServer: C{str}
     51    @param streamID: the Stream ID as generated by the Receiving Server.
     52    @type streamID: C{str}
     53    @return: hexadecimal digest of the generated key.
     54    @type: C{str}
     55    """
     56
     57    hashObject = sha256()
     58    hashObject.update(secret)
     59    hashedSecret = hashObject.hexdigest()
     60    message = " ".join([receivingServer, originatingServer, streamID])
     61    hash = hmac.HMAC(hashedSecret, message, digestmod=sha256)
     62    return hash.hexdigest()
     63
     64
     65def trapStreamError(xs, observer):
     66    """
     67    Trap stream errors.
     68
     69    This wraps an observer to catch exceptions. In case of a
     70    L{error.StreamError}, it is send over the given XML stream. All other
     71    exceptions yield a C{'internal-server-error'} stream error, that is
     72    sent over the stream, while the exception is logged.
     73
     74    @return: Wrapped observer
     75    """
     76
     77    def wrappedObserver(element):
     78        try:
     79            observer(element)
     80        except error.StreamError, exc:
     81            xs.sendStreamError(exc)
     82        except:
     83            log.err()
     84            exc = error.StreamError('internal-server-error')
     85            xs.sendStreamError(exc)
     86
     87    return wrappedObserver
     88
     89
     90class XMPPServerConnector(SRVConnector):
     91    def __init__(self, reactor, domain, factory):
     92        SRVConnector.__init__(self, reactor, 'xmpp-server', domain, factory)
     93
     94
     95    def pickServer(self):
     96        host, port = SRVConnector.pickServer(self)
     97
     98        if not self.servers and not self.orderedServers:
     99            # no SRV record, fall back..
     100            port = 5269
     101
     102        return host, port
     103
     104
     105class DialbackFailed(Exception):
     106    pass
     107
     108
     109
     110class OriginatingDialbackInitializer(object):
     111    """
     112    Server Dialback Initializer for the Orginating Server.
     113    """
     114
     115    implements(ijabber.IInitiatingInitializer)
     116
     117    _deferred = None
     118
     119    def __init__(self, xs, thisHost, otherHost, secret):
     120        self.xmlstream = xs
     121        self.thisHost = thisHost
     122        self.otherHost = otherHost
     123        self.secret = secret
     124
     125
     126    def initialize(self):
     127        self._deferred = defer.Deferred()
     128        self.xmlstream.addObserver(xmlstream.STREAM_ERROR_EVENT,
     129                                   self.onStreamError)
     130        self.xmlstream.addObserver("/result[@xmlns='%s']" % NS_DIALBACK,
     131                                   self.onResult)
     132
     133        key = generateKey(self.secret, self.otherHost,
     134                          self.thisHost, self.xmlstream.sid)
     135
     136        result = domish.Element((NS_DIALBACK, 'result'))
     137        result['from'] = self.thisHost
     138        result['to'] = self.otherHost
     139        result.addContent(key)
     140
     141        self.xmlstream.send(result)
     142
     143        return self._deferred
     144
     145
     146    def onResult(self, result):
     147        self.xmlstream.removeObserver(xmlstream.STREAM_ERROR_EVENT,
     148                                      self.onStreamError)
     149        if result['type'] == 'valid':
     150            self.xmlstream.otherEntity = jid.internJID(self.otherHost)
     151            self._deferred.callback(None)
     152        else:
     153            self._deferred.errback(DialbackFailed())
     154
     155
     156    def onStreamError(self, failure):
     157        self.xmlstream.removeObserver("/result[@xmlns='%s']" % NS_DIALBACK,
     158                                      self.onResult)
     159        self._deferred.errback(failure)
     160
     161
     162
     163class ReceivingDialbackInitializer(object):
     164    """
     165    Server Dialback Initializer for the Receiving Server.
     166    """
     167
     168    implements(ijabber.IInitiatingInitializer)
     169
     170    _deferred = None
     171
     172    def __init__(self, xs, thisHost, otherHost, originalStreamID, key):
     173        self.xmlstream = xs
     174        self.thisHost = thisHost
     175        self.otherHost = otherHost
     176        self.originalStreamID = originalStreamID
     177        self.key = key
     178
     179
     180    def initialize(self):
     181        self._deferred = defer.Deferred()
     182        self.xmlstream.addObserver(xmlstream.STREAM_ERROR_EVENT,
     183                                   self.onStreamError)
     184        self.xmlstream.addObserver("/verify[@xmlns='%s']" % NS_DIALBACK,
     185                                   self.onVerify)
     186
     187        verify = domish.Element((NS_DIALBACK, 'verify'))
     188        verify['from'] = self.thisHost
     189        verify['to'] = self.otherHost
     190        verify['id'] = self.originalStreamID
     191        verify.addContent(self.key)
     192
     193        self.xmlstream.send(verify)
     194        return self._deferred
     195
     196
     197    def onVerify(self, verify):
     198        self.xmlstream.removeObserver(xmlstream.STREAM_ERROR_EVENT,
     199                                      self.onStreamError)
     200        if verify['id'] != self.originalStreamID:
     201            self.xmlstream.sendStreamError(error.StreamError('invalid-id'))
     202            self._deferred.errback(DialbackFailed())
     203        elif verify['to'] != self.thisHost:
     204            self.xmlstream.sendStreamError(error.StreamError('host-unknown'))
     205            self._deferred.errback(DialbackFailed())
     206        elif verify['from'] != self.otherHost:
     207            self.xmlstream.sendStreamError(error.StreamError('invalid-from'))
     208            self._deferred.errback(DialbackFailed())
     209        elif verify['type'] == 'valid':
     210            self._deferred.callback(None)
     211        else:
     212            self._deferred.errback(DialbackFailed())
     213
     214
     215    def onStreamError(self, failure):
     216        self.xmlstream.removeObserver("/verify[@xmlns='%s']" % NS_DIALBACK,
     217                                      self.onVerify)
     218        self._deferred.errback(failure)
     219
     220
     221
     222class XMPPServerConnectAuthenticator(xmlstream.ConnectAuthenticator):
     223    """
     224    Authenticator for an outgoing XMPP server-to-server connection.
     225
     226    This authenticator connects to C{otherHost} (the Receiving Server) and then
     227    initiates dialback as C{thisHost} (the Originating Server) using
     228    L{OriginatingDialbackInitializer}.
     229
     230    @ivar thisHost: The domain this server connects from (the Originating
     231                    Server) .
     232    @ivar otherHost: The domain of the server this server connects to (the
     233                     Receiving Server).
     234    @ivar secret: The shared secret that is used for verifying the validity
     235                  of this new connection.
     236    """
     237    namespace = 'jabber:server'
     238
     239    def __init__(self, thisHost, otherHost, secret):
     240        self.thisHost = thisHost
     241        self.otherHost = otherHost
     242        self.secret = secret
     243        xmlstream.ConnectAuthenticator.__init__(self, otherHost)
     244
     245
     246    def connectionMade(self):
     247        self.xmlstream.thisEntity = jid.internJID(self.thisHost)
     248        self.xmlstream.prefixes = {xmlstream.NS_STREAMS: 'stream',
     249                                   NS_DIALBACK: 'db'}
     250        xmlstream.ConnectAuthenticator.connectionMade(self)
     251
     252
     253    def associateWithStream(self, xs):
     254        xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
     255        init = OriginatingDialbackInitializer(xs, self.thisHost,
     256                                              self.otherHost, self.secret)
     257        xs.initializers = [init]
     258
     259
     260
     261class XMPPServerVerifyAuthenticator(xmlstream.ConnectAuthenticator):
     262    """
     263    Authenticator for an outgoing connection to verify an incoming connection.
     264
     265    This authenticator connects to C{otherHost} (the Authoritative Server) and
     266    then initiates dialback as C{thisHost} (the Receiving Server) using
     267    L{ReceivingDialbackInitializer}.
     268
     269    @ivar thisHost: The domain this server connects from (the Receiving
     270                    Server) .
     271    @ivar otherHost: The domain of the server this server connects to (the
     272                     Authoritative Server).
     273    @ivar originalStreamID: The stream ID of the incoming connection that is
     274                            being verified.
     275    @ivar key: The key provided by the Receving Server to be verified.
     276    """
     277    namespace = 'jabber:server'
     278
     279    def __init__(self, thisHost, otherHost, originalStreamID, key):
     280        self.thisHost = thisHost
     281        self.otherHost = otherHost
     282        self.originalStreamID = originalStreamID
     283        self.key = key
     284        xmlstream.ConnectAuthenticator.__init__(self, otherHost)
     285
     286
     287    def connectionMade(self):
     288        self.xmlstream.thisEntity = jid.internJID(self.thisHost)
     289        self.xmlstream.prefixes = {xmlstream.NS_STREAMS: 'stream',
     290                                   NS_DIALBACK: 'db'}
     291        xmlstream.ConnectAuthenticator.connectionMade(self)
     292
     293
     294    def associateWithStream(self, xs):
     295        xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
     296        init = ReceivingDialbackInitializer(xs, self.thisHost, self.otherHost,
     297                                            self.originalStreamID, self.key)
     298        xs.initializers = [init]
     299
     300
     301
     302class XMPPServerListenAuthenticator(xmlstream.ListenAuthenticator):
     303    """
     304    Authenticator for an incoming XMPP server-to-server connection.
     305
     306    This authenticator handles two types of incoming connections. Regular
     307    server-to-server connections are from the Originating Server to the
     308    Receiving Server, where this server is the Receiving Server. These
     309    connections start out by receiving a dialback key, verifying the
     310    key with the Authoritative Server, and then accept normal XMPP stanzas.
     311
     312    The other type of connections is from a Receiving Server to an
     313    Authoritative Server, where this server acts as the Authoritative Server.
     314    These connections are used to verify the validity of an outgoing connection
     315    from this server. In this case, this server receives a verification
     316    request, checks the key and then returns the result.
     317
     318    @ivar service: The service that keeps the list of domains we accept
     319                   connections for.
     320    """
     321    namespace = 'jabber:server'
     322
     323    def __init__(self, service):
     324        xmlstream.ListenAuthenticator.__init__(self)
     325        self.service = service
     326
     327
     328    def streamStarted(self, rootElement):
     329        xmlstream.ListenAuthenticator.streamStarted(self, rootElement)
     330
     331        # Compatibility fix for pre-8.2 implementations of ListenAuthenticator
     332        if not self.xmlstream.sid:
     333            from twisted.python import randbytes
     334            self.xmlstream.sid = randbytes.secureRandom(8).encode('hex')
     335
     336        if self.xmlstream.thisEntity:
     337            targetDomain = self.xmlstream.thisEntity.host
     338        else:
     339            targetDomain = self.service.defaultDomain
     340
     341        def prepareStream(domain):
     342            self.xmlstream.namespace = self.namespace
     343            self.xmlstream.prefixes = {xmlstream.NS_STREAMS: 'stream',
     344                                       NS_DIALBACK: 'db'}
     345            self.xmlstream.thisEntity = jid.internJID(domain)
     346
     347        try:
     348            if xmlstream.NS_STREAMS != rootElement.uri or \
     349               self.namespace != self.xmlstream.namespace or \
     350               ('db', NS_DIALBACK) not in rootElement.localPrefixes.iteritems():
     351                raise error.StreamError('invalid-namespace')
     352
     353            if targetDomain not in self.service.domains:
     354                raise error.StreamError('host-unknown')
     355        except error.StreamError, exc:
     356            prepareStream(self.service.defaultDomain)
     357            self.xmlstream.sendStreamError(exc)
     358            return
     359
     360        self.xmlstream.addObserver("//verify[@xmlns='%s']" % NS_DIALBACK,
     361                                   trapStreamError(self.xmlstream,
     362                                                   self.onVerify))
     363        self.xmlstream.addObserver("//result[@xmlns='%s']" % NS_DIALBACK,
     364                                   self.onResult)
     365
     366        prepareStream(targetDomain)
     367        self.xmlstream.sendHeader()
     368
     369        if self.xmlstream.version >= (1, 0):
     370            features = domish.Element((xmlstream.NS_STREAMS, 'features'))
     371            self.xmlstream.send(features)
     372
     373
     374    def onVerify(self, verify):
     375        try:
     376            receivingServer = jid.JID(verify['from']).host
     377            originatingServer = jid.JID(verify['to']).host
     378        except (KeyError, jid.InvalidFormat):
     379            raise error.StreamError('improper-addressing')
     380
     381        if originatingServer not in self.service.domains:
     382            raise error.StreamError('host-unknown')
     383
     384        if (self.xmlstream.otherEntity and
     385            receivingServer != self.xmlstream.otherEntity.host):
     386            raise error.StreamError('invalid-from')
     387
     388        streamID = verify.getAttribute('id', '')
     389        key = unicode(verify)
     390
     391        calculatedKey = generateKey(self.service.secret, receivingServer,
     392                                    originatingServer, streamID)
     393        validity = (key == calculatedKey) and 'valid' or 'invalid'
     394
     395        reply = domish.Element((NS_DIALBACK, 'verify'))
     396        reply['from'] = originatingServer
     397        reply['to'] = receivingServer
     398        reply['id'] = streamID
     399        reply['type'] = validity
     400        self.xmlstream.send(reply)
     401
     402
     403    def onResult(self, result):
     404        def reply(validity):
     405            reply = domish.Element((NS_DIALBACK, 'result'))
     406            reply['from'] = result['to']
     407            reply['to'] = result['from']
     408            reply['type'] = validity
     409            self.xmlstream.send(reply)
     410
     411        def valid(xs):
     412            reply('valid')
     413            self.xmlstream.otherEntity = jid.internJID(originatingServer)
     414            self.xmlstream.dispatch(self.xmlstream,
     415                                    xmlstream.STREAM_AUTHD_EVENT)
     416
     417        def invalid(failure):
     418            log.err(failure)
     419            reply('invalid')
     420
     421        receivingServer = result['to']
     422        originatingServer = result['from']
     423        key = unicode(result)
     424
     425        d = self.service.validateConnection(receivingServer, originatingServer,
     426                                            self.xmlstream.sid, key)
     427        d.addCallbacks(valid, invalid)
     428        return d
     429
     430
     431
     432class DeferredS2SClientFactory(DeferredXmlStreamFactory):
     433    """
     434    Deferred firing factory for initiating XMPP server-to-server connection.
     435
     436    The deferred has its callbacks called upon succesful authentication with
     437    the other server. In case of failed authentication or connection, the
     438    deferred will have its errbacks called instead.
     439    """
     440
     441    logTraffic = False
     442
     443    def __init__(self, authenticator):
     444        DeferredXmlStreamFactory.__init__(self, authenticator)
     445
     446        self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT,
     447                          self.onConnectionMade)
     448
     449        self.serial = 0
     450
     451
     452    def onConnectionMade(self, xs):
     453        xs.serial = self.serial
     454        self.serial += 1
     455
     456        def logDataIn(buf):
     457            log.msg("RECV (%d): %r" % (xs.serial, buf))
     458
     459        def logDataOut(buf):
     460            log.msg("SEND (%d): %r" % (xs.serial, buf))
     461
     462        if self.logTraffic:
     463            xs.rawDataInFn = logDataIn
     464            xs.rawDataOutFn = logDataOut
     465
     466
     467
     468def initiateS2S(factory):
     469    domain = factory.authenticator.otherHost
     470    c = XMPPServerConnector(reactor, domain, factory)
     471    c.connect()
     472    return factory.deferred
     473
     474
     475
     476class XMPPS2SServerFactory(XmlStreamServerFactory):
     477    """
     478    XMPP Server-to-Server Server factory.
     479
     480    This factory accepts XMPP server-to-server connections.
     481    """
     482
     483    logTraffic = False
     484
     485    def __init__(self, service):
     486        self.service = service
     487
     488        def authenticatorFactory():
     489            return XMPPServerListenAuthenticator(service)
     490
     491        XmlStreamServerFactory.__init__(self, authenticatorFactory)
     492        self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT,
     493                          self.onConnectionMade)
     494        self.addBootstrap(xmlstream.STREAM_AUTHD_EVENT,
     495                          self.onAuthenticated)
     496
     497        self.serial = 0
     498
     499
     500    def onConnectionMade(self, xs):
     501        """
     502        Called when a server-to-server connection was made.
     503
     504        This enables traffic debugging on incoming streams.
     505        """
     506        xs.serial = self.serial
     507        self.serial += 1
     508
     509        def logDataIn(buf):
     510            log.msg("RECV (%d): %r" % (xs.serial, buf))
     511
     512        def logDataOut(buf):
     513            log.msg("SEND (%d): %r" % (xs.serial, buf))
     514
     515        if self.logTraffic:
     516            xs.rawDataInFn = logDataIn
     517            xs.rawDataOutFn = logDataOut
     518
     519        xs.addObserver(xmlstream.STREAM_ERROR_EVENT, self.onError)
     520
     521
     522    def onAuthenticated(self, xs):
     523        thisHost = xs.thisEntity.host
     524        otherHost = xs.otherEntity.host
     525
     526        log.msg("Incoming connection %d from %r to %r established" %
     527                (xs.serial, otherHost, thisHost))
     528
     529        xs.addObserver(xmlstream.STREAM_END_EVENT, self.onConnectionLost,
     530                                                   0, xs)
     531        xs.addObserver('/*', self.onElement, 0, xs)
     532
     533
     534    def onConnectionLost(self, xs, reason):
     535        thisHost = xs.thisEntity.host
     536        otherHost = xs.otherEntity.host
     537
     538        log.msg("Incoming connection %d from %r to %r disconnected" %
     539                (xs.serial, otherHost, thisHost))
     540
     541
     542    def onError(self, reason):
     543        log.err(reason, "Stream Error")
     544
     545
     546    def onElement(self, xs, element):
     547        """
     548        Called when an element was received from one of the connected streams.
     549
     550        """
     551        if element.handled:
     552            return
     553        else:
     554            self.service.dispatch(xs, element)
     555
     556
     557
     558class ServerService(object):
     559    """
     560    Service for managing XMPP server to server connections.
     561    """
     562
     563    logTraffic = False
     564
     565    def __init__(self, router, secret, domain):
     566        self.router = router
     567        self.secret = secret
     568        self.defaultDomain = domain
     569        self.domains = set([domain])
     570
     571        self._outgoingStreams = {}
     572        self._outgoingQueues = {}
     573        self._outgoingConnecting = set()
     574        self.serial = 0
     575
     576        pipe = XmlPipe()
     577        self.xmlstream = pipe.source
     578        self.router.addRoute(None, pipe.sink)
     579        self.xmlstream.addObserver('/*', self.send)
     580
     581
     582    def outgoingInitialized(self, xs):
     583        thisHost = xs.thisEntity.host
     584        otherHost = xs.otherEntity.host
     585
     586        log.msg("Outgoing connection %d from %r to %r established" %
     587                (xs.serial, thisHost, otherHost))
     588
     589        self._outgoingStreams[thisHost, otherHost] = xs
     590        xs.addObserver(xmlstream.STREAM_END_EVENT,
     591                       lambda _: self.outgoingDisconnected(xs))
     592
     593        if (thisHost, otherHost) in self._outgoingQueues:
     594            for element in self._outgoingQueues[thisHost, otherHost]:
     595                xs.send(element)
     596            del self._outgoingQueues[thisHost, otherHost]
     597
     598
     599    def outgoingDisconnected(self, xs):
     600        thisHost = xs.thisEntity.host
     601        otherHost = xs.otherEntity.host
     602
     603        log.msg("Outgoing connection %d from %r to %r disconnected" %
     604                (xs.serial, thisHost, otherHost))
     605
     606        del self._outgoingStreams[thisHost, otherHost]
     607
     608
     609    def initiateOutgoingStream(self, thisHost, otherHost):
     610        """
     611        Initiate an outgoing XMPP server-to-server connection.
     612        """
     613
     614        def resetConnecting(_):
     615            self._outgoingConnecting.remove((thisHost, otherHost))
     616
     617        if (thisHost, otherHost) in self._outgoingConnecting:
     618            return
     619
     620        authenticator = XMPPServerConnectAuthenticator(thisHost,
     621                                                       otherHost,
     622                                                       self.secret)
     623        factory = DeferredS2SClientFactory(authenticator)
     624        factory.addBootstrap(xmlstream.STREAM_AUTHD_EVENT,
     625                             self.outgoingInitialized)
     626        factory.logTraffic = self.logTraffic
     627
     628        self._outgoingConnecting.add((thisHost, otherHost))
     629
     630        d = initiateS2S(factory)
     631        d.addBoth(resetConnecting)
     632        return d
     633
     634
     635    def validateConnection(self, thisHost, otherHost, sid, key):
     636        """
     637        Validate an incoming XMPP server-to-server connection.
     638        """
     639
     640        def connected(xs):
     641            # Set up stream for immediate disconnection.
     642            def disconnect(_):
     643                xs.transport.loseConnection()
     644            xs.addObserver(xmlstream.STREAM_AUTHD_EVENT, disconnect)
     645            xs.addObserver(xmlstream.INIT_FAILED_EVENT, disconnect)
     646
     647        authenticator = XMPPServerVerifyAuthenticator(thisHost, otherHost,
     648                                                      sid, key)
     649        factory = DeferredS2SClientFactory(authenticator)
     650        factory.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, connected)
     651        factory.logTraffic = self.logTraffic
     652
     653        d = initiateS2S(factory)
     654        return d
     655
     656
     657    def send(self, stanza):
     658        """
     659        Send stanza to the proper XML Stream.
     660
     661        This uses addressing embedded in the stanza to find the correct stream
     662        to forward the stanza to.
     663        """
     664
     665        otherHost = jid.internJID(stanza["to"]).host
     666        thisHost = jid.internJID(stanza["from"]).host
     667
     668        if (thisHost, otherHost) not in self._outgoingStreams:
     669            # There is no connection with the destination (yet). Cache the
     670            # outgoing stanza until the connection has been established.
     671            # XXX: If the connection cannot be established, the queue should
     672            #      be emptied at some point.
     673            if (thisHost, otherHost) not in self._outgoingQueues:
     674                self._outgoingQueues[(thisHost, otherHost)] = []
     675            self._outgoingQueues[(thisHost, otherHost)].append(stanza)
     676            self.initiateOutgoingStream(thisHost, otherHost)
     677        else:
     678            self._outgoingStreams[(thisHost, otherHost)].send(stanza)
     679
     680
     681    def dispatch(self, xs, stanza):
     682        """
     683        Send on element to be routed within the server.
     684        """
     685        stanzaFrom = stanza.getAttribute('from')
     686        stanzaTo = stanza.getAttribute('to')
     687
     688        if not stanzaFrom or not stanzaTo:
     689            xs.sendStreamError(error.StreamError('improper-addressing'))
     690        else:
     691            try:
     692                sender = jid.internJID(stanzaFrom)
     693                recipient = jid.internJID(stanzaTo)
     694            except jid.InvalidFormat:
     695                log.msg("Dropping error stanza with malformed JID")
     696
     697            if sender.host != xs.otherEntity.host:
     698                xs.sendStreamError(error.StreamError('invalid-from'))
     699            else:
     700                self.xmlstream.send(stanza)
  • new file wokkel/test/test_server.py

    diff -r 313d45b505a7 wokkel/test/test_server.py
    - +  
     1# Copyright (c) 2003-2008 Ralph Meijer
     2# See LICENSE for details.
     3
     4"""
     5Tests for L{wokkel.server}.
     6"""
     7
     8from twisted.internet import defer
     9from twisted.python import failure
     10from twisted.test.proto_helpers import StringTransport
     11from twisted.trial import unittest
     12from twisted.words.protocols.jabber import error, jid, xmlstream
     13from twisted.words.xish import domish
     14
     15from wokkel import component, server
     16
     17NS_STREAMS = 'http://etherx.jabber.org/streams'
     18NS_DIALBACK = "jabber:server:dialback"
     19
     20class GenerateKeyTest(unittest.TestCase):
     21    """
     22    Tests for L{server.generateKey}.
     23    """
     24
     25    def testBasic(self):
     26        originating = "example.org"
     27        receiving = "xmpp.example.com"
     28        sid = "D60000229F"
     29        secret = "s3cr3tf0rd14lb4ck"
     30
     31        key = server.generateKey(secret, receiving, originating, sid)
     32
     33        self.assertEqual(key,
     34            '37c69b1cf07a3f67c04a5ef5902fa5114f2c76fe4a2686482ba5b89323075643')
     35
     36
     37
     38class XMPPServerListenAuthenticatorTest(unittest.TestCase):
     39    """
     40    Tests for L{server.XMPPServerListenAuthenticator}.
     41    """
     42
     43    secret = "s3cr3tf0rd14lb4ck"
     44    originating = "example.org"
     45    receiving = "xmpp.example.com"
     46    sid = "D60000229F"
     47    key = '37c69b1cf07a3f67c04a5ef5902fa5114f2c76fe4a2686482ba5b89323075643'
     48
     49    def setUp(self):
     50        self.output = []
     51
     52        class MyService(object):
     53            pass
     54
     55        self.service = MyService()
     56        self.service.defaultDomain = self.receiving
     57        self.service.domains = [self.receiving, 'pubsub.'+self.receiving]
     58        self.service.secret = self.secret
     59
     60        self.authenticator = server.XMPPServerListenAuthenticator(self.service)
     61        self.xmlstream = xmlstream.XmlStream(self.authenticator)
     62        self.xmlstream.send = self.output.append
     63        self.xmlstream.transport = StringTransport()
     64
     65
     66    def test_attributes(self):
     67        """
     68        Test attributes of authenticator and stream objects.
     69        """
     70        self.assertEqual(self.service, self.authenticator.service)
     71        self.assertEqual(self.xmlstream.initiating, False)
     72
     73
     74    def test_streamStartedVersion0(self):
     75        """
     76        The authenticator supports pre-XMPP 1.0 streams.
     77        """
     78        self.xmlstream.connectionMade()
     79        self.xmlstream.dataReceived(
     80            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     81                           "xmlns:db='jabber:server:dialback' "
     82                           "xmlns='jabber:server' "
     83                           "to='xmpp.example.com'>")
     84        self.assertEqual((0, 0), self.xmlstream.version)
     85
     86
     87    def test_streamStartedVersion1(self):
     88        """
     89        The authenticator supports XMPP 1.0 streams.
     90        """
     91        self.xmlstream.connectionMade()
     92        self.xmlstream.dataReceived(
     93            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     94                           "xmlns:db='jabber:server:dialback' "
     95                           "xmlns='jabber:server' "
     96                           "to='xmpp.example.com' "
     97                           "version='1.0'>")
     98        self.assertEqual((1, 0), self.xmlstream.version)
     99
     100
     101    def test_streamStartedSID(self):
     102        """
     103        The response stream will have a stream ID.
     104        """
     105        self.xmlstream.connectionMade()
     106        self.assertIdentical(None, self.xmlstream.sid)
     107
     108        self.xmlstream.dataReceived(
     109            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     110                           "xmlns:db='jabber:server:dialback' "
     111                           "xmlns='jabber:server' "
     112                           "to='xmpp.example.com' "
     113                           "version='1.0'>")
     114        self.assertNotIdentical(None, self.xmlstream.sid)
     115
     116
     117    def test_streamStartedSentResponseHeader(self):
     118        """
     119        A stream header is sent in response to the incoming stream header.
     120        """
     121        self.xmlstream.connectionMade()
     122        self.assertFalse(self.xmlstream._headerSent)
     123
     124        self.xmlstream.dataReceived(
     125            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     126                           "xmlns:db='jabber:server:dialback' "
     127                           "xmlns='jabber:server' "
     128                           "to='xmpp.example.com'>")
     129        self.assertTrue(self.xmlstream._headerSent)
     130
     131
     132    def test_streamStartedNotSentFeatures(self):
     133        """
     134        No features are sent in response to an XMPP < 1.0 stream header.
     135        """
     136        self.xmlstream.connectionMade()
     137        self.xmlstream.dataReceived(
     138            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     139                           "xmlns:db='jabber:server:dialback' "
     140                           "xmlns='jabber:server' "
     141                           "to='xmpp.example.com'>")
     142        self.assertEqual(1, len(self.output))
     143
     144
     145    def test_streamStartedSentFeatures(self):
     146        """
     147        Features are sent in response to an XMPP >= 1.0 stream header.
     148        """
     149        self.xmlstream.connectionMade()
     150        self.xmlstream.dataReceived(
     151            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     152                           "xmlns:db='jabber:server:dialback' "
     153                           "xmlns='jabber:server' "
     154                           "to='xmpp.example.com' "
     155                           "version='1.0'>")
     156        self.assertEqual(2, len(self.output))
     157        features = self.output[-1]
     158        self.assertEqual(NS_STREAMS, features.uri)
     159        self.assertEqual('features', features.name)
     160
     161
     162    def test_streamRootElement(self):
     163        """
     164        Test stream error on wrong stream namespace.
     165        """
     166        self.xmlstream.connectionMade()
     167        self.xmlstream.dataReceived(
     168            "<stream:stream xmlns:stream='badns' "
     169                           "xmlns:db='jabber:server:dialback' "
     170                           "xmlns='jabber:server' "
     171                           "to='xmpp.example.com'>")
     172
     173        self.assertEqual(3, len(self.output))
     174        exc = error.exceptionFromStreamError(self.output[1])
     175        self.assertEqual('invalid-namespace', exc.condition)
     176
     177
     178    def test_streamDefaultNamespace(self):
     179        """
     180        Test stream error on missing dialback namespace.
     181        """
     182        self.xmlstream.connectionMade()
     183        self.xmlstream.dataReceived(
     184            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     185                           "xmlns:db='jabber:server:dialback' "
     186                           "xmlns='badns' "
     187                           "to='xmpp.example.com'>")
     188
     189        self.assertEqual(3, len(self.output))
     190        exc = error.exceptionFromStreamError(self.output[1])
     191        self.assertEqual('invalid-namespace', exc.condition)
     192
     193
     194    def test_streamNoDialbackNamespace(self):
     195        """
     196        Test stream error on missing dialback namespace.
     197        """
     198        self.xmlstream.connectionMade()
     199        self.xmlstream.dataReceived(
     200            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     201                           "xmlns='jabber:server' "
     202                           "to='xmpp.example.com'>")
     203
     204        self.assertEqual(3, len(self.output))
     205        exc = error.exceptionFromStreamError(self.output[1])
     206        self.assertEqual('invalid-namespace', exc.condition)
     207
     208
     209    def test_streamBadDialbackNamespace(self):
     210        """
     211        Test stream error on missing dialback namespace.
     212        """
     213        self.xmlstream.connectionMade()
     214        self.xmlstream.dataReceived(
     215            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     216                           "xmlns:db='badns' "
     217                           "xmlns='jabber:server' "
     218                           "to='xmpp.example.com'>")
     219
     220        self.assertEqual(3, len(self.output))
     221        exc = error.exceptionFromStreamError(self.output[1])
     222        self.assertEqual('invalid-namespace', exc.condition)
     223
     224
     225    def test_streamToUnknownHost(self):
     226        """
     227        Test stream error on stream's to attribute having unknown host.
     228        """
     229        self.xmlstream.connectionMade()
     230        self.xmlstream.dataReceived(
     231            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     232                           "xmlns:db='jabber:server:dialback' "
     233                           "xmlns='jabber:server' "
     234                           "to='badhost'>")
     235
     236        self.assertEqual(3, len(self.output))
     237        exc = error.exceptionFromStreamError(self.output[1])
     238        self.assertEqual('host-unknown', exc.condition)
     239
     240
     241    def test_streamToOtherLocalHost(self):
     242        """
     243        The authenticator supports XMPP 1.0 streams.
     244        """
     245        self.xmlstream.connectionMade()
     246        self.xmlstream.dataReceived(
     247            "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
     248                           "xmlns:db='jabber:server:dialback' "
     249                           "xmlns='jabber:server' "
     250                           "to='pubsub.xmpp.example.com' "
     251                           "version='1.0'>")
     252
     253        self.assertEqual(2, len(self.output))
     254        self.assertEqual(jid.JID('pubsub.xmpp.example.com'),
     255                         self.xmlstream.thisEntity)
     256
     257    def test_onResult(self):
     258        def cb(result):
     259            self.assertEqual(1, len(self.output))
     260            reply = self.output[0]
     261            self.assertEqual(self.originating, reply['to'])
     262            self.assertEqual(self.receiving, reply['from'])
     263            self.assertEqual('valid', reply['type'])
     264
     265        def validateConnection(thisHost, otherHost, sid, key):
     266            self.assertEqual(thisHost, self.receiving)
     267            self.assertEqual(otherHost, self.originating)
     268            self.assertEqual(sid, self.sid)
     269            self.assertEqual(key, self.key)
     270            return defer.succeed(None)
     271
     272        self.xmlstream.sid = self.sid
     273        self.service.validateConnection = validateConnection
     274
     275        result = domish.Element((NS_DIALBACK, 'result'))
     276        result['to'] = self.receiving
     277        result['from'] = self.originating
     278        result.addContent(self.key)
     279
     280        d = self.authenticator.onResult(result)
     281        d.addCallback(cb)
     282        return d
     283
     284
     285    def test_onResultFailure(self):
     286        class TestError(Exception):
     287            pass
     288
     289        def cb(result):
     290            reply = self.output[0]
     291            self.assertEqual('invalid', reply['type'])
     292            self.assertEqual(1, len(self.flushLoggedErrors(TestError)))
     293
     294
     295        def validateConnection(thisHost, otherHost, sid, key):
     296            return defer.fail(TestError())
     297
     298        self.xmlstream.sid = self.sid
     299        self.service.validateConnection = validateConnection
     300
     301        result = domish.Element((NS_DIALBACK, 'result'))
     302        result['to'] = self.receiving
     303        result['from'] = self.originating
     304        result.addContent(self.key)
     305
     306        d = self.authenticator.onResult(result)
     307        d.addCallback(cb)
     308        return d
     309
     310
     311
     312class FakeService(object):
     313    domains = set(['example.org', 'pubsub.example.org'])
     314    defaultDomain = 'example.org'
     315    secret = 'mysecret'
     316
     317    def __init__(self):
     318        self.dispatched = []
     319
     320    def dispatch(self, xs, element):
     321        self.dispatched.append(element)
     322
     323
     324
     325class XMPPS2SServerFactoryTest(unittest.TestCase):
     326    """
     327    Tests for L{component.XMPPS2SServerFactory}.
     328    """
     329
     330    def setUp(self):
     331        self.service = FakeService()
     332        self.factory = server.XMPPS2SServerFactory(self.service)
     333        self.xmlstream = self.factory.buildProtocol(None)
     334        self.transport = StringTransport()
     335        self.xmlstream.thisEntity = jid.JID('example.org')
     336        self.xmlstream.otherEntity = jid.JID('example.com')
     337
     338
     339    def test_makeConnection(self):
     340        """
     341        A new connection increases the stream serial count. No logs by default.
     342        """
     343        self.xmlstream.makeConnection(self.transport)
     344        self.assertEqual(0, self.xmlstream.serial)
     345        self.assertEqual(1, self.factory.serial)
     346        self.assertIdentical(None, self.xmlstream.rawDataInFn)
     347        self.assertIdentical(None, self.xmlstream.rawDataOutFn)
     348
     349
     350    def test_makeConnectionLogTraffic(self):
     351        """
     352        Setting logTraffic should set up raw data loggers.
     353        """
     354        self.factory.logTraffic = True
     355        self.xmlstream.makeConnection(self.transport)
     356        self.assertNotIdentical(None, self.xmlstream.rawDataInFn)
     357        self.assertNotIdentical(None, self.xmlstream.rawDataOutFn)
     358
     359
     360    def test_onError(self):
     361        """
     362        An observer for stream errors should trigger onError to log it.
     363        """
     364        self.xmlstream.makeConnection(self.transport)
     365
     366        class TestError(Exception):
     367            pass
     368
     369        reason = failure.Failure(TestError())
     370        self.xmlstream.dispatch(reason, xmlstream.STREAM_ERROR_EVENT)
     371        self.assertEqual(1, len(self.flushLoggedErrors(TestError)))
     372
     373
     374    def test_connectionInitialized(self):
     375        """
     376        """
     377        self.xmlstream.makeConnection(self.transport)
     378        self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT)
     379
     380
     381    def test_connectionLost(self):
     382        """
     383        """
     384        self.xmlstream.makeConnection(self.transport)
     385        self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT)
     386        self.xmlstream.dispatch(None, xmlstream.STREAM_END_EVENT)
     387
     388
     389    def test_Element(self):
     390        self.xmlstream.makeConnection(self.transport)
     391        self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT)
     392
     393        stanza = domish.Element((None, "presence"))
     394        self.xmlstream.dispatch(stanza)
     395        self.assertEqual(1, len(self.service.dispatched))
     396        self.assertIdentical(stanza, self.service.dispatched[-1])
     397
     398
     399    def test_ElementNotAuthenticated(self):
     400        self.xmlstream.makeConnection(self.transport)
     401
     402        stanza = domish.Element((None, "presence"))
     403        self.xmlstream.dispatch(stanza)
     404        self.assertEqual(0, len(self.service.dispatched))
     405
     406
     407
     408class ServerServiceTest(unittest.TestCase):
     409
     410    def setUp(self):
     411        self.output = []
     412
     413        self.xmlstream = xmlstream.XmlStream(xmlstream.Authenticator())
     414        self.xmlstream.thisEntity = jid.JID('example.org')
     415        self.xmlstream.otherEntity = jid.JID('example.com')
     416        self.xmlstream.send = self.output.append
     417
     418        self.router = component.Router()
     419        self.service = server.ServerService(self.router,
     420                                            secret='mysecret',
     421                                            domain='example.org')
     422        self.service.xmlstream = self.xmlstream
     423
     424
     425    def test_defaultDomainInDomains(self):
     426        """
     427        The default domain is part of the domains considered local.
     428        """
     429        self.assertIn(self.service.defaultDomain, self.service.domains)
     430
     431
     432    def test_dispatch(self):
     433        stanza = domish.Element((None, "presence"))
     434        stanza['to'] = 'user@example.org'
     435        stanza['from'] = 'other@example.com'
     436        self.service.dispatch(self.xmlstream, stanza)
     437
     438        self.assertEqual(1, len(self.output))
     439        self.assertIdentical(stanza, self.output[-1])
     440
     441
     442    def test_dispatchNoTo(self):
     443        errors = []
     444        self.xmlstream.sendStreamError = errors.append
     445
     446        stanza = domish.Element((None, "presence"))
     447        stanza['from'] = 'other@example.com'
     448        self.service.dispatch(self.xmlstream, stanza)
     449
     450        self.assertEqual(1, len(errors))
Note: See TracBrowser for help on using the repository browser.