diff options
| author | Janito Vaqueiro Ferreira Filho <janito@mullvad.net> | 2018-05-16 12:48:15 -0300 |
|---|---|---|
| committer | Janito Vaqueiro Ferreira Filho <janito@mullvad.net> | 2018-05-16 12:48:15 -0300 |
| commit | 5493dd910d29a6f44addba645de0d2c75992b9e5 (patch) | |
| tree | e73f915de6ae74b0ace93c7143d8603eb2f10c6a | |
| parent | 37e54738b50e9f84996e323e1d069bb3bcde421d (diff) | |
| parent | c91323f774f24be5a6bfe6a373b552095043eca1 (diff) | |
| download | mullvadvpn-5493dd910d29a6f44addba645de0d2c75992b9e5.tar.xz mullvadvpn-5493dd910d29a6f44addba645de0d2c75992b9e5.zip | |
Merge branch 'pubsub-ws-ipc'
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | talpid-ipc/Cargo.toml | 1 | ||||
| -rw-r--r-- | talpid-ipc/src/client.rs | 449 | ||||
| -rw-r--r-- | talpid-ipc/src/lib.rs | 3 |
4 files changed, 382 insertions, 72 deletions
diff --git a/Cargo.lock b/Cargo.lock index fe0cd000f6..38339b7f23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1392,6 +1392,7 @@ dependencies = [ "error-chain 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", "jsonrpc-core 8.0.1 (git+https://github.com/paritytech/jsonrpc?tag=v8.0.1)", "jsonrpc-macros 8.0.0 (git+https://github.com/paritytech/jsonrpc?tag=v8.0.1)", + "jsonrpc-pubsub 8.0.0 (git+https://github.com/paritytech/jsonrpc?tag=v8.0.1)", "jsonrpc-ws-server 8.0.0 (git+https://github.com/paritytech/jsonrpc?tag=v8.0.1)", "log 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.45 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/talpid-ipc/Cargo.toml b/talpid-ipc/Cargo.toml index 79e164bb51..4d02460ad2 100644 --- a/talpid-ipc/Cargo.toml +++ b/talpid-ipc/Cargo.toml @@ -11,6 +11,7 @@ serde = "1.0" serde_json = "1.0" log = "0.4" jsonrpc-core = { git = "https://github.com/paritytech/jsonrpc", tag = "v8.0.1" } +jsonrpc-pubsub = { git = "https://github.com/paritytech/jsonrpc", tag = "v8.0.1" } jsonrpc-ws-server = { git = "https://github.com/paritytech/jsonrpc", tag = "v8.0.1" } ws = { git = "https://github.com/tomusdrw/ws-rs" } url = "1.4" diff --git a/talpid-ipc/src/client.rs b/talpid-ipc/src/client.rs index 9e76bfe443..9c22c5e7fa 100644 --- a/talpid-ipc/src/client.rs +++ b/talpid-ipc/src/client.rs @@ -1,10 +1,12 @@ -use std::sync::{mpsc, Arc, Mutex, MutexGuard}; +use std::collections::HashMap; +use std::sync::mpsc; use std::thread; use error_chain::ChainedError; +use jsonrpc_pubsub::SubscriptionId; use serde; use serde_json::{self, Result as JsonResult, Value as JsonValue}; -use url; +use url::Url; use ws; type JsonMap = serde_json::map::Map<String, JsonValue>; @@ -12,16 +14,81 @@ type JsonMap = serde_json::map::Map<String, JsonValue>; mod errors { error_chain! { errors { + ConnectError(details: &'static str) { + description("Failed to connect to RPC server") + display("Failed to connect to RPC server: {}", details) + } + + ConnectionHandlerStopped { + description("The WebSocket connection handler thread has stopped") + } + ErrorResponse(error_message: String) { description("Received an RPC error response") display("Received an RPC error response: {}", error_message) } + DeserializeResponseError { + description("Failed to deserialize response") + } + + DeserializeSubscriptionEvent(event: String) { + description("Failed to deserialize RPC subscription event") + display("Failed to deserialize RPC subscription event {}", event) + } + + ForwardSubscriptionEvent(event: String) { + description("Failed to forward RPC subscription event") + display("Failed to forward RPC subscription event {}", event) + } + InvalidJsonRpcResponse(details: &'static str) { description("Received an invalid JSON-RPC response") display("Received an invalid JSON-RPC response: {}", details) } + InvalidServerIdUrl(server_id: ::IpcServerId) { + description("Unable to parse given server ID as a URL") + display("Unable to parse given server ID as a URL: {}", server_id) + } + + InvalidSubscriptionEvent(details: &'static str) { + description("Received an invalid JSON-RPC PubSub event") + display("Received an invalid JSON-RPC PubSub event: {}", details) + } + + InvalidSubscriptionId(raw_id: ::serde_json::Value) { + description("Received an invalid JSON-RPC subscription ID for subscribe request") + display( + "Received an invalid JSON-RPC subscription ID for subscribe request: {}", + raw_id, + ) + } + + MissingResponse { + description("No response received") + } + + SendRequestError(method: String) { + description("Failed to send a request to call a remote JSON-RPC procedure") + display( + "Failed to send a request to call the \"{}\" remote JSON-RPC procedure", + method + ) + } + + SerializeArgumentsError { + description("Failed to serialize JSON-RPC request arguments") + } + + SerializeSubscriptionId { + description("Failed to serialize JSON-RPC subscription ID") + } + + UnsubscribeError { + description("Failed to unsubscribe from a remote event") + } + WebSocketError { description("Error with WebSocket connection") } @@ -30,6 +97,13 @@ mod errors { } pub use self::errors::*; +#[derive(Debug, Eq, PartialEq)] +pub enum SubscriptionHandlerResult { + Active, + Finished, +} + +type SubscriptionHandler = Box<Fn(JsonValue) -> SubscriptionHandlerResult + Send>; struct ActiveRequest { id: i64, @@ -50,8 +124,34 @@ impl ActiveRequest { } } +enum WsIpcCommand { + Call { + method: String, + arguments: JsonValue, + response_tx: mpsc::Sender<Result<JsonValue>>, + }, + + Subscribe { + id: SubscriptionId, + handler: SubscriptionHandler, + unsubscribe_method: String, + }, + + Response { + id: i64, + result: Result<JsonValue>, + }, + + Notification { + subscription: SubscriptionId, + event: JsonValue, + }, + + Error(Error), +} + struct Factory { - active_request: Arc<Mutex<Option<ActiveRequest>>>, + connection_tx: mpsc::Sender<WsIpcCommand>, sender_tx: mpsc::Sender<ws::Sender>, } @@ -64,35 +164,38 @@ impl ws::Factory for Factory { let _ = self.sender_tx.send(sender); Handler { - active_request: self.active_request.clone(), + connection_tx: self.connection_tx.clone(), } } } struct Handler { - active_request: Arc<Mutex<Option<ActiveRequest>>>, + connection_tx: mpsc::Sender<WsIpcCommand>, } impl Handler { fn process_message(&mut self, msg: ws::Message) -> Result<()> { trace!("WsIpcClient incoming message: {:?}", msg); - let mut response_json_object = self.parse_message_object(msg)?; - let response_id = self.parse_response_id(&mut response_json_object)?; - let rpc_result = self.parse_response_result(response_json_object); + let mut message_json_object = self.parse_message_object(msg)?; + let response_id = self.parse_response_id(&mut message_json_object)?; - let mut active_request = self.lock_active_request(); + let command = if let Some(id) = response_id { + let result = self.parse_response_result(message_json_object); - if let Some(mut request) = active_request.take() { - if response_id == request.id() { - let _ = request.send_response(rpc_result); - } else { - warn!("Received an unexpect JSON-RPC message"); - *active_request = Some(request); + WsIpcCommand::Response { id, result } + } else { + let (subscription, event) = self.parse_subscription_event(message_json_object)?; + + WsIpcCommand::Notification { + subscription, + event, } - } + }; - Ok(()) + self.connection_tx + .send(command) + .chain_err(|| ErrorKind::ConnectionHandlerStopped) } fn parse_message_object(&self, msg: ws::Message) -> Result<JsonMap> { @@ -119,13 +222,13 @@ impl Handler { Ok(json_object_map) } - fn parse_response_id(&self, json_object_map: &mut JsonMap) -> Result<i64> { + fn parse_response_id(&self, json_object_map: &mut JsonMap) -> Result<Option<i64>> { match json_object_map.remove("id") { - Some(JsonValue::Number(id)) => id.as_i64().ok_or_else(|| { + Some(JsonValue::Number(id)) => id.as_i64().map(Some).ok_or_else(|| { ErrorKind::InvalidJsonRpcResponse("Invalid request ID number").into() }), - None => Err(ErrorKind::InvalidJsonRpcResponse("Missing request ID").into()), - _ => Err(ErrorKind::InvalidJsonRpcResponse("Invalid request ID value").into()), + Some(_) => Err(ErrorKind::InvalidJsonRpcResponse("Invalid request ID value").into()), + None => Ok(None), } } @@ -148,10 +251,30 @@ impl Handler { } } - fn lock_active_request(&mut self) -> MutexGuard<Option<ActiveRequest>> { - self.active_request - .lock() - .expect("a thread panicked while using the active JSON-RPC request") + fn parse_subscription_event( + &mut self, + mut notification: JsonMap, + ) -> Result<(SubscriptionId, JsonValue)> { + match notification.remove("params") { + Some(JsonValue::Object(mut parameters)) => { + let raw_id = parameters + .remove("subscription") + .ok_or_else(|| ErrorKind::InvalidSubscriptionEvent("Missing subscription ID"))?; + let id = SubscriptionId::parse_value(&raw_id) + .ok_or_else(|| ErrorKind::InvalidSubscriptionEvent("Invalid subscription ID"))?; + let event = parameters + .remove("result") + .ok_or_else(|| ErrorKind::InvalidSubscriptionEvent("Missing event data"))?; + + Ok((id, event)) + } + Some(_) => bail!(ErrorKind::InvalidSubscriptionEvent( + "RPC parameters is not a JSON object map" + )), + None => bail!(ErrorKind::InvalidSubscriptionEvent( + "Missing RPC parameters" + )), + } } } @@ -166,75 +289,194 @@ impl ws::Handler for Handler { } fn on_error(&mut self, error: ws::Error) { - if let Some(active_request) = self.lock_active_request().as_mut() { - active_request.send_response(Err(error).chain_err(|| ErrorKind::WebSocketError)); - } + let error = Error::with_chain(error, ErrorKind::WebSocketError); + + let _ = self.connection_tx.send(WsIpcCommand::Error(error)); } } - pub struct WsIpcClient { - next_id: i64, - active_request: Arc<Mutex<Option<ActiveRequest>>>, - sender: ws::Sender, + connection_tx: mpsc::Sender<WsIpcCommand>, } impl WsIpcClient { pub fn connect(server_id: &::IpcServerId) -> Result<Self> { - let url = url::Url::parse(server_id).chain_err(|| "Unable to parse server_id as url")?; - let active_request = Arc::new(Mutex::new(None)); - let sender = Self::open_websocket(url, active_request.clone())?; + let url = Url::parse(&server_id) + .chain_err(|| ErrorKind::InvalidServerIdUrl(server_id.to_owned()))?; + let (connection_tx, connection_rx) = mpsc::channel(); + let sender = Self::open_websocket(url, connection_tx.clone())?; - Ok(WsIpcClient { - next_id: 1, - active_request, - sender, - }) + WsIpcClientConnection::spawn(sender, connection_rx); + + Ok(WsIpcClient { connection_tx }) } - fn open_websocket( - url: url::Url, - active_request: Arc<Mutex<Option<ActiveRequest>>>, - ) -> Result<ws::Sender> { + fn open_websocket(url: Url, connection_tx: mpsc::Sender<WsIpcCommand>) -> Result<ws::Sender> { let (sender_tx, sender_rx) = mpsc::channel(); let factory = Factory { - active_request, + connection_tx, sender_tx, }; - let mut websocket = ws::WebSocket::new(factory).chain_err(|| "Unable to create WebSocket")?; + let mut websocket = ws::WebSocket::new(factory) + .chain_err(|| ErrorKind::ConnectError("Unable to create WebSocket"))?; websocket .connect(url) - .chain_err(|| "Unable to connect WebSocket to URL")?; + .chain_err(|| ErrorKind::ConnectError("Unable to connect WebSocket to URL"))?; thread::spawn(move || { let result = websocket .run() - .chain_err(|| "Error while running WebSocket event loop"); + .chain_err(|| ErrorKind::ConnectError("Error while running WebSocket event loop")); if let Err(error) = result { error!("{}", error.display_chain()); } }); - sender_rx.recv().chain_err(|| "WebSocket connection failed") + sender_rx + .recv() + .chain_err(|| ErrorKind::ConnectError("WebSocket connection failed")) + } + + pub fn subscribe<V, M>( + &mut self, + subscribe_method: String, + unsubscribe_method: String, + sender: mpsc::Sender<M>, + ) -> Result<()> + where + V: for<'de> serde::Deserialize<'de>, + M: From<V> + Send + 'static, + { + let raw_subscription_id = self.call(&subscribe_method, &[] as &[u8; 0])?; + let subscription_id = SubscriptionId::parse_value(&raw_subscription_id) + .ok_or_else(|| ErrorKind::InvalidSubscriptionId(raw_subscription_id))?; + + let handler = move |json_value| match forward_subscription_event( + &subscribe_method, + json_value, + &sender, + ) { + Ok(()) => SubscriptionHandlerResult::Active, + Err(error) => { + error!("{}", error.display_chain()); + SubscriptionHandlerResult::Finished + } + }; + + self.register_subscription(subscription_id, handler, unsubscribe_method)?; + + Ok(()) + } + + fn register_subscription<H>( + &mut self, + id: SubscriptionId, + handler: H, + unsubscribe_method: String, + ) -> Result<()> + where + H: Fn(JsonValue) -> SubscriptionHandlerResult + Send + 'static, + { + self.connection_tx + .send(WsIpcCommand::Subscribe { + id, + handler: Box::new(handler), + unsubscribe_method, + }) + .chain_err(|| ErrorKind::ConnectionHandlerStopped) } - pub fn call<T, O>(&mut self, method: &str, params: &T) -> Result<O> + pub fn call<S, T, O>(&mut self, method: S, params: &T) -> Result<O> where + S: ToString, T: serde::Serialize, O: for<'de> serde::Deserialize<'de>, { - let id = self.new_id(); - let (result_tx, result_rx) = mpsc::channel(); + let arguments = + serde_json::to_value(params).chain_err(|| ErrorKind::SerializeArgumentsError)?; + let (response_tx, response_rx) = mpsc::channel(); + let command = WsIpcCommand::Call { + method: method.to_string(), + arguments, + response_tx, + }; + + self.connection_tx + .send(command) + .chain_err(|| ErrorKind::ConnectionHandlerStopped)?; - self.queue_request_response(id, result_tx); - self.send_request(id, method, params)?; + let json_result = response_rx.recv().chain_err(|| ErrorKind::MissingResponse)?; - let json_result = result_rx.recv().chain_err(|| "No response received")?; + Ok(serde_json::from_value(json_result?).chain_err(|| ErrorKind::DeserializeResponseError)?) + } +} + +struct WsIpcClientConnection { + next_id: i64, + active_request: Option<ActiveRequest>, + active_subscriptions: HashMap<SubscriptionId, (SubscriptionHandler, String)>, + sender: ws::Sender, +} - Ok(serde_json::from_value(json_result?).chain_err(|| "Failed to deserialize RPC result")?) +impl WsIpcClientConnection { + pub fn spawn(sender: ws::Sender, commands: mpsc::Receiver<WsIpcCommand>) { + let mut instance = WsIpcClientConnection { + next_id: 1, + active_request: None, + active_subscriptions: HashMap::new(), + sender, + }; + + thread::spawn(move || { + if let Err(error) = instance.run(commands) { + let chained_error = Error::with_chain(error, "WsIpcClient event loop error"); + error!("{}", chained_error.display_chain()); + } + }); + } + + fn run(&mut self, commands: mpsc::Receiver<WsIpcCommand>) -> Result<()> { + use self::WsIpcCommand::*; + + for command in commands { + match command { + Call { + method, + arguments, + response_tx, + } => self.call(method, arguments, response_tx)?, + Subscribe { + id, + handler, + unsubscribe_method, + } => { + self.active_subscriptions + .insert(id, (handler, unsubscribe_method)); + } + Response { id, result } => self.handle_response(id, result)?, + Notification { + subscription, + event, + } => self.handle_notification(subscription, event)?, + Error(error) => self.handle_error(error), + } + } + + Ok(()) + } + + fn call( + &mut self, + method: String, + arguments: JsonValue, + response_tx: mpsc::Sender<Result<JsonValue>>, + ) -> Result<()> { + let id = self.new_id(); + self.queue_request_response(id, response_tx); + self.send_request(id, method, arguments) } fn new_id(&mut self) -> i64 { @@ -243,29 +485,19 @@ impl WsIpcClient { id } - fn queue_request_response(&mut self, id: i64, result_tx: mpsc::Sender<Result<JsonValue>>) { - let mut active_request = self.active_request - .lock() - .expect("a thread panicked using the active RPC request map"); - - *active_request = Some(ActiveRequest::new(id, result_tx)); + fn queue_request_response(&mut self, id: i64, response_tx: mpsc::Sender<Result<JsonValue>>) { + self.active_request = Some(ActiveRequest::new(id, response_tx)); } - fn send_request<T>(&mut self, id: i64, method: &str, params: &T) -> Result<()> - where - T: serde::Serialize, - { - let json_request = self.build_json_request(id, method, params); + fn send_request(&mut self, id: i64, method: String, arguments: JsonValue) -> Result<()> { + let json_request = self.build_json_request(id, &method, arguments); self.sender .send(json_request.as_bytes()) - .chain_err(|| "Unable to send jsonrpc request") + .chain_err(|| ErrorKind::SendRequestError(method)) } - fn build_json_request<T>(&mut self, id: i64, method: &str, params: &T) -> String - where - T: serde::Serialize, - { + fn build_json_request(&mut self, id: i64, method: &str, params: JsonValue) -> String { let request_json = json!({ "jsonrpc": "2.0", "id": id, @@ -274,4 +506,77 @@ impl WsIpcClient { }); format!("{}", request_json) } + + fn handle_response(&mut self, id: i64, result: Result<JsonValue>) -> Result<()> { + if let Some(mut request) = self.active_request.take() { + if request.id() == id { + request.send_response(result); + } else { + self.active_request = Some(request); + warn!("Received an unexpected response with ID {}", id); + } + } else { + warn!("Received an unexpected response with ID {}", id); + } + + Ok(()) + } + + fn handle_notification(&mut self, id: SubscriptionId, event: JsonValue) -> Result<()> { + let unsubscribe_method = + if let Some((handler, unsubscribe_method)) = self.active_subscriptions.get(&id) { + match handler(event) { + SubscriptionHandlerResult::Active => None, + SubscriptionHandlerResult::Finished => Some(unsubscribe_method.clone()), + } + } else { + warn!("Received an unexpected notification"); + None + }; + + if let Some(method) = unsubscribe_method { + self.unsubscribe(method, id)?; + } + + Ok(()) + } + + fn unsubscribe(&mut self, method: String, id: SubscriptionId) -> Result<()> { + self.active_subscriptions.remove(&id); + + let (result_tx, _) = mpsc::channel(); + let arguments = match id { + SubscriptionId::Number(id) => serde_json::to_value(&[id]), + SubscriptionId::String(id) => serde_json::to_value(&[id]), + }.chain_err(|| ErrorKind::SerializeSubscriptionId); + + self.call(method, arguments?, result_tx) + .chain_err(|| ErrorKind::UnsubscribeError) + } + + fn handle_error(&mut self, error: Error) { + if let Some(ref mut request) = self.active_request { + let _ = request.response_tx.send(Err(error)); + } else { + error!("{}", error.display_chain()); + } + } +} + +fn forward_subscription_event<V, M>( + subscribe_method: &String, + json_value: JsonValue, + sender: &mpsc::Sender<M>, +) -> Result<()> +where + V: for<'de> serde::Deserialize<'de>, + M: From<V> + Send + 'static, +{ + let value: V = serde_json::from_value(json_value) + .chain_err(|| ErrorKind::DeserializeSubscriptionEvent(subscribe_method.clone()))?; + let message = M::from(value); + + sender + .send(message) + .chain_err(|| ErrorKind::ForwardSubscriptionEvent(subscribe_method.clone())) } diff --git a/talpid-ipc/src/lib.rs b/talpid-ipc/src/lib.rs index e47de09cba..6a8a523bea 100644 --- a/talpid-ipc/src/lib.rs +++ b/talpid-ipc/src/lib.rs @@ -6,6 +6,8 @@ //! GNU General Public License as published by the Free Software Foundation, either version 3 of //! the License, or (at your option) any later version. +#![recursion_limit = "128"] + #[macro_use] extern crate error_chain; #[macro_use] @@ -16,6 +18,7 @@ extern crate serde; extern crate serde_json; extern crate jsonrpc_core; +extern crate jsonrpc_pubsub; extern crate jsonrpc_ws_server; extern crate url; extern crate ws; |
