use std::collections::HashMap; use std::convert::Infallible; use std::fs; use std::net::{Ipv4Addr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; use hyper::service::{make_service_fn, service_fn}; use hyper::{header, Body, Request, Response, Server, StatusCode}; use miette::IntoDiagnostic; use structopt::StructOpt; mod utils; use utils::HttpHandler; #[derive(Debug, StructOpt)] struct Opt { #[structopt(long, parse(from_os_str), default_value = "narchttpd.kdl")] config_file: PathBuf, } #[derive(Debug, knuffel::Decode)] struct HttpConfig { #[knuffel(argument)] enabled: bool, #[knuffel(property(name = "port"))] port: u16, } #[derive(Debug, knuffel::Decode)] struct StaticDomain { #[knuffel(child, unwrap(argument))] root: PathBuf, } impl StaticDomain { fn handler(self) -> Box { let Self { root } = self; Box::new(utils::serve_static::Params::new(root)) } } #[derive(Debug, knuffel::Decode)] struct ProxyChildDomain { #[knuffel(child, unwrap(argument))] command: String, #[knuffel(child, unwrap(argument))] in_dir: Option, #[knuffel(child, unwrap(argument))] port: u16, } impl ProxyChildDomain { fn handler(self) -> Box { let Self { command, in_dir, port, } = self; Box::new(utils::proxy_child::ProxyChild::new(command, in_dir, port)) } } #[derive(Debug, knuffel::Decode)] enum DomainType { Static(StaticDomain), ProxyChild(ProxyChildDomain), } impl DomainType { fn handler(self) -> Box { match self { Self::Static(domain) => domain.handler(), Self::ProxyChild(domain) => domain.handler(), } } } #[derive(Debug, knuffel::Decode)] struct DomainConfig { #[knuffel(arguments)] domains: Vec, #[knuffel(children)] config: Vec, } #[derive(Debug, knuffel::Decode)] struct Config { #[knuffel(child)] http: HttpConfig, #[knuffel(children(name = "domain"))] domains: Vec, } #[tokio::main] async fn main() -> miette::Result<()> { let opt: Opt = Opt::from_args(); let config_data = fs::read_to_string(&opt.config_file).into_diagnostic()?; let Config { http, domains } = knuffel::parse(&opt.config_file.to_string_lossy(), &config_data)?; let http_port = http.port; let domains: HashMap<_, _> = domains .into_iter() .map(|DomainConfig { domains, config }| { let [config]: [DomainType; 1] = config.try_into().unwrap(); (domains, Arc::new(config.handler())) }) .flat_map(|(domains, handler)| { domains .into_iter() .map(move |domain| (domain, Arc::clone(&handler))) }) .collect(); let domains = Arc::new(domains); // TODO learn hyper let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), http_port); let make_svc = make_service_fn(move |_conn| { let domains = domains.clone(); async move { let domains = domains.clone(); Ok::<_, Infallible>(service_fn(move |req: Request| { let domains = domains.clone(); async move { let domains = domains.clone(); let req_domain = req.headers().get(header::HOST).unwrap(); let req_domain = req_domain.to_str().unwrap(); match domains.get(req_domain) { Some(handler) => { eprintln!("request {:?} matched domain {}", req, req_domain); Ok::<_, Infallible>(handler.handle(req).await) } None => Ok::<_, Infallible>( Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("Not Found")) .unwrap(), ), } } })) } }); let server = Server::bind(&addr).serve(make_svc); server.await.into_diagnostic() }