tauri_plugin_ota_updater/
lib.rsuse std::{
borrow::Cow,
cell::OnceCell,
collections::HashMap,
io::{Cursor, Read},
path::{Path, PathBuf},
sync::{Arc, Mutex},
};
use base64::Engine;
use serde::{Deserialize, Deserializer, Serialize};
use tar::Archive;
use tauri::{
plugin::{Builder, TauriPlugin},
utils::assets::{AssetKey, CspHash},
App, AppHandle, Assets, Context, Manager, Runtime, State, Url,
};
mod error;
pub use error::{Error, Result};
struct PendingUpdate<R: Runtime>(tauri::async_runtime::Mutex<Option<Update<R>>>);
const DEFAULT_CHANNEL: &str = "over-the-air";
const CHANNEL_PREFIX: &str = "over-the-air-";
pub struct OTAUpdater<R: Runtime> {
cache_path: PathBuf,
config: Arc<tauri::async_runtime::Mutex<Config>>,
manifest: Arc<tauri::async_runtime::Mutex<Manifest>>,
assets: Arc<Mutex<HashMap<AssetKey, Vec<u8>>>>,
embedded_assets: Arc<Mutex<Option<Box<dyn Assets<R>>>>>,
}
impl<R: Runtime> Clone for OTAUpdater<R> {
fn clone(&self) -> Self {
Self {
cache_path: self.cache_path.clone(),
config: self.config.clone(),
manifest: self.manifest.clone(),
assets: self.assets.clone(),
embedded_assets: self.embedded_assets.clone(),
}
}
}
fn manifest_path(cache_path: &Path) -> PathBuf {
cache_path.join("ota-updates-manifest.json")
}
fn latest_update_archive_path(cache_path: &Path) -> PathBuf {
cache_path.join("latest-update.tar.gz")
}
#[derive(Deserialize)]
struct UpdateResponse {
version: String,
notes: String,
pub_date: String,
url: Url,
}
impl<R: Runtime> OTAUpdater<R> {
pub async fn check_for_updates(&self) -> Result<Option<Update<R>>> {
let current_manifest_id = self.manifest.lock().await.id.clone();
let manifest_update_url = self
.config
.lock()
.await
.update_url_for(¤t_manifest_id, "manifest");
let update = reqwest::get(manifest_update_url)
.await?
.json::<UpdateResponse>()
.await?;
let manifest = reqwest::get(update.url).await?.json::<Manifest>().await?;
let current_manifest_id = self.manifest.lock().await.id.clone();
if manifest.id != current_manifest_id {
let latest_dist = reqwest::get(
self.config
.lock()
.await
.latest_url_for_filename("ota-dist.tar.gz"),
)
.await?
.bytes()
.await?
.to_vec();
let pub_key_decoded = base64_to_string(&self.config.lock().await.pubkey)?;
let public_key = minisign_verify::PublicKey::decode(&pub_key_decoded)
.map_err(Error::InvalidPublicKey)?;
let dist_signature_base64_decoded = base64_to_string(&manifest.archive_signature)?;
let dist_signature = minisign_verify::Signature::decode(&dist_signature_base64_decoded)
.map_err(Error::InvalidSignature)?;
public_key
.verify(&latest_dist, &dist_signature, false)
.map_err(Error::InvalidSignature)?;
let archive = tar::Archive::new(Cursor::new(&latest_dist));
let update_assets =
load_assets(archive, &manifest, &public_key, &self.cache_path).await?;
return Ok(Some(Update {
version: update.version,
notes: update.notes,
pub_date: update.pub_date,
archive: latest_dist,
manifest,
assets: update_assets,
updater: self.clone(),
}));
}
Ok(None)
}
}
pub trait OTAUpdaterExt<R: Runtime> {
fn ota_updater(&self) -> &OTAUpdater<R>;
}
impl<R: Runtime, T: Manager<R>> crate::OTAUpdaterExt<R> for T {
fn ota_updater(&self) -> &OTAUpdater<R> {
self.state::<OTAUpdater<R>>().inner()
}
}
pub struct Update<R: Runtime> {
pub version: String,
pub notes: String,
pub pub_date: String,
updater: OTAUpdater<R>,
archive: Vec<u8>,
manifest: Manifest,
assets: HashMap<AssetKey, Vec<u8>>,
}
impl<R: Runtime> Update<R> {
pub async fn apply(&self) -> Result<()> {
*self.updater.assets.lock().unwrap() = self.assets.clone();
std::fs::write(
manifest_path(&self.updater.cache_path),
serde_json::to_string(&self.manifest)?,
)?;
*self.updater.manifest.lock().await = self.manifest.clone();
std::fs::write(
latest_update_archive_path(&self.updater.cache_path),
&self.archive,
)?;
Ok(())
}
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Manifest {
id: String,
archive_signature: String,
files: HashMap<PathBuf, ManifestFile>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ManifestFile {
signature: String,
}
pub struct OTAAssets<R: Runtime> {
embedded_assets: OnceCell<Box<dyn Assets<R>>>,
assets: Arc<Mutex<HashMap<AssetKey, Vec<u8>>>>,
csp_hashes: Vec<CspHash<'static>>,
}
unsafe impl<R: Runtime> Sync for OTAAssets<R> {}
impl<R: Runtime> Assets<R> for OTAAssets<R> {
fn setup(&self, app: &App<R>) {
let ota = app.state::<OTAUpdater<R>>();
self.embedded_assets.get_or_init(|| {
let assets = ota.embedded_assets.lock().unwrap().take().unwrap();
assets.setup(app);
assets
});
*self.assets.lock().unwrap() = {
if let Ok(archive_bytes) = std::fs::read(latest_update_archive_path(&ota.cache_path)) {
let load_assets_fut = async {
let pub_key_decoded = base64_to_string(&ota.config.lock().await.pubkey)?;
let public_key = minisign_verify::PublicKey::decode(&pub_key_decoded)
.map_err(Error::InvalidPublicKey)?;
load_assets(
tar::Archive::new(Cursor::new(archive_bytes)),
&*ota.manifest.lock().await,
&public_key,
&ota.cache_path,
)
.await
};
match tauri::async_runtime::block_on(load_assets_fut) {
Ok(assets) => assets,
Err(_e) => {
#[cfg(debug_assertions)]
eprintln!("failed to load assets: {_e}");
HashMap::new()
}
}
} else {
HashMap::new()
}
};
}
fn csp_hashes(&self, _html_path: &AssetKey) -> Box<dyn Iterator<Item = CspHash<'_>> + '_> {
Box::new(self.csp_hashes.iter().copied())
}
fn get(&self, key: &AssetKey) -> Option<Cow<'_, [u8]>> {
self.assets
.lock()
.unwrap()
.get(key)
.map(|b| Cow::Owned(b.clone()))
.or_else(|| self.embedded_assets.get().unwrap().get(key))
}
fn iter(&self) -> Box<dyn Iterator<Item = (Cow<'_, str>, Cow<'_, [u8]>)> + '_> {
Box::new(
self.assets
.lock()
.unwrap()
.clone()
.into_iter()
.map(|(k, v)| (Cow::Owned(k.as_ref().to_string()), Cow::Owned(v))),
)
}
}
#[derive(Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Config {
pub cdn_host: Option<String>,
pub org_slug: String,
pub app_slug: String,
pub pubkey: String,
#[serde(deserialize_with = "channel_deserializer", default)]
pub channel: Option<String>,
}
fn channel_deserializer<'de, D>(deserializer: D) -> std::result::Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
let s = Option::<String>::deserialize(deserializer)?;
Ok(match s.as_deref() {
Some(DEFAULT_CHANNEL) => None,
Some(c) => c
.strip_prefix(CHANNEL_PREFIX)
.map(ToString::to_string)
.or(s),
None => None,
})
}
impl Config {
fn update_url_for(&self, current_version: &str, update_platform: &str) -> Url {
format!(
"https://{}/update/{}/{}/{update_platform}/{current_version}?channel={}",
self.cdn_host.as_deref().unwrap_or("cdn.crabnebula.app"),
self.org_slug,
self.app_slug,
self.channel.as_deref().unwrap_or(DEFAULT_CHANNEL)
)
.parse()
.expect("invalid URL")
}
fn latest_url_for_filename(&self, filename: &str) -> Url {
format!(
"https://{}/download/{}/{}/latest/{filename}?channel={}",
self.cdn_host.as_deref().unwrap_or("cdn.crabnebula.app"),
self.org_slug,
self.app_slug,
self.channel.as_deref().unwrap_or(DEFAULT_CHANNEL)
)
.parse()
.expect("invalid URL")
}
}
#[tauri::command]
async fn set_channel<R: Runtime>(
_app: AppHandle<R>,
ota: State<'_, OTAUpdater<R>>,
channel: Option<String>,
) -> Result<()> {
ota.config.lock().await.channel = channel;
Ok(())
}
#[tauri::command]
async fn check_for_updates<R: Runtime>(
app: AppHandle<R>,
ota: State<'_, OTAUpdater<R>>,
) -> Result<bool> {
if let Some(update) = ota.check_for_updates().await? {
if let Some(pending) = app.try_state::<PendingUpdate<R>>() {
pending.0.lock().await.replace(update);
} else {
app.manage(PendingUpdate(tauri::async_runtime::Mutex::new(Some(
update,
))));
}
Ok(true)
} else {
Ok(false)
}
}
#[tauri::command]
async fn apply_update<R: Runtime>(app: AppHandle<R>) -> Result<()> {
if let Some(update) = app.try_state::<PendingUpdate<R>>() {
let mut pending = update.0.lock().await;
if let Some(update) = &*pending {
update.apply().await?;
*pending = None;
}
}
Ok(())
}
pub fn init<R: Runtime>(mut context: Context<R>) -> (TauriPlugin<R, Config>, Context<R>) {
let assets = Arc::new(Mutex::new(HashMap::new()));
let embedded_assets = context.set_assets(Box::new(OTAAssets {
assets: assets.clone(),
csp_hashes: Default::default(),
embedded_assets: Default::default(),
}));
let plugin = Builder::<R, Config>::new("ota-updater")
.invoke_handler(tauri::generate_handler![
check_for_updates,
apply_update,
set_channel
])
.setup(|app, api| {
let cache_path = app.path().app_cache_dir()?;
std::fs::create_dir_all(&cache_path)?;
let current_manifest = std::fs::read_to_string(manifest_path(&cache_path))
.and_then(|json| {
serde_json::from_str::<Manifest>(&json)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))
})
.unwrap_or_else(|_| Manifest {
id: "0".to_string(),
archive_signature: "".to_string(),
files: HashMap::new(),
});
let ota_updater = OTAUpdater {
cache_path,
manifest: Arc::new(tauri::async_runtime::Mutex::new(current_manifest)),
config: Arc::new(tauri::async_runtime::Mutex::new(api.config().clone())),
assets,
embedded_assets: Arc::new(Mutex::new(Some(embedded_assets))),
};
let _ = tauri::async_runtime::block_on(ota_updater.check_for_updates());
app.manage(ota_updater);
Ok(())
})
.build();
(plugin, context)
}
fn base64_to_string(base64_string: &str) -> Result<String> {
let decoded_string = &base64::engine::general_purpose::STANDARD.decode(base64_string)?;
let result = std::str::from_utf8(decoded_string)?.to_string();
Ok(result)
}
async fn load_assets<R: Read>(
mut archive: Archive<R>,
manifest: &Manifest,
public_key: &minisign_verify::PublicKey,
cache_path: &Path,
) -> Result<HashMap<AssetKey, Vec<u8>>> {
let mut assets = HashMap::new();
let archive_out_dir = tempfile::tempdir_in(cache_path)?;
archive.unpack(archive_out_dir.path())?;
let dist_dir = archive_out_dir.path().join("dist");
for entry in walkdir::WalkDir::new(&dist_dir) {
let entry = entry?;
let path = entry.path();
if entry.file_type().is_file() {
let relative_path = path.strip_prefix(&dist_dir).unwrap();
let manifest_file =
manifest
.files
.get(relative_path)
.ok_or(Error::FileNotInManifest {
path: relative_path.to_path_buf(),
})?;
let signature_base64_decoded = base64_to_string(&manifest_file.signature)?;
let signature = minisign_verify::Signature::decode(&signature_base64_decoded)
.map_err(Error::InvalidSignature)?;
let data = std::fs::read(path)?;
public_key
.verify(&data, &signature, false)
.map_err(Error::InvalidSignature)?;
assets.insert(relative_path.into(), data);
}
}
Ok(assets)
}