aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs178
1 files changed, 109 insertions, 69 deletions
diff --git a/src/main.rs b/src/main.rs
index f284b75..d99b672 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,13 +1,13 @@
use std::collections::HashMap;
use std::convert::Infallible;
-use std::fs::read_to_string;
-use std::net::SocketAddr;
+use std::fs;
+use std::net::{Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::sync::Arc;
use hyper::service::{make_service_fn, service_fn};
use hyper::{header, Body, Request, Response, Server, StatusCode};
-use serde::Deserialize;
+use miette::IntoDiagnostic;
use structopt::StructOpt;
mod utils;
@@ -15,99 +15,139 @@ use utils::HttpHandler;
#[derive(Debug, StructOpt)]
struct Opt {
- #[structopt(long, parse(from_os_str), default_value = "narchttpd.toml")]
+ #[structopt(long, parse(from_os_str), default_value = "narchttpd.kdl")]
config_file: PathBuf,
}
-#[derive(Deserialize)]
-#[serde(tag = "mode", rename_all = "kebab-case")]
-enum DomainConfig {
- Static {
- root: PathBuf,
- },
-
- ProxyChild {
- command: String,
- in_dir: Option<PathBuf>,
- port: u16,
- },
+#[derive(Debug, knuffel::Decode)]
+struct HttpConfig {
+ #[knuffel(argument)]
+ enabled: bool,
+ #[knuffel(property(name = "port"))]
+ port: u16,
}
-impl DomainConfig {
+#[derive(Debug, knuffel::Decode)]
+struct StaticDomain {
+ #[knuffel(child, unwrap(argument))]
+ root: PathBuf,
+}
+
+impl StaticDomain {
+ fn handler(self) -> Box<dyn HttpHandler> {
+ let Self { root } = self;
+ Box::new(utils::serve_static::Params::new(root))
+ }
+}
+
+#[derive(Debug, knuffel::Decode)]
+struct ProxyChildDomain {
+ #[knuffel(child, unwrap(argument))]
+ command: String,
+ #[knuffel(child, unwrap(argument))]
+ in_dir: Option<PathBuf>,
+ #[knuffel(child, unwrap(argument))]
+ port: u16,
+}
+
+impl ProxyChildDomain {
+ fn handler(self) -> Box<dyn HttpHandler> {
+ let Self {
+ command,
+ in_dir,
+ port,
+ } = self;
+ Box::new(utils::proxy_child::ProxyChild::new(command, in_dir, port))
+ }
+}
+
+#[derive(Debug, knuffel::Decode)]
+enum DomainType {
+ Static(StaticDomain),
+ ProxyChild(ProxyChildDomain),
+}
+
+impl DomainType {
fn handler(self) -> Box<dyn HttpHandler> {
match self {
- Self::Static { root } => Box::new(utils::serve_static::Params::new(root)),
- Self::ProxyChild {
- command,
- in_dir,
- port,
- } => Box::new(utils::proxy_child::ProxyChild::new(command, in_dir, port)),
+ Self::Static(domain) => domain.handler(),
+ Self::ProxyChild(domain) => domain.handler(),
}
}
}
-#[derive(Deserialize)]
+#[derive(Debug, knuffel::Decode)]
+struct DomainConfig {
+ #[knuffel(arguments)]
+ domains: Vec<String>,
+ #[knuffel(children)]
+ config: Vec<DomainType>,
+}
+
+#[derive(Debug, knuffel::Decode)]
struct Config {
- http_ports: Vec<u16>,
- https_ports: Vec<u16>,
- #[serde(flatten)]
- domains: HashMap<String, DomainConfig>,
+ #[knuffel(child)]
+ http: HttpConfig,
+ #[knuffel(children(name = "domain"))]
+ domains: Vec<DomainConfig>,
}
#[tokio::main]
-async fn main() {
- let opt = Opt::from_args();
+async fn main() -> miette::Result<()> {
+ let opt: Opt = Opt::from_args();
- let config_data = read_to_string(&opt.config_file).expect("Config file not found");
+ let config_data = fs::read_to_string(&opt.config_file).into_diagnostic()?;
- let Config {
- http_ports,
- https_ports,
- domains,
- } = toml::from_str(&config_data).expect("Config file not valid");
+ let Config { http, domains } =
+ knuffel::parse(&opt.config_file.to_string_lossy(), &config_data)?;
- assert!(https_ports.is_empty(), "HTTPS is complicated oops");
+ let http_port = http.port;
let domains: HashMap<_, _> = domains
.into_iter()
- .map(|(domain, config)| (domain, Arc::new(config.handler())))
+ .map(|DomainConfig { domains, config }| {
+ let [config]: [DomainType; 1] = config.try_into().unwrap();
+ (domains, Arc::new(config.handler()))
+ })
+ .flat_map(|(domains, handler)| {
+ domains
+ .into_iter()
+ .map(move |domain| (domain, Arc::clone(&handler)))
+ })
.collect();
- let do_response_domains = domains.clone();
+ let domains = Arc::new(domains);
// TODO learn hyper
- let do_response = move |req: Request<Body>| async move {
- for (domain, handler) in do_response_domains.clone() {
- let req_domain = req.headers().get(header::HOST).unwrap();
- if &domain == req_domain {
- eprintln!("request {:?} matched domain {}", req, domain);
- // matched!
- return handler.handle(req).await;
- }
+ let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), http_port);
+
+ let make_svc = make_service_fn(move |_conn| {
+ let domains = domains.clone();
+ async move {
+ let domains = domains.clone();
+ Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
+ let domains = domains.clone();
+ async move {
+ let domains = domains.clone();
+ let req_domain = req.headers().get(header::HOST).unwrap();
+ let req_domain = req_domain.to_str().unwrap();
+ match domains.get(req_domain) {
+ Some(handler) => {
+ eprintln!("request {:?} matched domain {}", req, req_domain);
+ Ok::<_, Infallible>(handler.handle(req).await)
+ }
+ None => Ok::<_, Infallible>(
+ Response::builder()
+ .status(StatusCode::NOT_FOUND)
+ .body(Body::from("Not Found"))
+ .unwrap(),
+ ),
+ }
+ }
+ }))
}
- Response::builder()
- .status(StatusCode::NOT_FOUND)
- .body(Body::from("Not Found"))
- .unwrap()
- };
- let addr = http_ports.into_iter().flat_map(|port| {
- ["127.0.0.1", "::1"]
- .iter()
- .map(move |ip| SocketAddr::new(ip.parse().unwrap(), port))
- });
-
- let make_svc = make_service_fn(move |_conn| async move {
- Ok::<_, Infallible>(service_fn(move |req| async move {
- let response = do_response(req).await;
- Ok::<_, Infallible>(response)
- }))
});
- // TODO uhh
- let mut addr = addr;
- let addr = addr.nth(0).unwrap();
let server = Server::bind(&addr).serve(make_svc);
- if let Err(e) = server.await {
- eprintln!("server error: {}", e);
- }
+ server.await.into_diagnostic()
}