aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/de.rs122
-rw-r--r--src/spanned.rs2
-rw-r--r--test-suite/tests/spanned.rs73
3 files changed, 161 insertions, 36 deletions
diff --git a/src/de.rs b/src/de.rs
index 439a48a..9bb5204 100644
--- a/src/de.rs
+++ b/src/de.rs
@@ -22,7 +22,7 @@ use crate::spanned;
use crate::tokens::{Error as TokenError, Span, Token, Tokenizer};
/// Type Alias for a TOML Table pair
-type TablePair<'a> = (Cow<'a, str>, Value<'a>);
+type TablePair<'a> = ((Span, Cow<'a, str>), Value<'a>);
/// Deserializes a byte slice into a type.
///
@@ -318,9 +318,16 @@ impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
}
}
+fn headers_equal<'a, 'b>(hdr_a: &[(Span, Cow<'a, str>)], hdr_b: &[(Span, Cow<'b, str>)]) -> bool {
+ if hdr_a.len() != hdr_b.len() {
+ return false;
+ }
+ hdr_a.iter().zip(hdr_b.iter()).all(|(h1, h2)| h1.1 == h2.1)
+}
+
struct Table<'a> {
at: usize,
- header: Vec<Cow<'a, str>>,
+ header: Vec<(Span, Cow<'a, str>)>,
values: Option<Vec<TablePair<'a>>>,
array: bool,
}
@@ -351,7 +358,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
loop {
assert!(self.next_value.is_none());
if let Some((key, value)) = self.values.next() {
- let ret = seed.deserialize(StrDeserializer::new(key.clone()))?;
+ let ret = seed.deserialize(StrDeserializer::spanned(key.clone()))?;
self.next_value = Some((key, value));
return Ok(Some(ret));
}
@@ -366,7 +373,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
return false;
}
match t.header.get(..self.depth) {
- Some(header) => header == prefix,
+ Some(header) => headers_equal(&header, &prefix),
None => false,
}
})
@@ -382,9 +389,17 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
// Test to see if we're duplicating our parent's table, and if so
// then this is an error in the toml format
if self.cur_parent != pos {
- if self.tables[self.cur_parent].header == self.tables[pos].header {
+ if headers_equal(
+ &self.tables[self.cur_parent].header,
+ &self.tables[pos].header,
+ ) {
let at = self.tables[pos].at;
- let name = self.tables[pos].header.join(".");
+ let name = self.tables[pos]
+ .header
+ .iter()
+ .map(|k| k.1.to_owned())
+ .collect::<Vec<_>>()
+ .join(".");
return Err(self.de.error(at, ErrorKind::DuplicateTable(name)));
}
@@ -408,7 +423,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
// decoding.
if self.depth != table.header.len() {
let key = &table.header[self.depth];
- let key = seed.deserialize(StrDeserializer::new(key.clone()))?;
+ let key = seed.deserialize(StrDeserializer::spanned(key.clone()))?;
return Ok(Some(key));
}
@@ -437,7 +452,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
match seed.deserialize(ValueDeserializer::new(v)) {
Ok(v) => return Ok(v),
Err(mut e) => {
- e.add_key_context(&k);
+ e.add_key_context(&k.1);
return Err(e);
}
}
@@ -458,7 +473,7 @@ impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
de: &mut *self.de,
});
res.map_err(|mut e| {
- e.add_key_context(&self.tables[self.cur - 1].header[self.depth]);
+ e.add_key_context(&self.tables[self.cur - 1].header[self.depth].1);
e
})
}
@@ -482,7 +497,10 @@ impl<'de, 'b> de::SeqAccess<'de> for MapVisitor<'de, 'b> {
.iter()
.enumerate()
.skip(self.cur_parent + 1)
- .find(|&(_, table)| table.array && table.header == self.tables[self.cur_parent].header)
+ .find(|&(_, table)| {
+ let tables_eq = headers_equal(&table.header, &self.tables[self.cur_parent].header);
+ table.array && tables_eq
+ })
.map(|p| p.0)
.unwrap_or(self.max);
@@ -560,9 +578,9 @@ impl<'de, 'b> de::Deserializer<'de> for MapVisitor<'de, 'b> {
if table.header.len() == 0 {
return Err(self.de.error(self.cur, ErrorKind::EmptyTableKey));
}
- let name = table.header[table.header.len() - 1].to_owned();
+ let name = table.header[table.header.len() - 1].1.to_owned();
visitor.visit_enum(DottedTableDeserializer {
- name: name,
+ name,
value: Value {
e: E::DottedTable(values),
start: 0,
@@ -579,12 +597,27 @@ impl<'de, 'b> de::Deserializer<'de> for MapVisitor<'de, 'b> {
}
struct StrDeserializer<'a> {
+ span: Option<Span>,
key: Cow<'a, str>,
}
impl<'a> StrDeserializer<'a> {
+ fn spanned(inner: (Span, Cow<'a, str>)) -> StrDeserializer<'a> {
+ StrDeserializer {
+ span: Some(inner.0),
+ key: inner.1,
+ }
+ }
fn new(key: Cow<'a, str>) -> StrDeserializer<'a> {
- StrDeserializer { key }
+ StrDeserializer { span: None, key }
+ }
+}
+
+impl<'a, 'b> de::IntoDeserializer<'a, Error> for StrDeserializer<'a> {
+ type Deserializer = Self;
+
+ fn into_deserializer(self) -> Self::Deserializer {
+ self
}
}
@@ -601,9 +634,31 @@ impl<'de> de::Deserializer<'de> for StrDeserializer<'de> {
}
}
+ fn deserialize_struct<V>(
+ self,
+ name: &'static str,
+ fields: &'static [&'static str],
+ visitor: V,
+ ) -> Result<V::Value, Error>
+ where
+ V: de::Visitor<'de>,
+ {
+ if name == spanned::NAME && fields == [spanned::START, spanned::END, spanned::VALUE] {
+ if let Some(span) = self.span {
+ return visitor.visit_map(SpannedDeserializer {
+ phantom_data: PhantomData,
+ start: Some(span.start),
+ value: Some(StrDeserializer::new(self.key)),
+ end: Some(span.end),
+ });
+ }
+ }
+ self.deserialize_any(visitor)
+ }
+
serde::forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
- bytes byte_buf map struct option unit newtype_struct
+ bytes byte_buf map option unit newtype_struct
ignored_any unit_struct tuple_struct tuple enum identifier
}
}
@@ -690,13 +745,13 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer<'de> {
.iter()
.filter_map(|key_value| {
let (ref key, ref _val) = *key_value;
- if !fields.contains(&&(**key)) {
+ if !fields.contains(&&*(key.1)) {
Some(key.clone())
} else {
None
}
})
- .collect::<Vec<Cow<'de, str>>>();
+ .collect::<Vec<_>>();
if !extra_fields.is_empty() {
return Err(Error::from_kind(
@@ -704,7 +759,7 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer<'de> {
ErrorKind::UnexpectedKeys {
keys: extra_fields
.iter()
- .map(|k| k.to_string())
+ .map(|k| k.1.to_string())
.collect::<Vec<_>>(),
available: fields,
},
@@ -943,7 +998,7 @@ impl<'de> de::MapAccess<'de> for InlineTableDeserializer<'de> {
None => return Ok(None),
};
self.next_value = Some(value);
- seed.deserialize(StrDeserializer::new(key)).map(Some)
+ seed.deserialize(StrDeserializer::spanned(key)).map(Some)
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error>
@@ -976,7 +1031,7 @@ impl<'de> de::EnumAccess<'de> for InlineTableDeserializer<'de> {
}
};
- seed.deserialize(StrDeserializer::new(key))
+ seed.deserialize(StrDeserializer::new(key.1))
.map(|val| (val, TableEnumDeserializer { value }))
}
}
@@ -1027,13 +1082,13 @@ impl<'de> de::VariantAccess<'de> for TableEnumDeserializer<'de> {
let tuple_values = values
.into_iter()
.enumerate()
- .map(|(index, (key, value))| match key.parse::<usize>() {
+ .map(|(index, (key, value))| match key.1.parse::<usize>() {
Ok(key_index) if key_index == index => Ok(value),
Ok(_) | Err(_) => Err(Error::from_kind(
- Some(value.start),
+ Some(key.0.start),
ErrorKind::ExpectedTupleIndex {
expected: index,
- found: key.to_string(),
+ found: key.1.to_string(),
},
)),
})
@@ -1350,14 +1405,14 @@ impl<'a> Deserializer<'a> {
.as_ref()
.and_then(|values| values.last())
.map(|&(_, ref val)| val.end)
- .unwrap_or_else(|| header.len());
+ .unwrap_or_else(|| header.1.len());
Ok((
Value {
e: E::DottedTable(table.values.unwrap_or_else(Vec::new)),
start,
end,
},
- Some(header.clone()),
+ Some(header.1.clone()),
))
}
Some(_) => self.value().map(|val| (val, None)),
@@ -1672,14 +1727,11 @@ impl<'a> Deserializer<'a> {
Ok((span, ret))
}
- fn table_key(&mut self) -> Result<Cow<'a, str>, Error> {
- self.tokens
- .table_key()
- .map(|t| t.1)
- .map_err(|e| self.token_error(e))
+ fn table_key(&mut self) -> Result<(Span, Cow<'a, str>), Error> {
+ self.tokens.table_key().map_err(|e| self.token_error(e))
}
- fn dotted_key(&mut self) -> Result<Vec<Cow<'a, str>>, Error> {
+ fn dotted_key(&mut self) -> Result<Vec<(Span, Cow<'a, str>)>, Error> {
let mut result = Vec::new();
result.push(self.table_key()?);
self.eat_whitespace()?;
@@ -1705,7 +1757,7 @@ impl<'a> Deserializer<'a> {
/// * `values`: The `Vec` to store the value in.
fn add_dotted_key(
&self,
- mut key_parts: Vec<Cow<'a, str>>,
+ mut key_parts: Vec<(Span, Cow<'a, str>)>,
value: Value<'a>,
values: &mut Vec<TablePair<'a>>,
) -> Result<(), Error> {
@@ -1714,7 +1766,7 @@ impl<'a> Deserializer<'a> {
values.push((key, value));
return Ok(());
}
- match values.iter_mut().find(|&&mut (ref k, _)| *k == key) {
+ match values.iter_mut().find(|&&mut (ref k, _)| *k.1 == key.1) {
Some(&mut (
_,
Value {
@@ -2038,7 +2090,7 @@ enum Line<'a> {
header: Header<'a>,
array: bool,
},
- KeyValue(Vec<Cow<'a, str>>, Value<'a>),
+ KeyValue(Vec<(Span, Cow<'a, str>)>, Value<'a>),
}
struct Header<'a> {
@@ -2058,13 +2110,13 @@ impl<'a> Header<'a> {
}
}
- fn next(&mut self) -> Result<Option<Cow<'a, str>>, TokenError> {
+ fn next(&mut self) -> Result<Option<(Span, Cow<'a, str>)>, TokenError> {
self.tokens.eat_whitespace()?;
if self.first || self.tokens.eat(Token::Period)? {
self.first = false;
self.tokens.eat_whitespace()?;
- self.tokens.table_key().map(|t| t.1).map(Some)
+ self.tokens.table_key().map(|t| t).map(Some)
} else {
self.tokens.expect(Token::RightBracket)?;
if self.array {
diff --git a/src/spanned.rs b/src/spanned.rs
index 1538f96..3318d28 100644
--- a/src/spanned.rs
+++ b/src/spanned.rs
@@ -28,7 +28,7 @@ pub(crate) const VALUE: &str = "$__toml_private_value";
/// assert_eq!(u.s.into_inner(), String::from("value"));
/// }
/// ```
-#[derive(Clone, Debug, PartialEq)]
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Spanned<T> {
/// The start range.
start: usize,
diff --git a/test-suite/tests/spanned.rs b/test-suite/tests/spanned.rs
index 5130a72..1186645 100644
--- a/test-suite/tests/spanned.rs
+++ b/test-suite/tests/spanned.rs
@@ -85,3 +85,76 @@ fn test_spanned_field() {
// ending at something other than the absolute end
good::<u32>("foo = 42\nnoise = true", "42", Some(8));
}
+
+#[test]
+fn test_spanned_table() {
+ #[derive(Deserialize)]
+ struct Foo {
+ foo: HashMap<Spanned<String>, Spanned<String>>,
+ }
+
+ fn good(s: &str) {
+ let foo: Foo = toml::from_str(s).unwrap();
+
+ for (k, v) in foo.foo.iter() {
+ assert_eq!(&s[k.start()..k.end()], k.get_ref());
+ assert_eq!(&s[(v.start() + 1)..(v.end() - 1)], v.get_ref());
+ }
+ }
+
+ good(
+ "
+ [foo]
+ a = 'b'
+ bar = 'baz'
+ c = 'd'
+ e = \"f\"
+ ",
+ );
+
+ good(
+ "
+ foo = { a = 'b', bar = 'baz', c = 'd', e = \"f\" }
+ ",
+ );
+}
+
+#[test]
+fn test_spanned_nested() {
+ #[derive(Deserialize)]
+ struct Foo {
+ foo: HashMap<Spanned<String>, HashMap<Spanned<String>, Spanned<String>>>,
+ }
+
+ fn good(s: &str) {
+ let foo: Foo = toml::from_str(s).unwrap();
+
+ for (k, v) in foo.foo.iter() {
+ assert_eq!(&s[k.start()..k.end()], k.get_ref());
+ for (n_k, n_v) in v.iter() {
+ assert_eq!(&s[n_k.start()..n_k.end()], n_k.get_ref());
+ assert_eq!(&s[(n_v.start() + 1)..(n_v.end() - 1)], n_v.get_ref());
+ }
+ }
+ }
+
+ good(
+ "
+ [foo.a]
+ a = 'b'
+ c = 'd'
+ e = \"f\"
+ [foo.bar]
+ baz = 'true'
+ ",
+ );
+
+ good(
+ "
+ [foo]
+ foo = { a = 'b', bar = 'baz', c = 'd', e = \"f\" }
+ bazz = {}
+ g = { h = 'i' }
+ ",
+ );
+}