Skip to content

Trait bounds of RandomForest::fit conflict with those of predict and other metrics #322

@DaGaMs

Description

@DaGaMs

I am trying to perform cross validation of a RandomForest like this:

let cv_score = cross_validate(
    RandomForestClassifier::new(),
    x_train,
    y_train,
    RandomForestClassifierParameters::default()
        .with_criterion(SplitCriterion::ClassificationError)
        .with_n_trees(*n_tree)
        .with_m(*m_feat)
        .with_min_samples_split(*m_split)
        .with_min_samples_leaf(*m_leaf),
    &KFold::default().with_n_splits(5),
    &precision
).unwrap();

x_train is a DenseMatrix<f32> and y_train is a Vec<u16> of 0 and 1.

The problem is that RandomForestClassifier::fit expects y to be Number + Ord, whereas precision expects y to be Number + RealNumber + FloatNumber. As far as I can see, RealNumber can never be Ord, so precision, f1 and roc_auc_score cannot be used in cross_validate with RandomForectClassifier directly.

For now, I worked around it by defining my own precision function that converts u16 to f32 before calling precision, but I suppose this should be fixed in the framework. The logical thing to do IMO would be to stop requiring Ord for y in RandomForrestClassifier::fit?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions