• Home
  • Docs
  • About






  • Decision Tree Model

    Example Usage

    1 2 3 4 5 6 7 8 9 10 from mltrain.supervised.DecisionTree import DecisionTree # Initialize the model model = DecisionTree(max_depth=10, min_samples_split=2, criteria='gini') # Train the model model.train(X_train, y_train) # Make predictions predictions = model.predict(X_test)

    Overview

    The DecisionTree class implements a decision tree model for classification tasks. This class allows you to build a tree-based model by splitting data based on feature values to make predictions. The model supports various splitting criteria, such as Gini impurity and entropy, to determine the best splits at each node.


    Hyperparameters

    • max_depth (default=10): The maximum depth the tree can grow to.
    • min_samples_split (default=2): The minimum number of samples required to split a node.
    • criteria (default='gini'): The criterion used for splitting ('gini' or 'entropy').

    Attributes

    • tree (Node): The root node of the decision tree.

    Methods

    __init__(self, max_depth=10, min_samples_split=2, criteria='gini')

    Initializes the Decision Tree model with the specified hyperparameters.

    • Args:
      • max_depth (int): Maximum depth of the tree.
      • min_samples_split (int): Minimum number of samples required to split a node.
      • criteria (str): Criterion used for splitting ('gini' or 'entropy').

    traverse_tree(self, x, node)

    Traverses the tree to make a prediction for a single sample.

    • Args:
      • x (numpy.ndarray): The input sample.
      • node (Node): The current node in the tree.
    • Returns:
      • Any: The predicted class label for the input sample.

    most_common_label(self, y)

    Determines the most common label in the target array.

    • Args:
      • y (numpy.ndarray): The array of target labels.
    • Returns:
      • Any: The most common label in the target array.

    predict(self, X)

    Predicts class labels for the given dataset.

    • Args:
      • X (numpy.ndarray): The dataset for which to make predictions.
    • Returns:
      • numpy.ndarray: An array of predicted class labels.

    split(self, X, y, feature, threshold)

    Splits the dataset based on the specified feature and threshold.

    • Args:
      • X (numpy.ndarray): The dataset to be split.
      • y (numpy.ndarray): The target labels.
      • feature (int): The feature index used for splitting.
      • threshold (float): The threshold value for splitting.
    • Returns:
      • Tuple[numpy.ndarray, numpy.ndarray]: Indices of the left and right subsets after the split.

    entropy(self, y)

    Calculates the entropy of the target labels.

    • Args:
      • y (numpy.ndarray): The array of target labels.
    • Returns:
      • float: The entropy value.

    gini(self, y)

    Calculates the Gini impurity of the target labels.

    • Args:
      • y (numpy.ndarray): The array of target labels.
    • Returns:
      • float: The Gini impurity value.

    information_gain(self, X, y, feature, threshold)

    Calculates the information gain from splitting the data on a specific feature and threshold.

    • Args:
      • X (numpy.ndarray): The dataset being split.
      • y (numpy.ndarray): The target labels.
      • feature (int): The feature index used for splitting.
      • threshold (float): The threshold value for splitting.
    • Returns:
      • float: The information gain from the split.

    best_split(self, X, y, features)

    Finds the best feature and threshold to split the data to maximize information gain.

    • Args:
      • X (numpy.ndarray): The dataset to be split.
      • y (numpy.ndarray): The target labels.
      • features (list[int]): The list of feature indices to consider for splitting.
    • Returns:
      • Tuple[int, float]: The best feature index and threshold value for the split.

    fit(self, X, y)

    Fits the decision tree model to the given dataset.

    • Args:
      • X (numpy.ndarray): The input features.
      • y (numpy.ndarray): The target labels.
    • Returns:
      • None