Convert to async

This commit is contained in:
Maxime Augier 2024-08-29 15:00:27 +02:00
parent 0a5891b5be
commit 17a3cf1018
5 changed files with 138 additions and 115 deletions

View File

@ -15,10 +15,13 @@ categories = ["api-bindings"]
[dependencies] [dependencies]
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }
futures-util = { version = "0.3.30", features = ["futures-sink"] }
reqwest = { version = "0.12.7", features = ["json"] }
serde = { version = "1.0.204", features = ["derive"] } serde = { version = "1.0.204", features = ["derive"] }
serde_json = "1.0.121" serde_json = "1.0.121"
serde_repr = "0.1.19" serde_repr = "0.1.19"
thiserror = "1.0.63" thiserror = "1.0.63"
tokio = "1.39.3"
tokio-tungstenite = { version = "0.23.1", features = ["tokio-rustls", "rustls-tls-native-roots"] }
tracing = "0.1.40" tracing = "0.1.40"
tungstenite = { version = "0.23.0", optional = true, features = ["rustls-tls-native-roots"] } tungstenite = { version = "0.23.0", optional = true, features = ["rustls-tls-native-roots"] }
ureq = { version = "2.10.0", features = ["json"] }

View File

@ -1,7 +1,5 @@
use std::{ use std::{
io, io, ops::{Add, Mul, Sub}, time::{Duration, Instant, SystemTime, UNIX_EPOCH}
ops::{Add, Mul, Sub},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
}; };
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};
@ -9,11 +7,14 @@ use serde_repr::Deserialize_repr;
use thiserror::Error; use thiserror::Error;
use tracing::{debug, info, instrument}; use tracing::{debug, info, instrument};
pub use reqwest::{self, StatusCode};
pub struct Context { pub struct Context {
auth_header: String, auth_header: String,
refresh_token: String, refresh_token: String,
token_expiration: Instant, token_expiration: Instant,
on_refresh: Option<Box<dyn FnMut(&mut Self) + Send>>, on_refresh: Option<Box<dyn FnMut(&mut Self) + Send>>,
client: reqwest::Client,
} }
impl std::fmt::Debug for Context { impl std::fmt::Debug for Context {
@ -316,7 +317,7 @@ pub enum ApiError {
/// HTTP call failed (404, etc) /// HTTP call failed (404, etc)
#[error("ureq")] #[error("ureq")]
Ureq(#[source] Box<ureq::Error>), HTTP(#[source] Box<reqwest::Error>),
/// HTTP call succeeded but the returned JSON document didn't match the expected format /// HTTP call succeeded but the returned JSON document didn't match the expected format
#[error("unexpected data: {1} when processing {0}")] #[error("unexpected data: {1} when processing {0}")]
@ -334,20 +335,20 @@ pub enum ApiError {
InvalidID(String), InvalidID(String),
} }
impl From<ureq::Error> for ApiError { impl From<reqwest::Error> for ApiError {
fn from(value: ureq::Error) -> Self { fn from(value: reqwest::Error) -> Self {
ApiError::Ureq(Box::new(value)) ApiError::HTTP(Box::new(value))
} }
} }
trait JsonExplicitError { trait JsonExplicitError {
/// Explicitely report the received JSON object we failed to parse /// Explicitely report the received JSON object we failed to parse
fn into_json_with_error<T: DeserializeOwned>(self) -> Result<T, ApiError>; async fn into_json_with_error<T: DeserializeOwned>(self) -> Result<T, ApiError>;
} }
impl JsonExplicitError for ureq::Response { impl JsonExplicitError for reqwest::Response {
fn into_json_with_error<T: DeserializeOwned>(self) -> Result<T, ApiError> { async fn into_json_with_error<T: DeserializeOwned>(self) -> Result<T, ApiError> {
let resp: serde_json::Value = self.into_json()?; let resp: serde_json::Value = self.json().await?;
let parsed = T::deserialize(&resp); let parsed = T::deserialize(&resp);
parsed.map_err(|e| ApiError::UnexpectedData(resp, e)) parsed.map_err(|e| ApiError::UnexpectedData(resp, e))
} }
@ -363,12 +364,13 @@ pub enum TokenParseError {
} }
impl Context { impl Context {
fn from_login_response(resp: LoginResponse) -> Self { fn from_login_response(resp: LoginResponse, client: reqwest::Client) -> Self {
Self { Self {
auth_header: format!("Bearer {}", &resp.access_token), auth_header: format!("Bearer {}", &resp.access_token),
refresh_token: resp.refresh_token, refresh_token: resp.refresh_token,
token_expiration: (Instant::now() + Duration::from_secs(resp.expires_in as u64)), token_expiration: (Instant::now() + Duration::from_secs(resp.expires_in as u64)),
on_refresh: None, on_refresh: None,
client,
} }
} }
@ -389,6 +391,7 @@ impl Context {
refresh_token: refresh.to_owned(), refresh_token: refresh.to_owned(),
token_expiration, token_expiration,
on_refresh: None, on_refresh: None,
client: reqwest::Client::new(),
}) })
} }
@ -410,7 +413,7 @@ impl Context {
} }
/// Retrieve access tokens online, by logging in with the provided credentials /// Retrieve access tokens online, by logging in with the provided credentials
pub fn from_login(user: &str, password: &str) -> Result<Self, ApiError> { pub async fn from_login(user: &str, password: &str) -> Result<Self, ApiError> {
#[derive(Serialize)] #[derive(Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct Params<'t> { struct Params<'t> {
@ -418,23 +421,26 @@ impl Context {
password: &'t str, password: &'t str,
} }
let client = reqwest::Client::new();
info!("Logging into API"); info!("Logging into API");
let url: String = format!("{}accounts/login", API_BASE); let url: String = format!("{}accounts/login", API_BASE);
let resp: LoginResponse = ureq::post(&url) let resp: LoginResponse = client.post(&url)
.send_json(Params { .json(&Params {
user_name: user, user_name: user,
password, password,
})? })
.into_json_with_error()?; .send().await?
.into_json_with_error().await?;
Ok(Self::from_login_response(resp)) Ok(Self::from_login_response(resp, client))
} }
/// Check if the token has reached its expiration date /// Check if the token has reached its expiration date
fn check_expired(&mut self) -> Result<(), ApiError> { async fn check_expired(&mut self) -> Result<(), ApiError> {
if self.token_expiration < Instant::now() { if self.token_expiration < Instant::now() {
debug!("Token has expired"); debug!("Token has expired");
self.refresh_token()?; self.refresh_token().await?;
} }
Ok(()) Ok(())
} }
@ -444,7 +450,7 @@ impl Context {
} }
/// Use the refresh token to refresh credentials /// Use the refresh token to refresh credentials
pub fn refresh_token(&mut self) -> Result<(), ApiError> { pub async fn refresh_token(&mut self) -> Result<(), ApiError> {
#[derive(Serialize)] #[derive(Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct Params<'t> { struct Params<'t> {
@ -456,51 +462,52 @@ impl Context {
refresh_token: &self.refresh_token, refresh_token: &self.refresh_token,
}; };
let url = format!("{}accounts/refresh_token", API_BASE); let url = format!("{}accounts/refresh_token", API_BASE);
let resp: LoginResponse = ureq::post(&url) let resp: LoginResponse = self.client.post(&url)
.set("Content-type", "application/json") .header("Content-type", "application/json")
.send_json(params)? .json(&params)
.into_json_with_error()?; .send().await?
.into_json_with_error().await?;
*self = Self::from_login_response(resp); *self = Self::from_login_response(resp, self.client.clone());
Ok(()) Ok(())
} }
/// List all sites available to the user /// List all sites available to the user
pub fn sites(&mut self) -> Result<Vec<Site>, ApiError> { pub async fn sites(&mut self) -> Result<Vec<Site>, ApiError> {
self.get("sites") self.get("sites").await
} }
pub fn site(&mut self, id: i32) -> Result<SiteDetails, ApiError> { pub async fn site(&mut self, id: i32) -> Result<SiteDetails, ApiError> {
self.get(&format!("sites/{id}")) self.get(&format!("sites/{id}")).await
} }
/// List all chargers available to the user /// List all chargers available to the user
pub fn chargers(&mut self) -> Result<Vec<Charger>, ApiError> { pub async fn chargers(&mut self) -> Result<Vec<Charger>, ApiError> {
self.get("chargers") self.get("chargers").await
} }
pub fn charger(&mut self, id: &str) -> Result<Charger, ApiError> { pub async fn charger(&mut self, id: &str) -> Result<Charger, ApiError> {
if !id.chars().all(char::is_alphanumeric) { if !id.chars().all(char::is_alphanumeric) {
return Err(ApiError::InvalidID(id.to_owned())); return Err(ApiError::InvalidID(id.to_owned()));
} }
self.get(&format!("chargers/{}", id)) self.get(&format!("chargers/{}", id)).await
} }
pub fn circuit(&mut self, site_id: u32, circuit_id: u32) -> Result<Circuit, ApiError> { pub async fn circuit(&mut self, site_id: u32, circuit_id: u32) -> Result<Circuit, ApiError> {
self.get(&format!("site/{site_id}/circuit/{circuit_id}")) self.get(&format!("site/{site_id}/circuit/{circuit_id}")).await
} }
pub fn circuit_dynamic_current( pub async fn circuit_dynamic_current(
&mut self, &mut self,
site_id: u32, site_id: u32,
circuit_id: u32, circuit_id: u32,
) -> Result<Triphase, ApiError> { ) -> Result<Triphase, ApiError> {
self.get(&format!( self.get(&format!(
"sites/{site_id}/circuits/{circuit_id}/dynamicCurrent" "sites/{site_id}/circuits/{circuit_id}/dynamicCurrent"
)) )).await
} }
pub fn set_circuit_dynamic_current( pub async fn set_circuit_dynamic_current(
&mut self, &mut self,
site_id: u32, site_id: u32,
circuit_id: u32, circuit_id: u32,
@ -509,65 +516,67 @@ impl Context {
self.post( self.post(
&format!("sites/{site_id}/circuits/{circuit_id}/dynamicCurrent"), &format!("sites/{site_id}/circuits/{circuit_id}/dynamicCurrent"),
&current, &current,
) ).await
} }
#[instrument] #[instrument]
fn get<T: DeserializeOwned>(&mut self, path: &str) -> Result<T, ApiError> { async fn get<T: DeserializeOwned>(&mut self, path: &str) -> Result<T, ApiError> {
self.check_expired()?; self.check_expired().await?;
let url: String = format!("{}{}", API_BASE, path); let url: String = format!("{}{}", API_BASE, path);
let req = ureq::get(&url)
.set("Accept", "application/json")
.set("Authorization", &self.auth_header);
let mut resp = req.clone().call()?; let req = self.client.get(url)
.header("Accept", "application/json")
.header("Authorization", &self.auth_header)
.build()?;
let mut resp = self.client.execute(req.try_clone().unwrap()).await?;
if resp.status() == 401 { if resp.status() == 401 {
self.refresh_token()?; self.refresh_token().await?;
resp = req.call()? resp = self.client.execute(req).await?
} }
resp.into_json_with_error() resp.into_json_with_error().await
} }
fn maybe_get<T: DeserializeOwned>(&mut self, path: &str) -> Result<Option<T>, ApiError> { async fn maybe_get<T: DeserializeOwned>(&mut self, path: &str) -> Result<Option<T>, ApiError> {
match self.get(path) { match self.get(path).await {
Ok(r) => Ok(Some(r)), Ok(r) => Ok(Some(r)),
Err(ApiError::Ureq(e)) => match &*e { Err(ApiError::HTTP(e)) if e.status() == Some(StatusCode::NOT_FOUND)=> Ok(None),
ureq::Error::Status(404, _) => Ok(None),
_ => Err(ApiError::Ureq(e)),
},
Err(other) => Err(other), Err(other) => Err(other),
} }
} }
pub(crate) fn post<T: DeserializeOwned, P: Serialize>( pub(crate) async fn post<T: DeserializeOwned, P: Serialize>(
&mut self, &mut self,
path: &str, path: &str,
params: &P, params: &P,
) -> Result<T, ApiError> { ) -> Result<T, ApiError> {
let url: String = format!("{}{}", API_BASE, path); let url: String = format!("{}{}", API_BASE, path);
self.post_raw(&url, params) self.post_raw(&url, params).await
} }
pub(crate) fn post_raw<T: DeserializeOwned, P: Serialize>( pub(crate) async fn post_raw<T: DeserializeOwned, P: Serialize>(
&mut self, &mut self,
url: &str, url: &str,
params: &P, params: &P,
) -> Result<T, ApiError> { ) -> Result<T, ApiError> {
self.check_expired()?; self.check_expired().await?;
let req = ureq::post(url) let req = self.client.post(url)
.set("Accept", "application/json") .header("Accept", "application/json")
.set("Authorization", &self.auth_header); .header("Authorization", &self.auth_header)
.json(params);
let mut resp = req.clone().send_json(params)?; let mut resp = req
.try_clone().unwrap()
.send().await?;
if resp.status() == 401 { if resp.status() == 401 {
self.refresh_token()?; self.refresh_token().await?;
resp = req.send_json(params)? resp = req.send().await?
} }
resp.into_json_with_error() resp.into_json_with_error().await
} }
} }
@ -584,12 +593,12 @@ pub struct MeterReading {
impl Site { impl Site {
/// Read all energy meters from the given site /// Read all energy meters from the given site
pub fn lifetime_energy(&self, ctx: &mut Context) -> Result<Vec<MeterReading>, ApiError> { pub async fn lifetime_energy(&self, ctx: &mut Context) -> Result<Vec<MeterReading>, ApiError> {
ctx.get(&format!("sites/{}/energy", self.id)) ctx.get(&format!("sites/{}/energy", self.id)).await
} }
pub fn details(&self, ctx: &mut Context) -> Result<SiteDetails, ApiError> { pub async fn details(&self, ctx: &mut Context) -> Result<SiteDetails, ApiError> {
ctx.get(&format!("sites/{}", self.id)) ctx.get(&format!("sites/{}", self.id)).await
} }
} }
@ -598,63 +607,63 @@ impl Circuit {
format!("sites/{}/circuits/{}/dynamicCurrent", self.site_id, self.id) format!("sites/{}/circuits/{}/dynamicCurrent", self.site_id, self.id)
} }
pub fn dynamic_current(&self, ctx: &mut Context) -> Result<Triphase, ApiError> { pub async fn dynamic_current(&self, ctx: &mut Context) -> Result<Triphase, ApiError> {
ctx.circuit_dynamic_current(self.site_id, self.id) ctx.circuit_dynamic_current(self.site_id, self.id).await
} }
pub fn set_dynamic_current( pub async fn set_dynamic_current(
&self, &self,
ctx: &mut Context, ctx: &mut Context,
current: SetCurrent, current: SetCurrent,
) -> Result<(), ApiError> { ) -> Result<(), ApiError> {
ctx.post(&self.dynamic_current_path(), &current) ctx.post(&self.dynamic_current_path(), &current).await
} }
} }
impl Charger { impl Charger {
/// Enable "smart charging" on the charger. This just turns the LED blue, and disables basic charging plans. /// Enable "smart charging" on the charger. This just turns the LED blue, and disables basic charging plans.
pub fn enable_smart_charging(&self, ctx: &mut Context) -> Result<(), ApiError> { pub async fn enable_smart_charging(&self, ctx: &mut Context) -> Result<(), ApiError> {
let url = format!("chargers/{}/commands/smart_charging", &self.id); let url = format!("chargers/{}/commands/smart_charging", &self.id);
ctx.post(&url, &()) ctx.post(&url, &()).await
} }
/// Read the state of a charger /// Read the state of a charger
pub fn state(&self, ctx: &mut Context) -> Result<ChargerState, ApiError> { pub async fn state(&self, ctx: &mut Context) -> Result<ChargerState, ApiError> {
let url = format!("chargers/{}/state", self.id); let url = format!("chargers/{}/state", self.id);
ctx.get(&url) ctx.get(&url).await
} }
/// Read info about the ongoing charging session /// Read info about the ongoing charging session
pub fn ongoing_session(&self, ctx: &mut Context) -> Result<Option<ChargingSession>, ApiError> { pub async fn ongoing_session(&self, ctx: &mut Context) -> Result<Option<ChargingSession>, ApiError> {
ctx.maybe_get(&format!("chargers/{}/sessions/ongoing", &self.id)) ctx.maybe_get(&format!("chargers/{}/sessions/ongoing", &self.id)).await
} }
/// Read info about the last charging session (not including ongoing one) /// Read info about the last charging session (not including ongoing one)
pub fn latest_session(&self, ctx: &mut Context) -> Result<Option<ChargingSession>, ApiError> { pub async fn latest_session(&self, ctx: &mut Context) -> Result<Option<ChargingSession>, ApiError> {
ctx.maybe_get(&format!("chargers/{}/sessions/latest", &self.id)) ctx.maybe_get(&format!("chargers/{}/sessions/latest", &self.id)).await
} }
fn command(&self, ctx: &mut Context, command: &str) -> Result<CommandReply, ApiError> { async fn command(&self, ctx: &mut Context, command: &str) -> Result<CommandReply, ApiError> {
ctx.post(&format!("chargers/{}/commands/{}", self.id, command), &()) ctx.post(&format!("chargers/{}/commands/{}", self.id, command), &()).await
} }
pub fn start(&self, ctx: &mut Context) -> Result<(), ApiError> { pub async fn start(&self, ctx: &mut Context) -> Result<(), ApiError> {
self.command(ctx, "start_charging")?; self.command(ctx, "start_charging").await?;
Ok(()) Ok(())
} }
pub fn pause(&self, ctx: &mut Context) -> Result<(), ApiError> { pub async fn pause(&self, ctx: &mut Context) -> Result<(), ApiError> {
self.command(ctx, "pause_charging")?; self.command(ctx, "pause_charging").await?;
Ok(()) Ok(())
} }
pub fn resume(&self, ctx: &mut Context) -> Result<(), ApiError> { pub async fn resume(&self, ctx: &mut Context) -> Result<(), ApiError> {
self.command(ctx, "resume_charging")?; self.command(ctx, "resume_charging").await?;
Ok(()) Ok(())
} }
pub fn stop(&self, ctx: &mut Context) -> Result<(), ApiError> { pub async fn stop(&self, ctx: &mut Context) -> Result<(), ApiError> {
self.command(ctx, "stop_charging")?; self.command(ctx, "stop_charging").await?;
Ok(()) Ok(())
} }
} }
@ -666,11 +675,15 @@ mod test {
use super::Context; use super::Context;
#[test] #[test]
fn token_save() { fn token_save() {
let client = reqwest::Client::new();
let ctx = Context { let ctx = Context {
auth_header: "Bearer aaaaaaa0".to_owned(), auth_header: "Bearer aaaaaaa0".to_owned(),
refresh_token: "abcdef".to_owned(), refresh_token: "abcdef".to_owned(),
token_expiration: Instant::now() + Duration::from_secs(1234), token_expiration: Instant::now() + Duration::from_secs(1234),
on_refresh: None, on_refresh: None,
client: client.clone(),
}; };
let saved = ctx.save(); let saved = ctx.save();

View File

@ -1,9 +1,9 @@
use serde::{de::{DeserializeOwned, IntoDeserializer}, Deserialize}; use serde::{de::{DeserializeOwned, IntoDeserializer}, Deserialize};
use serde_json::json;
use serde_repr::Deserialize_repr; use serde_repr::Deserialize_repr;
use std::num::{ParseFloatError, ParseIntError}; use std::num::{ParseFloatError, ParseIntError};
use thiserror::Error; use thiserror::Error;
use tracing::info; use tracing::info;
use ureq::json;
use crate::{ use crate::{
api::{ChargerOpMode, Context, OutputPhase, UtcDateTime}, api::{ChargerOpMode, Context, OutputPhase, UtcDateTime},
@ -325,17 +325,17 @@ struct ProductUpdate {
} }
impl Stream { impl Stream {
pub fn from_context(ctx: &mut Context) -> Result<Self, NegotiateError> { pub async fn from_context(ctx: &mut Context) -> Result<Self, NegotiateError> {
Ok(Self { Ok(Self {
inner: signalr::Stream::from_ws(crate::stream::Stream::open(ctx)?), inner: signalr::Stream::from_ws(crate::stream::Stream::open(ctx).await?),
}) })
} }
pub fn recv(&mut self) -> Result<Event, ObservationError> { pub async fn recv(&mut self) -> Result<Event, ObservationError> {
use signalr::Message::*; use signalr::Message::*;
let de = |msg| -> Result<Event, ObservationError> { Err(ObservationError::Protocol(msg)) }; let de = |msg| -> Result<Event, ObservationError> { Err(ObservationError::Protocol(msg)) };
loop { loop {
let msg = self.inner.recv()?; let msg = self.inner.recv().await?;
match &msg { match &msg {
Ping => continue, Ping => continue,
Empty | InvocationResult { .. } => info!("Skipped message: {msg:?}"), Empty | InvocationResult { .. } => info!("Skipped message: {msg:?}"),
@ -351,9 +351,9 @@ impl Stream {
} }
} }
} }
pub fn subscribe(&mut self, id: &str) -> Result<(), tungstenite::Error> { pub async fn subscribe(&mut self, id: &str) -> Result<(), tungstenite::Error> {
self.inner self.inner
.invoke("SubscribeWithCurrentState", json!([id, true])) .invoke("SubscribeWithCurrentState", json!([id, true])).await
} }
} }

View File

@ -113,9 +113,9 @@ impl Stream {
Self { ws, buffer: vec![] } Self { ws, buffer: vec![] }
} }
pub fn recv(&mut self) -> Result<Message, StreamError> { pub async fn recv(&mut self) -> Result<Message, StreamError> {
while self.buffer.is_empty() { while self.buffer.is_empty() {
self.buffer = self.ws.recv()?; self.buffer = self.ws.recv().await?;
self.buffer.reverse(); self.buffer.reverse();
} }
@ -123,7 +123,7 @@ impl Stream {
Ok(Message::from_json(json)?) Ok(Message::from_json(json)?)
} }
pub fn invoke( pub async fn invoke(
&mut self, &mut self,
target: &str, target: &str,
args: serde_json::Value, args: serde_json::Value,
@ -131,6 +131,6 @@ impl Stream {
self.ws.send(json!( { "arguments": args, self.ws.send(json!( { "arguments": args,
"invocationId": "0", "invocationId": "0",
"target": target, "target": target,
"type": 1} )) "type": 1} )).await
} }
} }

View File

@ -1,9 +1,11 @@
use super::api::{ApiError, Context}; use super::api::{ApiError, Context};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use std::net::TcpStream; use tokio::net::TcpStream;
use thiserror::Error; use thiserror::Error;
use tungstenite::{stream::MaybeTlsStream, Message, WebSocket}; //use tungstenite::{stream::MaybeTlsStream, Message, WebSocket};
use tokio_tungstenite::{MaybeTlsStream, tungstenite::Message, WebSocketStream};
use futures_util::{SinkExt,StreamExt};
const STREAM_API_NEGOTIATION_URL: &str = const STREAM_API_NEGOTIATION_URL: &str =
"https://streams.easee.com/hubs/products/negotiate?negotiateVersion=1"; "https://streams.easee.com/hubs/products/negotiate?negotiateVersion=1";
@ -36,15 +38,18 @@ pub enum RecvError {
#[error("WS error: {0}")] #[error("WS error: {0}")]
TungsteniteError(#[from] tungstenite::Error), TungsteniteError(#[from] tungstenite::Error),
#[error("End of stream")]
EndOfStream,
} }
pub struct Stream { pub struct Stream {
sock: WebSocket<MaybeTlsStream<TcpStream>>, sock: WebSocketStream<MaybeTlsStream<TcpStream>>,
} }
impl Stream { impl Stream {
pub fn open(ctx: &mut Context) -> Result<Stream, NegotiateError> { pub async fn open(ctx: &mut Context) -> Result<Stream, NegotiateError> {
let r: NegotiateResponse = ctx.post_raw(STREAM_API_NEGOTIATION_URL, &())?; let r: NegotiateResponse = ctx.post_raw(STREAM_API_NEGOTIATION_URL, &()).await?;
let token = ctx.auth_token(); let token = ctx.auth_token();
let wss_url = format!( let wss_url = format!(
@ -52,7 +57,8 @@ impl Stream {
WSS_URL, r.connection_token, token WSS_URL, r.connection_token, token
); );
let resp = tungstenite::client::connect(&wss_url); let resp = tokio_tungstenite::connect_async(wss_url).await;
//let resp = tungstenite::client::connect(&wss_url);
if let Err(tungstenite::Error::Http(he)) = &resp { if let Err(tungstenite::Error::Http(he)) = &resp {
eprintln!( eprintln!(
@ -62,19 +68,20 @@ impl Stream {
} }
let mut stream = Stream { sock: resp?.0 }; let mut stream = Stream { sock: resp?.0 };
stream.send(json!({ "protocol": "json", "version": 1 }))?; stream.send(json!({ "protocol": "json", "version": 1 })).await?;
Ok(stream) Ok(stream)
} }
pub fn send<T: Serialize>(&mut self, msg: T) -> Result<(), tungstenite::Error> { pub async fn send<T: Serialize>(&mut self, msg: T) -> Result<(), tungstenite::Error> {
let mut msg = serde_json::to_string(&msg).unwrap(); let mut msg = serde_json::to_string(&msg).unwrap();
msg.push('\x1E'); msg.push('\x1E');
self.sock.send(Message::Text(msg)) self.sock.send(Message::Text(msg)).await
} }
pub fn recv(&mut self) -> Result<Vec<serde_json::Value>, RecvError> { pub async fn recv(&mut self) -> Result<Vec<serde_json::Value>, RecvError> {
let msg = self.sock.read()?; let msg = self.sock.next().await
.ok_or(RecvError::EndOfStream)??;
let Message::Text(txt) = msg else { let Message::Text(txt) = msg else {
return Err(RecvError::BadMessageType); return Err(RecvError::BadMessageType);
}; };