aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/decoder/serde.rs88
1 files changed, 88 insertions, 0 deletions
diff --git a/src/decoder/serde.rs b/src/decoder/serde.rs
index 01806d4..2f69eb9 100644
--- a/src/decoder/serde.rs
+++ b/src/decoder/serde.rs
@@ -89,6 +89,94 @@ impl de::Deserializer for Decoder {
self.visit(visitor)
}
}
+
+ fn visit_enum<V>(&mut self,
+ _enum: &str,
+ variants: &[&str],
+ mut visitor: V) -> Result<V::Value, DecodeError>
+ where V: de::EnumVisitor,
+ {
+ // When decoding enums, this crate takes the strategy of trying to
+ // decode the current TOML as all of the possible variants, returning
+ // success on the first one that succeeds.
+ //
+ // Note that fidelity of the errors returned here is a little nebulous,
+ // but we try to return the error that had the relevant field as the
+ // longest field. This way we hopefully match an error against what was
+ // most likely being written down without losing too much info.
+ let mut first_error = None::<DecodeError>;
+
+ for variant in 0 .. variants.len() {
+ let mut de = VariantVisitor {
+ de: self.sub_decoder(self.toml.clone(), ""),
+ variant: variant,
+ };
+
+ match visitor.visit(&mut de) {
+ Ok(value) => {
+ self.toml = de.de.toml;
+ return Ok(value);
+ }
+ Err(e) => {
+ if let Some(ref first) = first_error {
+ let my_len = e.field.as_ref().map(|s| s.len());
+ let first_len = first.field.as_ref().map(|s| s.len());
+ if my_len <= first_len {
+ continue
+ }
+ }
+ first_error = Some(e);
+ }
+ }
+ }
+
+ Err(first_error.unwrap_or_else(|| self.err(DecodeErrorKind::NoEnumVariants)))
+ }
+}
+
+struct VariantVisitor {
+ de: Decoder,
+ variant: usize,
+}
+
+impl de::VariantVisitor for VariantVisitor {
+ type Error = DecodeError;
+
+ fn visit_variant<V>(&mut self) -> Result<V, DecodeError>
+ where V: de::Deserialize
+ {
+ use serde::de::value::ValueDeserializer;
+
+ let mut de = self.variant.into_deserializer();
+
+ de::Deserialize::deserialize(&mut de).map_err(|e| se2toml(e, "variant"))
+ }
+
+ fn visit_unit(&mut self) -> Result<(), DecodeError> {
+ de::Deserialize::deserialize(&mut self.de)
+ }
+
+ fn visit_newtype<T>(&mut self) -> Result<T, DecodeError>
+ where T: de::Deserialize,
+ {
+ de::Deserialize::deserialize(&mut self.de)
+ }
+
+ fn visit_tuple<V>(&mut self,
+ _len: usize,
+ visitor: V) -> Result<V::Value, DecodeError>
+ where V: de::Visitor,
+ {
+ de::Deserializer::visit(&mut self.de, visitor)
+ }
+
+ fn visit_struct<V>(&mut self,
+ _fields: &'static [&'static str],
+ visitor: V) -> Result<V::Value, DecodeError>
+ where V: de::Visitor,
+ {
+ de::Deserializer::visit(&mut self.de, visitor)
+ }
}
struct SeqDeserializer<'a, I> {