From 2fcd829b1d9c70d0981411b4f4adca9124985b54 Mon Sep 17 00:00:00 2001
From: Andrzej Janik <vosen@vosen.pl>
Date: Thu, 4 Jun 2015 20:23:46 +0200
Subject: Disallow table redefinitions

---
 src/decoder/rustc_serialize.rs | 14 ++++++------
 src/display.rs                 |  7 +++---
 src/encoder/mod.rs             |  8 +++----
 src/encoder/rustc_serialize.rs | 11 +++++----
 src/lib.rs                     | 20 +++++++++++++++--
 src/parser.rs                  | 51 ++++++++++++++++++++++++++----------------
 tests/valid.rs                 |  4 ++--
 7 files changed, 74 insertions(+), 41 deletions(-)

diff --git a/src/decoder/rustc_serialize.rs b/src/decoder/rustc_serialize.rs
index 6e8fe59..af38f9b 100644
--- a/src/decoder/rustc_serialize.rs
+++ b/src/decoder/rustc_serialize.rs
@@ -3,7 +3,7 @@ use std::mem;
 
 use super::{Decoder, DecodeError};
 use super::DecodeErrorKind::*;
-use Value;
+use {Value, Table};
 
 impl rustc_serialize::Decoder for Decoder {
     type Error = DecodeError;
@@ -141,7 +141,7 @@ impl rustc_serialize::Decoder for Decoder {
             Some(Value::Table(..)) => {
                 let ret = try!(f(self));
                 match self.toml {
-                    Some(Value::Table(ref t)) if t.len() == 0 => {}
+                    Some(Value::Table(Table(ref t, _,))) if t.len() == 0 => {}
                     _ => return Ok(ret)
                 }
                 self.toml.take();
@@ -156,7 +156,7 @@ impl rustc_serialize::Decoder for Decoder {
     {
         let field = format!("{}", f_name);
         let toml = match self.toml {
-            Some(Value::Table(ref mut table)) => {
+            Some(Value::Table(Table(ref mut table, _))) => {
                 table.remove(&field)
                     .or_else(|| table.remove(&f_name.replace("_", "-")))
             },
@@ -165,7 +165,7 @@ impl rustc_serialize::Decoder for Decoder {
         let mut d = self.sub_decoder(toml, f_name);
         let ret = try!(f(&mut d));
         if let Some(value) = d.toml {
-            if let Some(Value::Table(ref mut table)) = self.toml {
+            if let Some(Value::Table(Table(ref mut table, _))) = self.toml {
                 table.insert(field, value);
             }
         }
@@ -260,7 +260,7 @@ impl rustc_serialize::Decoder for Decoder {
         where F: FnOnce(&mut Decoder, usize) -> Result<T, DecodeError>
     {
         let len = match self.toml {
-            Some(Value::Table(ref table)) => table.len(),
+            Some(Value::Table(Table(ref table, _))) => table.len(),
             ref found => return Err(self.mismatch("table", found)),
         };
         let ret = try!(f(self, len));
@@ -273,7 +273,7 @@ impl rustc_serialize::Decoder for Decoder {
     {
         match self.toml {
             Some(Value::Table(ref table)) => {
-                match table.iter().skip(idx).next() {
+                match table.0.iter().skip(idx).next() {
                     Some((key, _)) => {
                         let val = Value::String(format!("{}", key));
                         f(&mut self.sub_decoder(Some(val), &**key))
@@ -290,7 +290,7 @@ impl rustc_serialize::Decoder for Decoder {
     {
         match self.toml {
             Some(Value::Table(ref table)) => {
-                match table.iter().skip(idx).next() {
+                match table.0.iter().skip(idx).next() {
                     Some((_, value)) => {
                         // XXX: this shouldn't clone
                         f(&mut self.sub_decoder(Some(value.clone()), ""))
diff --git a/src/display.rs b/src/display.rs
index 0c561e8..74ec424 100644
--- a/src/display.rs
+++ b/src/display.rs
@@ -57,7 +57,7 @@ fn write_str(f: &mut fmt::Formatter, s: &str) -> fmt::Result {
 
 impl<'a, 'b> Printer<'a, 'b> {
     fn print(&mut self, table: &'a TomlTable) -> fmt::Result {
-        for (k, v) in table.iter() {
+        for (k, v) in table.0.iter() {
             match *v {
                 Table(..) => continue,
                 Array(ref a) => {
@@ -70,7 +70,7 @@ impl<'a, 'b> Printer<'a, 'b> {
             }
             try!(writeln!(self.output, "{} = {}", Key(&[k]), v));
         }
-        for (k, v) in table.iter() {
+        for (k, v) in table.0.iter() {
             match *v {
                 Table(ref inner) => {
                     self.stack.push(k);
@@ -127,13 +127,14 @@ impl<'a> fmt::Display for Key<'a> {
 #[allow(warnings)]
 mod tests {
     use Value;
+    use Table as TomlTable;
     use Value::{String, Integer, Float, Boolean, Datetime, Array, Table};
     use std::collections::BTreeMap;
 
     macro_rules! map( ($($k:expr => $v:expr),*) => ({
         let mut _m = BTreeMap::new();
         $(_m.insert($k.to_string(), $v);)*
-        _m
+        TomlTable::new(_m)
     }) );
 
     #[test]
diff --git a/src/encoder/mod.rs b/src/encoder/mod.rs
index 21185f4..ea8ef6a 100644
--- a/src/encoder/mod.rs
+++ b/src/encoder/mod.rs
@@ -31,7 +31,7 @@ use {Value, Table};
 /// let mut e = Encoder::new();
 /// my_struct.encode(&mut e).unwrap();
 ///
-/// assert_eq!(e.toml.get(&"foo".to_string()), Some(&Value::Integer(4)))
+/// assert_eq!(e.toml.0.get(&"foo".to_string()), Some(&Value::Integer(4)))
 /// # }
 /// ```
 pub struct Encoder {
@@ -73,12 +73,12 @@ enum State {
 impl Encoder {
     /// Constructs a new encoder which will emit to the given output stream.
     pub fn new() -> Encoder {
-        Encoder { state: State::Start, toml: BTreeMap::new() }
+        Encoder { state: State::Start, toml: Table(BTreeMap::new(), false) }
     }
 
     fn emit_value(&mut self, v: Value) -> Result<(), Error> {
         match mem::replace(&mut self.state, State::Start) {
-            State::NextKey(key) => { self.toml.insert(key, v); Ok(()) }
+            State::NextKey(key) => { self.toml.0.insert(key, v); Ok(()) }
             State::NextArray(mut vec) => {
                 // TODO: validate types
                 vec.push(v);
@@ -122,7 +122,7 @@ impl Encoder {
             State::NextKey(key) => {
                 let mut nested = Encoder::new();
                 try!(f(&mut nested));
-                self.toml.insert(key, Value::Table(nested.toml));
+                self.toml.0.insert(key, Value::Table(nested.toml));
                 Ok(())
             }
             State::NextArray(mut arr) => {
diff --git a/src/encoder/rustc_serialize.rs b/src/encoder/rustc_serialize.rs
index ab5e90f..830eb5e 100644
--- a/src/encoder/rustc_serialize.rs
+++ b/src/encoder/rustc_serialize.rs
@@ -193,8 +193,8 @@ impl rustc_serialize::Encodable for Value {
                 })
             }
             Value::Table(ref t) => {
-                e.emit_map(t.len(), |e| {
-                    for (i, (key, value)) in t.iter().enumerate() {
+                e.emit_map(t.0.len(), |e| {
+                    for (i, (key, value)) in t.0.iter().enumerate() {
                         try!(e.emit_map_elt_key(i, |e| e.emit_str(key)));
                         try!(e.emit_map_elt_val(i, |e| value.encode(e)));
                     }
@@ -212,6 +212,7 @@ mod tests {
 
     use {Encoder, Decoder, DecodeError};
     use Value;
+    use Table as TomlTable;
     use Value::{Table, Integer, Array, Float};
 
     macro_rules! encode( ($t:expr) => ({
@@ -228,7 +229,7 @@ mod tests {
     macro_rules! map( ($($k:ident, $v:expr),*) => ({
         let mut _m = BTreeMap::new();
         $(_m.insert(stringify!($k).to_string(), $v);)*
-        _m
+        TomlTable::new(_m)
     }) );
 
     #[test]
@@ -577,7 +578,9 @@ mod tests {
         #[derive(RustcEncodable, RustcDecodable, PartialEq, Debug)]
         struct Foo { a: BTreeMap<String, String> }
 
-        let v = Foo { a: map! { a, "foo".to_string() } };
+        let mut v = Foo { a: BTreeMap::new() };
+        v.a.insert("a".to_string(), "foo".to_string());
+
         let mut d = Decoder::new(Table(map! {
             a, Table(map! {
                 a, Value::String("foo".to_string())
diff --git a/src/lib.rs b/src/lib.rs
index 0196fbc..547c407 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -75,8 +75,24 @@ pub enum Value {
 /// Type representing a TOML array, payload of the Value::Array variant
 pub type Array = Vec<Value>;
 
+// The bool field flag is used during parsing and construction.
+// Is true if the given table was explicitly defined, false otherwise
+// e.g. in a toml document: `[a.b] foo = "bar"`, Table `a` would be false,
+// where table `b` (contained inside `a`) would be true.
 /// Type representing a TOML table, payload of the Value::Table variant
-pub type Table = BTreeMap<string::String, Value>;
+#[derive(Debug, Clone)]
+pub struct Table (pub BTreeMap<string::String, Value>, bool);
+impl Table {
+    /// Creates new TOML table
+    pub fn new(map: BTreeMap<string::String, Value>) -> Table {
+        Table(map, false)
+    }
+}
+impl PartialEq for Table {
+    fn eq(&self, other: &Table) -> bool {
+        self.0.eq(&other.0)
+    }
+}
 
 impl Value {
     /// Tests whether this and another value have the same type.
@@ -182,7 +198,7 @@ impl Value {
         let mut cur_value = self;
         for key in path.split('.') {
             match cur_value {
-                &Value::Table(ref hm) => {
+                &Value::Table(Table(ref hm, _)) => {
                     match hm.get(key) {
                         Some(v) => cur_value = v,
                         None => return None
diff --git a/src/parser.rs b/src/parser.rs
index 9a15de8..ccf0d3a 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -162,7 +162,7 @@ impl<'a> Parser<'a> {
     /// If an error occurs, the `errors` field of this parser can be consulted
     /// to determine the cause of the parse failure.
     pub fn parse(&mut self) -> Option<TomlTable> {
-        let mut ret = BTreeMap::new();
+        let mut ret = TomlTable(BTreeMap::new(), false);
         while self.peek(0).is_some() {
             self.ws();
             if self.newline() { continue }
@@ -189,7 +189,7 @@ impl<'a> Parser<'a> {
                 if keys.len() == 0 { return None }
 
                 // Build the section table
-                let mut table = BTreeMap::new();
+                let mut table = TomlTable(BTreeMap::new(), false);
                 if !self.values(&mut table) { return None }
                 if array {
                     self.insert_array(&mut ret, &*keys, Table(table), start)
@@ -715,7 +715,7 @@ impl<'a> Parser<'a> {
     fn inline_table(&mut self, _start: usize) -> Option<Value> {
         if !self.expect('{') { return None }
         self.ws();
-        let mut ret = BTreeMap::new();
+        let mut ret = TomlTable(BTreeMap::new(), true);
         if self.eat('}') { return Some(Table(ret)) }
         loop {
             let lo = self.next_pos();
@@ -734,14 +734,14 @@ impl<'a> Parser<'a> {
 
     fn insert(&mut self, into: &mut TomlTable, key: String, value: Value,
               key_lo: usize) {
-        if into.contains_key(&key) {
+        if into.0.contains_key(&key) {
             self.errors.push(ParserError {
                 lo: key_lo,
                 hi: key_lo + key.len(),
                 desc: format!("duplicate key: `{}`", key),
             })
         } else {
-            into.insert(key, value);
+            into.0.insert(key, value);
         }
     }
 
@@ -751,8 +751,8 @@ impl<'a> Parser<'a> {
         for part in keys[..keys.len() - 1].iter() {
             let tmp = cur;
 
-            if tmp.contains_key(part) {
-                match *tmp.get_mut(part).unwrap() {
+            if tmp.0.contains_key(part) {
+                match *tmp.0.get_mut(part).unwrap() {
                     Table(ref mut table) => {
                         cur = table;
                         continue
@@ -785,8 +785,8 @@ impl<'a> Parser<'a> {
             }
 
             // Initialize an empty table as part of this sub-key
-            tmp.insert(part.clone(), Table(BTreeMap::new()));
-            match *tmp.get_mut(part).unwrap() {
+            tmp.0.insert(part.clone(), Table(TomlTable(BTreeMap::new(), false)));
+            match *tmp.0.get_mut(part).unwrap() {
                 Table(ref mut inner) => cur = inner,
                 _ => unreachable!(),
             }
@@ -802,22 +802,22 @@ impl<'a> Parser<'a> {
         };
         let key = format!("{}", key);
         let mut added = false;
-        if !into.contains_key(&key) {
-            into.insert(key.clone(), Table(BTreeMap::new()));
+        if !into.0.contains_key(&key) {
+            into.0.insert(key.clone(), Table(TomlTable(BTreeMap::new(), true)));
             added = true;
         }
-        match into.get_mut(&key) {
+        match into.0.get_mut(&key) {
             Some(&mut Table(ref mut table)) => {
-                let any_tables = table.values().any(|v| v.as_table().is_some());
-                if !any_tables && !added {
+                let any_tables = table.0.values().any(|v| v.as_table().is_some());
+                if !added && (!any_tables || table.1) {
                     self.errors.push(ParserError {
                         lo: key_lo,
                         hi: key_lo + key.len(),
                         desc: format!("redefinition of table `{}`", key),
                     });
                 }
-                for (k, v) in value.into_iter() {
-                    if table.insert(k.clone(), v).is_some() {
+                for (k, v) in value.0.into_iter() {
+                    if table.0.insert(k.clone(), v).is_some() {
                         self.errors.push(ParserError {
                             lo: key_lo,
                             hi: key_lo + key.len(),
@@ -844,10 +844,10 @@ impl<'a> Parser<'a> {
             None => return,
         };
         let key = format!("{}", key);
-        if !into.contains_key(&key) {
-            into.insert(key.clone(), Array(Vec::new()));
+        if !into.0.contains_key(&key) {
+            into.0.insert(key.clone(), Array(Vec::new()));
         }
-        match *into.get_mut(&key).unwrap() {
+        match *into.0.get_mut(&key).unwrap() {
             Array(ref mut vec) => {
                 match vec.first() {
                     Some(ref v) if !v.same_type(&value) => {
@@ -1333,4 +1333,17 @@ trimmed in raw strings.
             c = 2
         ", "duplicate key `c` in table");
     }
+
+    #[test]
+    fn bad_table_redefine() {
+        let mut p = Parser::new("
+            [a]
+            foo=\"bar\"
+            [a.b]
+            foo=\"bar\"
+            [a]
+            baz=\"bar\"
+        ");
+        assert!(p.parse().is_none());
+    }
 }
diff --git a/tests/valid.rs b/tests/valid.rs
index 568518b..18c21d6 100644
--- a/tests/valid.rs
+++ b/tests/valid.rs
@@ -32,7 +32,7 @@ fn to_json(toml: Value) -> Json {
             let json = Json::Array(arr.into_iter().map(to_json).collect());
             if is_table {json} else {doit("array", json)}
         }
-        Table(table) => Json::Object(table.into_iter().map(|(k, v)| {
+        Table(table) => Json::Object(table.0.into_iter().map(|(k, v)| {
             (k, to_json(v))
         }).collect()),
     }
@@ -58,7 +58,7 @@ fn run(toml: &str, json: &str) {
 
     let table2 = Parser::new(&toml_string).parse().unwrap();
     // floats are a little lossy
-    if table2.values().any(|v| v.as_float().is_some()) { return }
+    if table2.0.values().any(|v| v.as_float().is_some()) { return }
     assert_eq!(toml, Table(table2));
 }
 
-- 
cgit v1.2.3