The Naive Bayes Classifer – Machine Learning, Engine Implementation Part 2


Category Components

  • Tests of Category
    using System;
    using System.Collections.Generic;
    using NaiveBayes.Category;
    using NaiveBayes.Variables;
    using NUnit.Framework;
    
    namespace TestNaiveBayes
    {
        [TestFixture]
        public class TestCategories
        {
            private ICategory _category;
            private ITargetObject _targetObject;
    
    
            [SetUp]
            public void SetUp()
            {
                CategoryFactory.AddCategoryType(TestCategoryFactory.CategoryName, TestCategoryFactory.CategoryTypes, TestCategoryFactory.CategoryAttributes);
    
                this._category = CategoryFactory.GetCategory(TestCategoryFactory.CategoryName);
                this._targetObject = new TargetObject(this._category, this._category.CategoryTypes[TestCategoryFactory.BeginnerCategory]);
            }
    
            [TearDown]
            public void TearDown()
            {
                CategoryFactory.RemoveCategory(TestCategoryFactory.CategoryName);
            }
    
            [Test]
            public void TestAfterInit()
            {
                Assert.AreEqual(this._category.Name, TestCategoryFactory.CategoryName);
                Assert.AreEqual(this._category.CategoryTypes.Count, TestCategoryFactory.CategoryTypes.Length);
                Assert.AreEqual(this._category.Attributes.Count, TestCategoryFactory.CategoryAttributes.Length);
            }
    
            [Test]
            public void TestCountOfVariablesAfterInit()
            {
                foreach (string attribute in TestCategoryFactory.CategoryAttributes)
                {
                    foreach (string categoryType in this._category.CategoryTypes)
                    {
                        Assert.AreEqual(this._category.Engine.GetCategoryType(categoryType)[attribute], 1);
                    }
                }
            }
    
            [Test]
            public void TestTeachCategoryByTargetObject()
            {
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]);
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]);
    
                this._category.Engine.TeachCategory(this._targetObject);
    
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]], 2);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]], 2);
    
                ITargetObject targetObjectOther = new TargetObject(this._category, this._category.CategoryTypes[0]);
                targetObjectOther.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]);
                targetObjectOther.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Freestyle]);
                targetObjectOther.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Freeride]);
    
                this._category.Engine.TeachCategory(new List<ITargetObject> { targetObjectOther, this._targetObject });
    
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Freestyle]], 2);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Freeride]], 2);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]], 3);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]], 4);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Expensive]], 1);
            }
    
            [Test]
            public void TestReset()
            {
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]);
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]);
    
                this._category.Engine.TeachCategory(this._targetObject);
    
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]], 2);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._targetObject.CategoryType)[TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]], 2);
    
                this._category.Engine.Reset();
    
                foreach (string attribute in TestCategoryFactory.CategoryAttributes)
                {
                    foreach (string categoryType in this._category.CategoryTypes)
                    {
                        Assert.AreEqual(this._category.Engine.GetCategoryType(categoryType)[attribute], 1);
                    }
                }
            }
    
            [Test]
            public void TestApriori()
            {
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]);
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]);
    
                this._category.Engine.TeachCategory(this._targetObject);
    
                ITargetObject targetObjectOther = new TargetObject(this._category, this._category.CategoryTypes[TestCategoryFactory.MediumCategory]);
    
                this._category.Engine.TeachCategory(new List<ITargetObject> { targetObjectOther, this._targetObject, this._targetObject });
    
                this._category.Engine.PrepareToClassification();
    
                Assert.AreEqual(this._category.Engine.GetApriori(TestCategoryFactory.CategoryTypes[TestCategoryFactory.BeginnerCategory]), 3.0 / 4.0);
                Assert.AreEqual(this._category.Engine.GetApriori(TestCategoryFactory.CategoryTypes[TestCategoryFactory.MediumCategory]), 1.0 / 4.0);
                Assert.AreEqual(this._category.Engine.GetApriori(TestCategoryFactory.CategoryTypes[TestCategoryFactory.AdvancedCategory]), 0.0 / 4.0);
            }
    
            [Test]
            public void TestProbability()
            {
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]);
                this._targetObject.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]);
    
                this._category.Engine.TeachCategory(this._targetObject);
    
                ITargetObject targetObjectOther = new TargetObject(this._category, this._category.CategoryTypes[TestCategoryFactory.BeginnerCategory]);
                targetObjectOther.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]);
                targetObjectOther.SetAttributeExist(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Freestyle]);
    
                this._category.Engine.TeachCategory(new List<ITargetObject> { targetObjectOther, this._targetObject });
    
                this._category.Engine.PrepareToClassification();
    
                const double numberOfTargetObject = 3.0;
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._category.CategoryTypes[TestCategoryFactory.BeginnerCategory]).GetProbability(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]), 4.0 / numberOfTargetObject);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._category.CategoryTypes[TestCategoryFactory.BeginnerCategory]).GetProbability(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.AllMountain]), 3.0 / numberOfTargetObject);
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._category.CategoryTypes[TestCategoryFactory.BeginnerCategory]).GetProbability(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Freestyle]), 2.0 / numberOfTargetObject);
    
                Assert.AreEqual(this._category.Engine.GetCategoryType(this._category.CategoryTypes[TestCategoryFactory.MediumCategory]).GetProbability(TestCategoryFactory.CategoryAttributes[TestCategoryFactory.Cheap]), 0.0 / numberOfTargetObject);
            }
    
            [Test]
            [ExpectedException(&quot;NaiveBayes.Variables.CannotTeachByTargetObjectToClassifyException&quot;)]
            public void TestTeachingByObjectToClassify()
            {
                ITargetObject targetObjectOther = new TargetObject(this._category, String.Empty);
                this._category.Engine.TeachCategory(targetObjectOther);
            }
        }
    }
    
  • Interface – ICategory
    using System.Collections.Generic;
    
    namespace NaiveBayes.Category
    {
        public interface ICategory
        {
            string Name { get; }
    
            ICategoryEngine Engine { get; }
    
            List<string> CategoryTypes { get; }
            List<string> Attributes { get; }
        }
    }
    
  • Implementation – Category
    using System.Collections.Generic;
    
    namespace NaiveBayes.Category
    {
        public class Category : ICategory
        {
            internal Category(string name, IEnumerable<string> categoryTypes, IEnumerable<string> attributes)
            {
                this.Name = name;
                this.Engine = new CategoryEngine(new List<string>(categoryTypes), new List<string>(attributes));
            }
    
            #region ICategory Members
    
            public ICategoryEngine Engine { get; private set; }
    
            public List<string> CategoryTypes
            {
                get { return ((CategoryEngine) this.Engine).CategoryTypes; }
            }
    
            public List<string> Attributes
            {
                get { return ((CategoryEngine) this.Engine).Attributes; }
            }
    
            public string Name { get; private set; }
    
            #endregion
        }
    }
    
  • Interface – ICategoryEngine
    using System.Collections.Generic;
    using NaiveBayes.Variables;
    
    namespace NaiveBayes.Category
    {
        public interface ICategoryEngine
        {
            ICategoryType GetCategoryType(string key);
            bool CanBeClassify { get; }
            void TeachCategory(ITargetObject targetObject);
            void TeachCategory(List<ITargetObject> targetObjects);
    
            void PrepareToClassification();
    
            double GetApriori(string categoryTypeName);
            void Reset();
        }
    }
    
  • Implementation – CategoryEngine
    using System;
    using System.Collections.Generic;
    using NaiveBayes.Variables;
    
    namespace NaiveBayes.Category
    {
        public class CategoryEngine : ICategoryEngine
        {
            private readonly Dictionary<string, double> _apriories;
            private readonly List<string> _attributes;
            private readonly Dictionary<string, ICategoryType> _categoryTypeAttributes;
            private readonly Dictionary<string, int> _categoryTypeEvidence;
    
            private readonly List<string> _categoryTypes;
    
            public CategoryEngine(List<string> categoryTypes, List<string> attributes)
            {
                this._categoryTypes = categoryTypes;
                this._attributes = attributes;
    
                this.CanBeClassify = false;
    
                this._apriories = new Dictionary<string, double>(this._categoryTypes.Count);
    
                this._categoryTypeEvidence = new Dictionary<string, int>(categoryTypes.Count);
                this.PrepareCategoryTypeEvidence();
    
                this._categoryTypeAttributes = new Dictionary<string, ICategoryType>(this._categoryTypes.Count);
                this.PrepareCategoryTypeAttributes();
            }
    
            #region ICategoryEngine Members
    
            public void PrepareToClassification()
            {
                this.PrepareApriori();
    
                this.PrepareProbabilities();
    
                this.CanBeClassify = true;
            }
    
            public bool CanBeClassify { get; private set; }
    
            public void TeachCategory(ITargetObject targetObject)
            {
                if (String.IsNullOrEmpty(targetObject.CategoryType))
                {
                    throw new CannotTeachByTargetObjectToClassifyException(&quot;Object to classify cannot be used by teaching method.&quot;);
                }
    
                ICategoryType categoryType = this._categoryTypeAttributes[targetObject.CategoryType];
                this._categoryTypeEvidence[targetObject.CategoryType] += 1;
    
                foreach (string attribute in targetObject.ExistsAttributes.Keys)
                {
                    categoryType.AddAttribute(attribute);
                }
    
                this.SetCanBeClassifyToFalse();
            }
    
            public void TeachCategory(List<ITargetObject> targetObjects)
            {
                foreach (ITargetObject targetObject in targetObjects)
                {
                    this.TeachCategory(targetObject);
                }
            }
    
            public void Reset()
            {
                foreach (string categoryType in this._categoryTypes)
                {
                    foreach (string attribute in this._attributes)
                    {
                        ((CategoryType) this._categoryTypeAttributes[categoryType]).ResetAttribute(attribute);
                    }
                }
    
                this.SetCanBeClassifyToFalse();
            }
    
            public double GetApriori(string categoryTypeName)
            {
                return this._apriories[categoryTypeName];
            }
    
            public ICategoryType GetCategoryType(string key)
            {
                return this._categoryTypeAttributes[key];
            }
    
            #endregion
    
            internal List<string> CategoryTypes
            {
                get
                {
                    return this._categoryTypes;
                }
            }
    
            internal List<string> Attributes
            {
                get
                {
                    return this._attributes;
                }
            }
    
            private void SetCanBeClassifyToFalse()
            {
                if (this.CanBeClassify)
                {
                    this.CanBeClassify = false;
                }
            }
    
            private int GetSummaryEvidence()
            {
                int summaryEvidence = 0;
    
                foreach (int value in this._categoryTypeEvidence.Values)
                {
                    summaryEvidence += value;
                }
    
                return summaryEvidence;
            }
    
            private void PrepareCategoryTypeEvidence()
            {
                foreach (string categoryType in this._categoryTypes)
                {
                    this._categoryTypeEvidence.Add(categoryType, 0);
                }
            }
    
            private void PrepareCategoryTypeAttributes()
            {
                foreach (string categoryType in this._categoryTypes)
                {
                    this._categoryTypeAttributes.Add(categoryType, new CategoryType(this._attributes));
                }
            }
    
            private void PrepareProbabilities()
            {
                foreach (string attribute in this._attributes)
                {
                    this.PrepareProbabilityForAttribute(attribute);
                }
            }
    
            private void PrepareProbabilityForAttribute(string attribute)
            {
                foreach (string categoryType in this._categoryTypes)
                {
                    int categoryTypeAttributeCount = this._categoryTypeAttributes[categoryType][attribute];
                    double categoryTypeEvidence = this._categoryTypeEvidence[categoryType];
    
                    if (categoryTypeEvidence == 0)
                    {
                        this._categoryTypeAttributes[categoryType].SetProbability(attribute, 0.0);
                    }
                    else
                    {
                        this._categoryTypeAttributes[categoryType].SetProbability(attribute,
                                                                                  categoryTypeAttributeCount/
                                                                                  categoryTypeEvidence);
                    }
                }
            }
    
            private void PrepareApriori()
            {
                int summaryEvidence = this.GetSummaryEvidence();
    
                foreach (string categoryType in this._categoryTypes)
                {
                    this._apriories[categoryType] = this._categoryTypeEvidence[categoryType]/(double) summaryEvidence;
                }
            }
        }
    }
    
  • Interface – ICategoryType
    namespace NaiveBayes.Category
    {
        public interface ICategoryType
        {
            void AddAttribute(string attribute);
            int this[string attributeName] { get; }
    
            void SetProbability(string attribute, double value);
            double GetProbability(string attribute);
        }
    }
    
  • Implementation – CategoryType
    using System.Collections.Generic;
    
    namespace NaiveBayes.Category
    {
        public class CategoryType : ICategoryType
        {
            private readonly List<string> _attributes;
            private readonly Dictionary<string, int> _attributesMap;
            private readonly Dictionary<string, double> _probabilityMap;
    
            public CategoryType(List<string> attributes)
            {
                this._attributes = attributes;
    
                this._attributesMap = new Dictionary<string, int>(this._attributes.Count);
                this.PrepareAttributes();
    
                this._probabilityMap = new Dictionary<string, double>(this._attributes.Count);
                this.PrepareProbabilities();
            }
    
            #region ICategoryType Members
    
            public void AddAttribute(string attribute)
            {
                this._attributesMap[attribute] += 1;
            }
    
            public int this[string attributeName]
            {
                get { return this._attributesMap[attributeName]; }
            }
    
            public void SetProbability(string attribute, double value)
            {
                this._probabilityMap[attribute] = value;
            }
    
            public double GetProbability(string attribute)
            {
                return this._probabilityMap[attribute];
            }
    
            #endregion
    
            private void PrepareProbabilities()
            {
                foreach (string attribute in this._attributes)
                {
                    this._probabilityMap.Add(attribute, 0.0);
                }
            }
    
            private void PrepareAttributes()
            {
                foreach (string variable in this._attributes)
                {
                    this._attributesMap.Add(variable, 1);
                }
            }
    
            internal void ResetAttribute(string attribute)
            {
                this._attributesMap[attribute] = 1;
            }
        }
    }
    
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s