diff --git a/librespot/core.py b/librespot/core.py index 1e5f233..276f526 100644 --- a/librespot/core.py +++ b/librespot/core.py @@ -1280,89 +1280,158 @@ class Session(Closeable, MessageListener, SubListener): self.__inner.device_id)) def connect(self) -> None: - """Connect to the Spotify Server""" - acc = Session.Accumulator() - # Send ClientHello - nonce = Random.get_random_bytes(0x10) - client_hello_proto = Keyexchange.ClientHello( - build_info=Version.standard_build_info(), - client_nonce=nonce, - cryptosuites_supported=[ - Keyexchange.Cryptosuite.CRYPTO_SUITE_SHANNON - ], - login_crypto_hello=Keyexchange.LoginCryptoHelloUnion( - diffie_hellman=Keyexchange.LoginCryptoDiffieHellmanHello( - gc=self.__keys.public_key_bytes(), server_keys_known=1), ), - padding=b"\x1e", - ) - client_hello_bytes = client_hello_proto.SerializeToString() - self.connection.write(b"\x00\x04") - self.connection.write_int(2 + 4 + len(client_hello_bytes)) - self.connection.write(client_hello_bytes) - self.connection.flush() - acc.write(b"\x00\x04") - acc.write_int(2 + 4 + len(client_hello_bytes)) - acc.write(client_hello_bytes) - # Read APResponseMessage - ap_response_message_length = self.connection.read_int() - acc.write_int(ap_response_message_length) - ap_response_message_bytes = self.connection.read( - ap_response_message_length - 4) - acc.write(ap_response_message_bytes) - ap_response_message_proto = Keyexchange.APResponseMessage() - ap_response_message_proto.ParseFromString(ap_response_message_bytes) - shared_key = util.int_to_bytes( - self.__keys.compute_shared_key( - ap_response_message_proto.challenge.login_crypto_challenge. - diffie_hellman.gs)) - # Check gs_signature - rsa = RSA.construct((int.from_bytes(self.__server_key, "big"), 65537)) - pkcs1_v1_5 = PKCS1_v1_5.new(rsa) - sha1 = SHA1.new() - sha1.update(ap_response_message_proto.challenge.login_crypto_challenge. - diffie_hellman.gs) - if not pkcs1_v1_5.verify( - sha1, - ap_response_message_proto.challenge.login_crypto_challenge. - diffie_hellman.gs_signature, - ): - raise RuntimeError("Failed signature check!") - # Solve challenge - buffer = io.BytesIO() - for i in range(1, 6): - mac = HMAC.new(shared_key, digestmod=SHA1) - mac.update(acc.read()) - mac.update(bytes([i])) - buffer.write(mac.digest()) - buffer.seek(0) - mac = HMAC.new(buffer.read(20), digestmod=SHA1) - mac.update(acc.read()) - challenge = mac.digest() - client_response_plaintext_proto = Keyexchange.ClientResponsePlaintext( - crypto_response=Keyexchange.CryptoResponseUnion(), - login_crypto_response=Keyexchange.LoginCryptoResponseUnion( - diffie_hellman=Keyexchange.LoginCryptoDiffieHellmanResponse( - hmac=challenge)), - pow_response=Keyexchange.PoWResponseUnion(), - ) - client_response_plaintext_bytes = ( - client_response_plaintext_proto.SerializeToString()) - self.connection.write_int(4 + len(client_response_plaintext_bytes)) - self.connection.write(client_response_plaintext_bytes) - self.connection.flush() - try: - self.connection.set_timeout(1) - scrap = self.connection.read(4) - if len(scrap) == 4: - payload = self.connection.read( - struct.unpack(">i", scrap)[0] - 4) - failed = Keyexchange.APResponseMessage() - failed.ParseFromString(payload) - raise RuntimeError(failed) - except socket.timeout: - pass - finally: - self.connection.set_timeout(0) + """Connect to the Spotify Server. + + This will retry the initial handshake a few times instead of + crashing immediately on transient socket errors or short reads. + """ + max_attempts = 3 + last_exc: typing.Optional[BaseException] = None + + for attempt in range(1, max_attempts + 1): + try: + acc = Session.Accumulator() + # Send ClientHello + nonce = Random.get_random_bytes(0x10) + client_hello_proto = Keyexchange.ClientHello( + build_info=Version.standard_build_info(), + client_nonce=nonce, + cryptosuites_supported=[ + Keyexchange.Cryptosuite.CRYPTO_SUITE_SHANNON + ], + login_crypto_hello=Keyexchange.LoginCryptoHelloUnion( + diffie_hellman=Keyexchange.LoginCryptoDiffieHellmanHello( + gc=self.__keys.public_key_bytes(), + server_keys_known=1, + ), + ), + padding=b"\x1e", + ) + client_hello_bytes = client_hello_proto.SerializeToString() + self.connection.write(b"\x00\x04") + self.connection.write_int(2 + 4 + len(client_hello_bytes)) + self.connection.write(client_hello_bytes) + self.connection.flush() + acc.write(b"\x00\x04") + acc.write_int(2 + 4 + len(client_hello_bytes)) + acc.write(client_hello_bytes) + # Read APResponseMessage + ap_response_message_length = self.connection.read_int() + acc.write_int(ap_response_message_length) + ap_response_message_bytes = self.connection.read( + ap_response_message_length - 4 + ) + acc.write(ap_response_message_bytes) + ap_response_message_proto = Keyexchange.APResponseMessage() + ap_response_message_proto.ParseFromString( + ap_response_message_bytes + ) + shared_key = util.int_to_bytes( + self.__keys.compute_shared_key( + ap_response_message_proto.challenge.login_crypto_challenge. + diffie_hellman.gs + ) + ) + # Check gs_signature + rsa = RSA.construct( + (int.from_bytes(self.__server_key, "big"), 65537) + ) + pkcs1_v1_5 = PKCS1_v1_5.new(rsa) + sha1 = SHA1.new() + sha1.update( + ap_response_message_proto.challenge.login_crypto_challenge. + diffie_hellman.gs + ) + if not pkcs1_v1_5.verify( + sha1, + ap_response_message_proto.challenge.login_crypto_challenge. + diffie_hellman.gs_signature, + ): + raise RuntimeError("Failed signature check!") + # Solve challenge + buffer = io.BytesIO() + for i in range(1, 6): + mac = HMAC.new(shared_key, digestmod=SHA1) + mac.update(acc.read()) + mac.update(bytes([i])) + buffer.write(mac.digest()) + buffer.seek(0) + mac = HMAC.new(buffer.read(20), digestmod=SHA1) + mac.update(acc.read()) + challenge = mac.digest() + client_response_plaintext_proto = ( + Keyexchange.ClientResponsePlaintext( + crypto_response=Keyexchange.CryptoResponseUnion(), + login_crypto_response=Keyexchange.LoginCryptoResponseUnion( + diffie_hellman=Keyexchange.LoginCryptoDiffieHellmanResponse( + hmac=challenge + ) + ), + pow_response=Keyexchange.PoWResponseUnion(), + ) + ) + client_response_plaintext_bytes = ( + client_response_plaintext_proto.SerializeToString() + ) + self.connection.write_int( + 4 + len(client_response_plaintext_bytes) + ) + self.connection.write(client_response_plaintext_bytes) + self.connection.flush() + try: + self.connection.set_timeout(1) + scrap = self.connection.read(4) + if len(scrap) == 4: + payload = self.connection.read( + struct.unpack(">i", scrap)[0] - 4 + ) + failed = Keyexchange.APResponseMessage() + failed.ParseFromString(payload) + raise RuntimeError(failed) + except socket.timeout: + # Normal path: server did not send an error APResponse. + pass + finally: + self.connection.set_timeout(0) + + # If we reach here, the handshake succeeded. + return + + except (ConnectionResetError, OSError, struct.error) as exc: + last_exc = exc + self.logger.warning( + "Handshake attempt %d/%d failed: %s", + attempt, + max_attempts, + exc, + ) + # Close current connection; a new access point will be + # selected on the next attempt. + if self.connection is not None: + try: + self.connection.close() + except Exception: + pass + self.connection = None + + if attempt < max_attempts: + # Pick a new access point and try again after a + # short delay. + address = ApResolver.get_random_accesspoint() + self.logger.info( + "Retrying connection, new access point: %s", address + ) + self.connection = Session.ConnectionHolder.create( + address, None + ) + time.sleep(1) + + # All attempts failed: raise a clear error instead of crashing + # with a low-level struct.error. + raise RuntimeError( + "Failed to connect to Spotify access point after " + f"{max_attempts} attempts" + ) from last_exc buffer.seek(20) with self.__auth_lock: self.cipher_pair = CipherPair(buffer.read(32), buffer.read(32)) @@ -2230,7 +2299,15 @@ class Session(Closeable, MessageListener, SubListener): :returns: Bytes data from socket """ - return self.__socket.recv(length) + # Ensure we either read the requested number of bytes + # or raise a clear error if the connection is closed. + data = b"" + while len(data) < length: + chunk = self.__socket.recv(length - len(data)) + if not chunk: + break + data += chunk + return data def read_int(self) -> int: """Read integer from socket @@ -2239,7 +2316,12 @@ class Session(Closeable, MessageListener, SubListener): :returns: integer from socket """ - return struct.unpack(">i", self.read(4))[0] + data = self.read(4) + if len(data) != 4: + raise ConnectionResetError( + "Unexpected end of stream while reading 4-byte integer" + ) + return struct.unpack(">i", data)[0] def read_short(self) -> int: """Read short integer from socket @@ -2248,7 +2330,12 @@ class Session(Closeable, MessageListener, SubListener): :returns: short integer from socket """ - return struct.unpack(">h", self.read(2))[0] + data = self.read(2) + if len(data) != 2: + raise ConnectionResetError( + "Unexpected end of stream while reading 2-byte integer" + ) + return struct.unpack(">h", data)[0] def set_timeout(self, seconds: float) -> None: """Set socket's timeout