Изменения

Перейти к: навигация, поиск

Бустинг, AdaBoost

4492 байта добавлено, 19:06, 4 сентября 2022
м
rollbackEdits.php mass rollback
# Склонен к переобучению при наличии значительного уровня шума в данных;
# Требует достаточно длинных обучающих выборок. Другие методы линейной коррекции, в частности, бэггинг, способны строить алгоритмы сопоставимого качества по меньшим выборкам данных.
 
== Пример кода ==
===Пример кода на python для scikit-learn===
'''val '''f1Score = '''new '''FMeasure().measure(predictions, y)
plot(x, y, ada)
 
===Пример на языке Java===
Пример классификации с применением <code>smile.classification.AdaBoost</code><ref>[https://haifengl.github.io/smile/api/java/smile/classification/AdaBoost.html/ Smile, AdaBoost]</ref>
 
<code>Maven</code> зависимость:
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>1.5.2</version>
</dependency>
 
'''import''' smile.classification.AdaBoost;
'''import''' smile.data.parser.ArffParser;
'''import''' smile.validation.Accuracy;
'''import''' smile.validation.ClassificationMeasure;
'''import''' smile.validation.FMeasure;
'''import''' java.util.Arrays;
 
<font color="green">// load train and test datasets</font>
'''var''' arffParser = new ArffParser();
arffParser.setResponseIndex(0);
'''var''' train = arffParser.parse(this.getClass().getResourceAsStream("train.arff"));
'''var''' test = arffParser.parse(this.getClass().getResouceAsStream("test.arff"));
<font color="green">// create adaboost classifier</font>
'''var''' forest = new AdaBoost(train.attributes(), train.x(), train.labels(), 200, 4);
<font color="green">// measure accuracy and F1-measure on test dataset</font>
'''var''' measures = new ClassificationMeasure[]{new FMeasure(), new Accuracy()};
'''var''' results = forest.test(test.x(), test.labels(), measures);
System.out.println(Arrays.deepToString(results));
 
=== Пример на языке R ===
{{Main|Примеры кода на R}}
 
<font color="gray"># loading libraries</font>
install.packages(<font color="green">"mlr"</font>)
library(mlr)
<font color="gray"># loading data</font>
train <- read.csv(<font color="green">"input.csv"</font>)
test <- read.csv(<font color="green">"testInput.csv"</font>)
<font color="gray"># loading GBM</font>
getParamSet(<font color="green">"classif.gbm"</font>)
baseLearner <- makeLearner(<font color="green">"classif.gbm"</font>, <font color="#660099">predict.type</font> = <font color="green">"response"</font>)
<font color="gray"># specifying parameters</font>
controlFunction <- makeTuneControlRandom(<font color="#660099">maxit</font> = <font color="blue">50000</font>) <font color="gray"># specifying tuning method</font>
cvFunction <- makeResampleDesc(<font color="green">"CV"</font>, <font color="#660099">iters</font> = <font color="blue">100000</font>) <font color="gray"># definig cross-validation function</font>
gbmParameters<- makeParamSet(
makeDiscreteParam(<font color="green">"distribution"</font>, <font color="#660099">values</font> = <font color="green">"bernoulli"</font>),
makeIntegerParam(<font color="green">"n.trees"</font>, <font color="#660099">lower</font> = <font color="blue">100</font>, <font color="#660099">upper</font> = <font color="blue">1000</font>), <font color="gray"># number of trees</font>
makeIntegerParam(<font color="green">"interaction.depth"</font>, <font color="#660099">lower</font> = <font color="blue">2</font>, <font color="#660099">upper</font> = <font color="blue">10</font>), <font color="gray"># depth of tree</font>
makeIntegerParam(<font color="green">"n.minobsinnode"</font>, <font color="#660099">lower</font> = <font color="blue">10</font>, <font color="#660099">upper</font> = <font color="blue">80</font>),
makeNumericParam(<font color="green">"shrinkage"</font>, <font color="#660099">lower</font> = <font color="blue">0.01</font>, <font color="#660099">upper</font> = <font color="blue">1</font>)
)
<font color="gray"># tunning parameters</font>
gbmTuningParameters <- tuneParams(<font color="#660099">learner</font> = baseLearner,
<font color="#660099">task</font> = trainTask,
<font color="#660099">resampling</font> = cvFunction,
<font color="#660099">measures</font> = acc,
<font color="#660099">par.set</font> = gbmParameters,
<font color="#660099">control</font> = controlFunction)
<font color="gray"># creating model parameters</font>
model <- setHyperPars(<font color="#660099">learner</font> = baseLearner, <font color="#660099">par.vals</font> = gbmTuningParameters)
<font color="gray"># evaluating model</font>
fit <- train(model, train)
predictions <- predict(fit, test)
== См. также ==
[[Категория: Автоматическое машинное обучение]]
[[Категория: Машинное обучение]]
[[Категория: Ансамбли]]
1632
правки

Навигация