Predicting period and mass ratio of hot subdwarfs

The aim of this tutorial is to show how NNaPS can be used to predict the orbital parameters of hot subdwarf binaries based on a relatively small sample of full MESA models.

We will use a sample of ~2000 binary MESA models focused on low mass binaries that interact on the red giant branch. We will train a fully connected model to predict if the interaction phase is stable, the orbital parameters (period and mass ratio) after the mass-loss phase and the final product (hot subdwarf, horizontal branch star or He white dwarf) after the mass-loss phase.

Then we will check how different metalicity distributions in the intial population will affect those final orbital parameters.

[1]:
import pandas as pd
import pylab as pl
import seaborn as sns

import yaml
from nnaps import predictors
Using TensorFlow backend.

Obtaining the data

Let’s download the extracted MESA models:

[2]:
data = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/period_mass_ratio_mesa_models_extracted.csv')
[3]:
data
[3]:
M1_init M2_init q_init P_init FeH_init stability product P_final q_final
0 1.791000 1.138 1.573813 3.570104 -0.109671 stable He-WD 52.924494 0.265416
1 0.954000 0.355 2.687324 410.120037 -0.780000 unstable UK 0.618103 1.245376
2 1.297000 0.865 1.499421 371.590147 -0.218502 stable sdB 1130.865820 0.541258
3 1.554999 0.579 2.685664 61.920032 -0.282191 unstable UK 0.045363 0.560722
4 1.291999 0.378 3.417988 23.240011 -0.493334 unstable UK 0.005948 0.740944
... ... ... ... ... ... ... ... ... ...
2340 0.922000 0.657 1.403348 139.840010 -0.780000 stable He-WD 326.953825 0.623666
2341 1.496000 1.323 1.130763 138.060030 -0.840593 stable HB 753.134322 0.402106
2342 1.154000 0.410 2.814634 311.360022 0.182481 unstable UK 0.187493 0.932706
2343 0.771000 0.714 1.079832 231.110008 -1.292690 stable He-WD 438.852463 0.658157
2344 0.895000 0.653 1.370597 352.640024 -0.780000 stable sdB 660.178499 0.708158

2345 rows × 9 columns

The features that we are interested in are the inital parameters of the models: - M1_init: initial mass of the donor star - q_init: initial mass ratio - P_init: initial orbital period - FeH_init: metalicity

The properties that we want to predict are: - stability: if the mass loss phase is stable or not - product: what type of star we get after the mass loss phase (sdB, HB, He-WD or UK for systems that are unstable) - P_final: orbital period after interaction phase - q_final: mass ratio after the interaction phase

[5]:
features = ['M1_init', 'q_init', 'P_init', 'FeH_init']
classifiers = ['stability', 'product']
regressors = ['P_final', 'q_final']

Data analysis

Let’s start with having a look at the data we are working with.

[ ]:
stable = data[data['stability'] == 'stable']
[6]:
data['stability'].value_counts()
[6]:
stable      1586
unstable     759
Name: stability, dtype: int64
[7]:
data['product'].value_counts()
[7]:
He-WD    853
UK       759
HB       421
sdB      312
Name: product, dtype: int64
[8]:
sns.pairplot(stable, vars=features, hue='product')
[8]:
<seaborn.axisgrid.PairGrid at 0x7f2ae4661b10>
../_images/tutorials_period_mass_ratio_models_14_1.png
[9]:
sns.scatterplot('P_init', 'P_final', data=stable, hue='FeH_init')
[9]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f2ae3660850>
../_images/tutorials_period_mass_ratio_models_15_1.png
[10]:
sns.scatterplot('q_init', 'q_final', data=stable, hue='FeH_init')
[10]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f2ae1e0dc90>
../_images/tutorials_period_mass_ratio_models_16_1.png

Modeling

Now we can make a predictive model. We will use a fully connected shallow neural network with 3 layers. Important to notice is that the features are automatically scaled, but since we are working with two targets (P_final and q_final) that are very different in range, we need to apply a scaler to these variables to avoid one of them dominating the loss function.

[11]:
setup = """
features:
    - M1_init
    - q_init
    - P_init
    - FeH_init
regressors:
    P_final:
        processor: StandardScaler
    q_final:
        processor: StandardScaler
classifiers:
    - stability
    - product
model:
   - {'layer':'Dense',   'args':[200], 'kwargs': {'activation':'relu', 'name':'FC_1'} }
   - {'layer':'Dense',   'args':[100], 'kwargs': {'activation':'relu', 'name':'FC_2'} }
   - {'layer':'Dense',   'args':[50],  'kwargs': {'activation':'relu', 'name':'FC_3'} }
optimizer: adam
"""
setup = yaml.safe_load(setup)
[12]:
predictor = predictors.FCPredictor(setup=setup, data=data)
[13]:
predictor.fit(epochs=100, batch_size=124, reduce_lr=True)
Train on 1876 samples, validate on 469 samples
Epoch 1/100
 - 1s - loss: 3.3076 - P_final_loss: 0.5563 - q_final_loss: 0.7013 - stability_loss: 0.6195 - product_loss: 1.3635 - P_final_mae: 0.6184 - q_final_mae: 0.3683 - stability_accuracy: 0.7777 - product_accuracy: 0.3241 - val_loss: 2.2777 - val_P_final_loss: 0.3099 - val_q_final_loss: 0.1414 - val_stability_loss: 0.5188 - val_product_loss: 1.3055 - val_P_final_mae: 0.4181 - val_q_final_mae: 0.2271 - val_stability_accuracy: 0.8380 - val_product_accuracy: 0.3198
Epoch 2/100
 - 0s - loss: 2.3581 - P_final_loss: 0.2519 - q_final_loss: 0.4045 - stability_loss: 0.4448 - product_loss: 1.2132 - P_final_mae: 0.3392 - q_final_mae: 0.2182 - stability_accuracy: 0.8417 - product_accuracy: 0.4488 - val_loss: 1.7413 - val_P_final_loss: 0.2005 - val_q_final_loss: 0.0971 - val_stability_loss: 0.3377 - val_product_loss: 1.1101 - val_P_final_mae: 0.3216 - val_q_final_mae: 0.1601 - val_stability_accuracy: 0.8721 - val_product_accuracy: 0.6141
Epoch 3/100
 - 0s - loss: 1.8151 - P_final_loss: 0.1397 - q_final_loss: 0.3311 - stability_loss: 0.3114 - product_loss: 0.9900 - P_final_mae: 0.2485 - q_final_mae: 0.1622 - stability_accuracy: 0.8849 - product_accuracy: 0.6588 - val_loss: 1.2836 - val_P_final_loss: 0.1059 - val_q_final_loss: 0.0753 - val_stability_loss: 0.2378 - val_product_loss: 0.8689 - val_P_final_mae: 0.2266 - val_q_final_mae: 0.1394 - val_stability_accuracy: 0.9275 - val_product_accuracy: 0.7207
Epoch 4/100
 - 0s - loss: 1.4653 - P_final_loss: 0.1084 - q_final_loss: 0.3225 - stability_loss: 0.2348 - product_loss: 0.7728 - P_final_mae: 0.2179 - q_final_mae: 0.1514 - stability_accuracy: 0.9184 - product_accuracy: 0.7303 - val_loss: 0.9856 - val_P_final_loss: 0.0873 - val_q_final_loss: 0.0602 - val_stability_loss: 0.1759 - val_product_loss: 0.6655 - val_P_final_mae: 0.2056 - val_q_final_mae: 0.1291 - val_stability_accuracy: 0.9403 - val_product_accuracy: 0.7612
Epoch 5/100
 - 0s - loss: 1.2421 - P_final_loss: 0.0994 - q_final_loss: 0.3190 - stability_loss: 0.1946 - product_loss: 0.6452 - P_final_mae: 0.2057 - q_final_mae: 0.1311 - stability_accuracy: 0.9254 - product_accuracy: 0.7862 - val_loss: 0.8992 - val_P_final_loss: 0.0743 - val_q_final_loss: 0.1368 - val_stability_loss: 0.1508 - val_product_loss: 0.5419 - val_P_final_mae: 0.1865 - val_q_final_mae: 0.1790 - val_stability_accuracy: 0.9403 - val_product_accuracy: 0.8145
Epoch 6/100
 - 0s - loss: 1.1037 - P_final_loss: 0.0810 - q_final_loss: 0.3026 - stability_loss: 0.1618 - product_loss: 0.5323 - P_final_mae: 0.1980 - q_final_mae: 0.1636 - stability_accuracy: 0.9302 - product_accuracy: 0.8124 - val_loss: 0.7422 - val_P_final_loss: 0.0668 - val_q_final_loss: 0.0983 - val_stability_loss: 0.1185 - val_product_loss: 0.4632 - val_P_final_mae: 0.1793 - val_q_final_mae: 0.1445 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8209
Epoch 7/100
 - 0s - loss: 0.9975 - P_final_loss: 0.0707 - q_final_loss: 0.2901 - stability_loss: 0.1473 - product_loss: 0.4778 - P_final_mae: 0.1839 - q_final_mae: 0.1594 - stability_accuracy: 0.9302 - product_accuracy: 0.8358 - val_loss: 0.6717 - val_P_final_loss: 0.0628 - val_q_final_loss: 0.0918 - val_stability_loss: 0.1064 - val_product_loss: 0.4159 - val_P_final_mae: 0.1680 - val_q_final_mae: 0.1558 - val_stability_accuracy: 0.9574 - val_product_accuracy: 0.8294
Epoch 8/100
 - 0s - loss: 0.9291 - P_final_loss: 0.0652 - q_final_loss: 0.2819 - stability_loss: 0.1368 - product_loss: 0.4310 - P_final_mae: 0.1727 - q_final_mae: 0.1525 - stability_accuracy: 0.9344 - product_accuracy: 0.8417 - val_loss: 0.6311 - val_P_final_loss: 0.0612 - val_q_final_loss: 0.0867 - val_stability_loss: 0.1064 - val_product_loss: 0.3820 - val_P_final_mae: 0.1736 - val_q_final_mae: 0.1252 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8465
Epoch 9/100
 - 0s - loss: 0.8920 - P_final_loss: 0.0605 - q_final_loss: 0.2840 - stability_loss: 0.1290 - product_loss: 0.3871 - P_final_mae: 0.1666 - q_final_mae: 0.1426 - stability_accuracy: 0.9350 - product_accuracy: 0.8555 - val_loss: 0.6306 - val_P_final_loss: 0.0567 - val_q_final_loss: 0.1263 - val_stability_loss: 0.0974 - val_product_loss: 0.3570 - val_P_final_mae: 0.1670 - val_q_final_mae: 0.1638 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8507
Epoch 10/100
 - 0s - loss: 0.8439 - P_final_loss: 0.0602 - q_final_loss: 0.2721 - stability_loss: 0.1371 - product_loss: 0.3827 - P_final_mae: 0.1690 - q_final_mae: 0.1385 - stability_accuracy: 0.9382 - product_accuracy: 0.8561 - val_loss: 0.5496 - val_P_final_loss: 0.0546 - val_q_final_loss: 0.0721 - val_stability_loss: 0.0950 - val_product_loss: 0.3326 - val_P_final_mae: 0.1544 - val_q_final_mae: 0.1383 - val_stability_accuracy: 0.9638 - val_product_accuracy: 0.8635
Epoch 11/100
 - 0s - loss: 0.8921 - P_final_loss: 0.0612 - q_final_loss: 0.3935 - stability_loss: 0.1295 - product_loss: 0.3574 - P_final_mae: 0.1701 - q_final_mae: 0.1654 - stability_accuracy: 0.9387 - product_accuracy: 0.8577 - val_loss: 0.5804 - val_P_final_loss: 0.0490 - val_q_final_loss: 0.1264 - val_stability_loss: 0.0943 - val_product_loss: 0.3174 - val_P_final_mae: 0.1495 - val_q_final_mae: 0.1714 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8785
Epoch 12/100
 - 0s - loss: 0.8487 - P_final_loss: 0.0586 - q_final_loss: 0.3174 - stability_loss: 0.1221 - product_loss: 0.3325 - P_final_mae: 0.1562 - q_final_mae: 0.1660 - stability_accuracy: 0.9419 - product_accuracy: 0.8614 - val_loss: 0.5460 - val_P_final_loss: 0.0508 - val_q_final_loss: 0.0680 - val_stability_loss: 0.1125 - val_product_loss: 0.3183 - val_P_final_mae: 0.1565 - val_q_final_mae: 0.1182 - val_stability_accuracy: 0.9467 - val_product_accuracy: 0.8721
Epoch 13/100
 - 0s - loss: 0.8180 - P_final_loss: 0.0555 - q_final_loss: 0.2856 - stability_loss: 0.1255 - product_loss: 0.3222 - P_final_mae: 0.1648 - q_final_mae: 0.1299 - stability_accuracy: 0.9392 - product_accuracy: 0.8721 - val_loss: 0.5329 - val_P_final_loss: 0.0491 - val_q_final_loss: 0.1100 - val_stability_loss: 0.0889 - val_product_loss: 0.2902 - val_P_final_mae: 0.1511 - val_q_final_mae: 0.1808 - val_stability_accuracy: 0.9638 - val_product_accuracy: 0.8955
Epoch 14/100
 - 0s - loss: 0.7635 - P_final_loss: 0.0490 - q_final_loss: 0.2639 - stability_loss: 0.1188 - product_loss: 0.3074 - P_final_mae: 0.1515 - q_final_mae: 0.1644 - stability_accuracy: 0.9440 - product_accuracy: 0.8769 - val_loss: 0.5311 - val_P_final_loss: 0.0473 - val_q_final_loss: 0.1135 - val_stability_loss: 0.0914 - val_product_loss: 0.2855 - val_P_final_mae: 0.1433 - val_q_final_mae: 0.1476 - val_stability_accuracy: 0.9616 - val_product_accuracy: 0.8870
Epoch 15/100
 - 0s - loss: 0.7786 - P_final_loss: 0.0478 - q_final_loss: 0.2968 - stability_loss: 0.1276 - product_loss: 0.3036 - P_final_mae: 0.1492 - q_final_mae: 0.1559 - stability_accuracy: 0.9440 - product_accuracy: 0.8817 - val_loss: 0.4969 - val_P_final_loss: 0.0488 - val_q_final_loss: 0.0882 - val_stability_loss: 0.0880 - val_product_loss: 0.2759 - val_P_final_mae: 0.1546 - val_q_final_mae: 0.1737 - val_stability_accuracy: 0.9595 - val_product_accuracy: 0.8913
Epoch 16/100
 - 0s - loss: 0.7438 - P_final_loss: 0.0523 - q_final_loss: 0.2630 - stability_loss: 0.1230 - product_loss: 0.2981 - P_final_mae: 0.1481 - q_final_mae: 0.1420 - stability_accuracy: 0.9430 - product_accuracy: 0.8913 - val_loss: 0.4654 - val_P_final_loss: 0.0469 - val_q_final_loss: 0.0682 - val_stability_loss: 0.0893 - val_product_loss: 0.2656 - val_P_final_mae: 0.1490 - val_q_final_mae: 0.1190 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9041
Epoch 17/100
 - 0s - loss: 0.7244 - P_final_loss: 0.0476 - q_final_loss: 0.9122 - stability_loss: 0.1176 - product_loss: 0.2901 - P_final_mae: 0.1460 - q_final_mae: 0.1270 - stability_accuracy: 0.9472 - product_accuracy: 0.8923 - val_loss: 0.5367 - val_P_final_loss: 0.0444 - val_q_final_loss: 0.1332 - val_stability_loss: 0.0998 - val_product_loss: 0.2656 - val_P_final_mae: 0.1381 - val_q_final_mae: 0.1568 - val_stability_accuracy: 0.9574 - val_product_accuracy: 0.9019
Epoch 18/100
 - 0s - loss: 0.9268 - P_final_loss: 0.0532 - q_final_loss: 0.4544 - stability_loss: 0.1233 - product_loss: 0.2855 - P_final_mae: 0.1537 - q_final_mae: 0.2999 - stability_accuracy: 0.9451 - product_accuracy: 0.8875 - val_loss: 0.7742 - val_P_final_loss: 0.0510 - val_q_final_loss: 0.3940 - val_stability_loss: 0.0872 - val_product_loss: 0.2553 - val_P_final_mae: 0.1577 - val_q_final_mae: 0.2756 - val_stability_accuracy: 0.9616 - val_product_accuracy: 0.9041
Epoch 19/100
 - 0s - loss: 0.7957 - P_final_loss: 0.0532 - q_final_loss: 0.3301 - stability_loss: 0.1156 - product_loss: 0.2775 - P_final_mae: 0.1597 - q_final_mae: 0.2370 - stability_accuracy: 0.9435 - product_accuracy: 0.8993 - val_loss: 0.4659 - val_P_final_loss: 0.0459 - val_q_final_loss: 0.0833 - val_stability_loss: 0.0910 - val_product_loss: 0.2492 - val_P_final_mae: 0.1502 - val_q_final_mae: 0.1570 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9126
Epoch 20/100
 - 0s - loss: 0.6928 - P_final_loss: 0.0476 - q_final_loss: 0.2491 - stability_loss: 0.1146 - product_loss: 0.2650 - P_final_mae: 0.1527 - q_final_mae: 0.1499 - stability_accuracy: 0.9483 - product_accuracy: 0.9019 - val_loss: 0.4797 - val_P_final_loss: 0.0477 - val_q_final_loss: 0.1085 - val_stability_loss: 0.0889 - val_product_loss: 0.2406 - val_P_final_mae: 0.1407 - val_q_final_mae: 0.1524 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9147
Epoch 21/100
 - 0s - loss: 0.6860 - P_final_loss: 0.0466 - q_final_loss: 0.2513 - stability_loss: 0.1196 - product_loss: 0.2611 - P_final_mae: 0.1434 - q_final_mae: 0.1381 - stability_accuracy: 0.9494 - product_accuracy: 0.9057 - val_loss: 0.4891 - val_P_final_loss: 0.0455 - val_q_final_loss: 0.1236 - val_stability_loss: 0.0882 - val_product_loss: 0.2381 - val_P_final_mae: 0.1402 - val_q_final_mae: 0.1399 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9126

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 22/100
 - 0s - loss: 0.6542 - P_final_loss: 0.0440 - q_final_loss: 0.2323 - stability_loss: 0.1193 - product_loss: 0.2566 - P_final_mae: 0.1370 - q_final_mae: 0.1261 - stability_accuracy: 0.9456 - product_accuracy: 0.9035 - val_loss: 0.4611 - val_P_final_loss: 0.0399 - val_q_final_loss: 0.1063 - val_stability_loss: 0.0874 - val_product_loss: 0.2333 - val_P_final_mae: 0.1273 - val_q_final_mae: 0.1336 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9168
Epoch 23/100
 - 0s - loss: 0.6471 - P_final_loss: 0.0403 - q_final_loss: 0.2308 - stability_loss: 0.1062 - product_loss: 0.2406 - P_final_mae: 0.1360 - q_final_mae: 0.1222 - stability_accuracy: 0.9488 - product_accuracy: 0.9035 - val_loss: 0.4541 - val_P_final_loss: 0.0392 - val_q_final_loss: 0.1077 - val_stability_loss: 0.0847 - val_product_loss: 0.2286 - val_P_final_mae: 0.1268 - val_q_final_mae: 0.1287 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9190
Epoch 24/100
 - 0s - loss: 0.6425 - P_final_loss: 0.0443 - q_final_loss: 0.2304 - stability_loss: 0.1147 - product_loss: 0.2519 - P_final_mae: 0.1314 - q_final_mae: 0.1175 - stability_accuracy: 0.9488 - product_accuracy: 0.9035 - val_loss: 0.4635 - val_P_final_loss: 0.0408 - val_q_final_loss: 0.1150 - val_stability_loss: 0.0858 - val_product_loss: 0.2283 - val_P_final_mae: 0.1259 - val_q_final_mae: 0.1271 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9211
Epoch 25/100
 - 0s - loss: 0.6417 - P_final_loss: 0.0436 - q_final_loss: 0.2397 - stability_loss: 0.1081 - product_loss: 0.2452 - P_final_mae: 0.1352 - q_final_mae: 0.1165 - stability_accuracy: 0.9483 - product_accuracy: 0.9046 - val_loss: 0.4661 - val_P_final_loss: 0.0411 - val_q_final_loss: 0.1115 - val_stability_loss: 0.0889 - val_product_loss: 0.2308 - val_P_final_mae: 0.1355 - val_q_final_mae: 0.1253 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9147
Epoch 26/100
 - 0s - loss: 0.6340 - P_final_loss: 0.0449 - q_final_loss: 0.2767 - stability_loss: 0.1278 - product_loss: 0.2579 - P_final_mae: 0.1336 - q_final_mae: 0.1162 - stability_accuracy: 0.9483 - product_accuracy: 0.9078 - val_loss: 0.4431 - val_P_final_loss: 0.0391 - val_q_final_loss: 0.0926 - val_stability_loss: 0.0884 - val_product_loss: 0.2283 - val_P_final_mae: 0.1264 - val_q_final_mae: 0.1197 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9168
Epoch 27/100
 - 0s - loss: 0.6367 - P_final_loss: 0.0387 - q_final_loss: 0.2296 - stability_loss: 0.1224 - product_loss: 0.2585 - P_final_mae: 0.1307 - q_final_mae: 0.1132 - stability_accuracy: 0.9478 - product_accuracy: 0.9104 - val_loss: 0.4314 - val_P_final_loss: 0.0392 - val_q_final_loss: 0.0776 - val_stability_loss: 0.0896 - val_product_loss: 0.2293 - val_P_final_mae: 0.1248 - val_q_final_mae: 0.1162 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9168
Epoch 28/100
 - 0s - loss: 0.6337 - P_final_loss: 0.0395 - q_final_loss: 0.2286 - stability_loss: 0.1086 - product_loss: 0.2450 - P_final_mae: 0.1286 - q_final_mae: 0.1118 - stability_accuracy: 0.9499 - product_accuracy: 0.9083 - val_loss: 0.4370 - val_P_final_loss: 0.0387 - val_q_final_loss: 0.0901 - val_stability_loss: 0.0879 - val_product_loss: 0.2255 - val_P_final_mae: 0.1236 - val_q_final_mae: 0.1165 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 29/100
 - 0s - loss: 0.6258 - P_final_loss: 0.0366 - q_final_loss: 0.2223 - stability_loss: 0.1061 - product_loss: 0.2422 - P_final_mae: 0.1280 - q_final_mae: 0.1114 - stability_accuracy: 0.9499 - product_accuracy: 0.9078 - val_loss: 0.4366 - val_P_final_loss: 0.0382 - val_q_final_loss: 0.0908 - val_stability_loss: 0.0880 - val_product_loss: 0.2246 - val_P_final_mae: 0.1226 - val_q_final_mae: 0.1159 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 30/100
 - 0s - loss: 0.6251 - P_final_loss: 0.0411 - q_final_loss: 0.2260 - stability_loss: 0.1064 - product_loss: 0.2360 - P_final_mae: 0.1269 - q_final_mae: 0.1115 - stability_accuracy: 0.9499 - product_accuracy: 0.9067 - val_loss: 0.4365 - val_P_final_loss: 0.0383 - val_q_final_loss: 0.0928 - val_stability_loss: 0.0876 - val_product_loss: 0.2230 - val_P_final_mae: 0.1222 - val_q_final_mae: 0.1169 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9190
Epoch 31/100
 - 0s - loss: 0.6210 - P_final_loss: 0.0365 - q_final_loss: 0.2213 - stability_loss: 0.1085 - product_loss: 0.2384 - P_final_mae: 0.1267 - q_final_mae: 0.1109 - stability_accuracy: 0.9504 - product_accuracy: 0.9094 - val_loss: 0.4293 - val_P_final_loss: 0.0381 - val_q_final_loss: 0.0894 - val_stability_loss: 0.0857 - val_product_loss: 0.2213 - val_P_final_mae: 0.1219 - val_q_final_mae: 0.1159 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9211
Epoch 32/100
 - 0s - loss: 0.6218 - P_final_loss: 0.0387 - q_final_loss: 0.2231 - stability_loss: 0.1069 - product_loss: 0.2394 - P_final_mae: 0.1263 - q_final_mae: 0.1118 - stability_accuracy: 0.9499 - product_accuracy: 0.9078 - val_loss: 0.4392 - val_P_final_loss: 0.0381 - val_q_final_loss: 0.1015 - val_stability_loss: 0.0859 - val_product_loss: 0.2194 - val_P_final_mae: 0.1216 - val_q_final_mae: 0.1195 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 33/100
 - 0s - loss: 0.6201 - P_final_loss: 0.0410 - q_final_loss: 0.2231 - stability_loss: 0.1073 - product_loss: 0.2398 - P_final_mae: 0.1266 - q_final_mae: 0.1120 - stability_accuracy: 0.9515 - product_accuracy: 0.9110 - val_loss: 0.4280 - val_P_final_loss: 0.0381 - val_q_final_loss: 0.0877 - val_stability_loss: 0.0869 - val_product_loss: 0.2203 - val_P_final_mae: 0.1216 - val_q_final_mae: 0.1166 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 34/100
 - 0s - loss: 0.6131 - P_final_loss: 0.0353 - q_final_loss: 0.2219 - stability_loss: 0.1050 - product_loss: 0.2335 - P_final_mae: 0.1253 - q_final_mae: 0.1104 - stability_accuracy: 0.9499 - product_accuracy: 0.9110 - val_loss: 0.4388 - val_P_final_loss: 0.0376 - val_q_final_loss: 0.1009 - val_stability_loss: 0.0871 - val_product_loss: 0.2191 - val_P_final_mae: 0.1204 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9190
Epoch 35/100
 - 0s - loss: 0.6124 - P_final_loss: 0.0353 - q_final_loss: 0.2179 - stability_loss: 0.1066 - product_loss: 0.2300 - P_final_mae: 0.1233 - q_final_mae: 0.1107 - stability_accuracy: 0.9488 - product_accuracy: 0.9136 - val_loss: 0.4365 - val_P_final_loss: 0.0376 - val_q_final_loss: 0.0988 - val_stability_loss: 0.0877 - val_product_loss: 0.2180 - val_P_final_mae: 0.1231 - val_q_final_mae: 0.1171 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9275
Epoch 36/100
 - 0s - loss: 0.6131 - P_final_loss: 0.0366 - q_final_loss: 0.2187 - stability_loss: 0.1062 - product_loss: 0.2282 - P_final_mae: 0.1285 - q_final_mae: 0.1129 - stability_accuracy: 0.9494 - product_accuracy: 0.9094 - val_loss: 0.4372 - val_P_final_loss: 0.0377 - val_q_final_loss: 0.1044 - val_stability_loss: 0.0858 - val_product_loss: 0.2152 - val_P_final_mae: 0.1205 - val_q_final_mae: 0.1185 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9232
Epoch 37/100
 - 0s - loss: 0.6091 - P_final_loss: 0.0375 - q_final_loss: 0.2183 - stability_loss: 0.1217 - product_loss: 0.2411 - P_final_mae: 0.1239 - q_final_mae: 0.1111 - stability_accuracy: 0.9494 - product_accuracy: 0.9120 - val_loss: 0.4389 - val_P_final_loss: 0.0382 - val_q_final_loss: 0.1046 - val_stability_loss: 0.0866 - val_product_loss: 0.2155 - val_P_final_mae: 0.1209 - val_q_final_mae: 0.1207 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9211
Epoch 38/100
 - 0s - loss: 0.6102 - P_final_loss: 0.0363 - q_final_loss: 0.2194 - stability_loss: 0.1047 - product_loss: 0.2241 - P_final_mae: 0.1234 - q_final_mae: 0.1131 - stability_accuracy: 0.9504 - product_accuracy: 0.9142 - val_loss: 0.4444 - val_P_final_loss: 0.0379 - val_q_final_loss: 0.1080 - val_stability_loss: 0.0884 - val_product_loss: 0.2161 - val_P_final_mae: 0.1233 - val_q_final_mae: 0.1191 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9211

Epoch 00038: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 39/100
 - 0s - loss: 0.6020 - P_final_loss: 0.0370 - q_final_loss: 0.2138 - stability_loss: 0.1107 - product_loss: 0.2364 - P_final_mae: 0.1253 - q_final_mae: 0.1114 - stability_accuracy: 0.9499 - product_accuracy: 0.9120 - val_loss: 0.4409 - val_P_final_loss: 0.0376 - val_q_final_loss: 0.1055 - val_stability_loss: 0.0881 - val_product_loss: 0.2156 - val_P_final_mae: 0.1203 - val_q_final_mae: 0.1187 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 40/100
 - 0s - loss: 0.6009 - P_final_loss: 0.0350 - q_final_loss: 0.2136 - stability_loss: 0.1064 - product_loss: 0.2326 - P_final_mae: 0.1224 - q_final_mae: 0.1113 - stability_accuracy: 0.9504 - product_accuracy: 0.9115 - val_loss: 0.4402 - val_P_final_loss: 0.0377 - val_q_final_loss: 0.1044 - val_stability_loss: 0.0884 - val_product_loss: 0.2155 - val_P_final_mae: 0.1200 - val_q_final_mae: 0.1179 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9275
Epoch 41/100
 - 0s - loss: 0.6002 - P_final_loss: 0.0362 - q_final_loss: 0.2132 - stability_loss: 0.1052 - product_loss: 0.2273 - P_final_mae: 0.1219 - q_final_mae: 0.1105 - stability_accuracy: 0.9504 - product_accuracy: 0.9120 - val_loss: 0.4382 - val_P_final_loss: 0.0375 - val_q_final_loss: 0.1037 - val_stability_loss: 0.0880 - val_product_loss: 0.2148 - val_P_final_mae: 0.1194 - val_q_final_mae: 0.1180 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 42/100
 - 0s - loss: 0.5996 - P_final_loss: 0.0342 - q_final_loss: 0.2136 - stability_loss: 0.1099 - product_loss: 0.2287 - P_final_mae: 0.1234 - q_final_mae: 0.1104 - stability_accuracy: 0.9494 - product_accuracy: 0.9131 - val_loss: 0.4367 - val_P_final_loss: 0.0372 - val_q_final_loss: 0.1029 - val_stability_loss: 0.0879 - val_product_loss: 0.2145 - val_P_final_mae: 0.1199 - val_q_final_mae: 0.1175 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9275
Epoch 43/100
 - 0s - loss: 0.5994 - P_final_loss: 0.0347 - q_final_loss: 0.2137 - stability_loss: 0.1055 - product_loss: 0.2254 - P_final_mae: 0.1221 - q_final_mae: 0.1107 - stability_accuracy: 0.9499 - product_accuracy: 0.9136 - val_loss: 0.4375 - val_P_final_loss: 0.0374 - val_q_final_loss: 0.1055 - val_stability_loss: 0.0871 - val_product_loss: 0.2135 - val_P_final_mae: 0.1194 - val_q_final_mae: 0.1179 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254

Epoch 00043: ReduceLROnPlateau reducing learning rate to 8.000000525498762e-06.
Epoch 44/100
 - 0s - loss: 0.5981 - P_final_loss: 0.0360 - q_final_loss: 0.2127 - stability_loss: 0.1025 - product_loss: 0.2239 - P_final_mae: 0.1216 - q_final_mae: 0.1107 - stability_accuracy: 0.9504 - product_accuracy: 0.9158 - val_loss: 0.4367 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1050 - val_stability_loss: 0.0870 - val_product_loss: 0.2133 - val_P_final_mae: 0.1192 - val_q_final_mae: 0.1178 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 45/100
 - 0s - loss: 0.5980 - P_final_loss: 0.0359 - q_final_loss: 0.2130 - stability_loss: 0.1138 - product_loss: 0.2335 - P_final_mae: 0.1216 - q_final_mae: 0.1107 - stability_accuracy: 0.9504 - product_accuracy: 0.9158 - val_loss: 0.4364 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1047 - val_stability_loss: 0.0870 - val_product_loss: 0.2133 - val_P_final_mae: 0.1192 - val_q_final_mae: 0.1178 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9275
Epoch 46/100
 - 0s - loss: 0.5980 - P_final_loss: 0.0360 - q_final_loss: 0.2380 - stability_loss: 0.1033 - product_loss: 0.2208 - P_final_mae: 0.1215 - q_final_mae: 0.1106 - stability_accuracy: 0.9499 - product_accuracy: 0.9152 - val_loss: 0.4366 - val_P_final_loss: 0.0374 - val_q_final_loss: 0.1050 - val_stability_loss: 0.0869 - val_product_loss: 0.2133 - val_P_final_mae: 0.1191 - val_q_final_mae: 0.1179 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 47/100
 - 0s - loss: 0.5980 - P_final_loss: 0.0344 - q_final_loss: 0.2127 - stability_loss: 0.1104 - product_loss: 0.2419 - P_final_mae: 0.1214 - q_final_mae: 0.1106 - stability_accuracy: 0.9510 - product_accuracy: 0.9158 - val_loss: 0.4348 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0868 - val_product_loss: 0.2131 - val_P_final_mae: 0.1190 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 48/100
 - 0s - loss: 0.5977 - P_final_loss: 0.0340 - q_final_loss: 0.2131 - stability_loss: 0.1093 - product_loss: 0.2295 - P_final_mae: 0.1213 - q_final_mae: 0.1106 - stability_accuracy: 0.9510 - product_accuracy: 0.9147 - val_loss: 0.4346 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1034 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232

Epoch 00048: ReduceLROnPlateau reducing learning rate to 1.6000001778593287e-06.
Epoch 49/100
 - 0s - loss: 0.5976 - P_final_loss: 0.0353 - q_final_loss: 0.2126 - stability_loss: 0.1081 - product_loss: 0.2291 - P_final_mae: 0.1213 - q_final_mae: 0.1105 - stability_accuracy: 0.9504 - product_accuracy: 0.9147 - val_loss: 0.4345 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1034 - val_stability_loss: 0.0867 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 50/100
 - 0s - loss: 0.5976 - P_final_loss: 0.0340 - q_final_loss: 0.2126 - stability_loss: 0.1039 - product_loss: 0.2307 - P_final_mae: 0.1213 - q_final_mae: 0.1105 - stability_accuracy: 0.9504 - product_accuracy: 0.9147 - val_loss: 0.4347 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0867 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 51/100
 - 0s - loss: 0.5975 - P_final_loss: 0.0346 - q_final_loss: 0.2125 - stability_loss: 0.1043 - product_loss: 0.2282 - P_final_mae: 0.1213 - q_final_mae: 0.1105 - stability_accuracy: 0.9510 - product_accuracy: 0.9147 - val_loss: 0.4347 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 52/100
 - 0s - loss: 0.5975 - P_final_loss: 0.0342 - q_final_loss: 0.2128 - stability_loss: 0.1087 - product_loss: 0.2297 - P_final_mae: 0.1214 - q_final_mae: 0.1105 - stability_accuracy: 0.9510 - product_accuracy: 0.9147 - val_loss: 0.4346 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 53/100
 - 0s - loss: 0.5975 - P_final_loss: 0.0369 - q_final_loss: 0.2133 - stability_loss: 0.1100 - product_loss: 0.2328 - P_final_mae: 0.1214 - q_final_mae: 0.1105 - stability_accuracy: 0.9510 - product_accuracy: 0.9152 - val_loss: 0.4346 - val_P_final_loss: 0.0372 - val_q_final_loss: 0.1034 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232

Epoch 00053: ReduceLROnPlateau reducing learning rate to 3.200000264769187e-07.
Epoch 00053: early stopping
Training results
target           mean  training score   test score
--------------------------------------------------
P_final     :  444.678       60.521       59.283
q_final     :    0.661        0.067        0.071
stability   :  -        95.1%        96.8%
product     :  -        91.5%        92.3%

We can now have a look at how the model trained, what the final accuracy and mean absolute errors are for the targets:

[14]:
pl.figure(figsize=(15,7))
predictor.print_score()
predictor.plot_training_history()
Training results
target           mean  training score   test score
--------------------------------------------------
P_final     :  444.678       60.521       59.283
q_final     :    0.661        0.067        0.071
stability   :  -        95.1%        96.8%
product     :  -        91.5%        92.3%
../_images/tutorials_period_mass_ratio_models_23_1.png

For the classifiers we can also plot a confusion matrix:

[15]:
pl.figure(figsize=(16, 8))
predictor.plot_confusion_matrix()
../_images/tutorials_period_mass_ratio_models_25_0.png

Making predictions

Download four samples with different metalicity distributions

[16]:
sample_m1 = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_m1.csv')
sample_solar = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_solar.csv')
sample_uniform = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_uniform.csv')
sample_besancon = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_besancon.csv')
[17]:
observations = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/observations_p_q.csv')
[18]:
pred_m1 = predictor.predict(sample_m1)
pred_solar = predictor.predict(sample_solar)
pred_uniform = predictor.predict(sample_uniform)
pred_besancon = predictor.predict(sample_besancon)
[19]:
def plot_set(data, ax):
    ax.set_facecolor('w')
    s = data[(data['product'] == 'sdB')]
    x, y = s['P_final'], s['q_final']
    hb = ax.hexbin(x, y, gridsize=100, cmap='Greys', vmax=5)

    pl.errorbar(observations['P'], observations['q'], xerr=observations['P_e'], yerr=observations['q_e'],
                marker='+', color='b', ls='', ms=10, mew=3, label='Obs')

    pl.xlim([200,1750])
    pl.ylim([0.3, 0.9])


pl.figure(1, figsize=(12, 10))
pl.subplots_adjust(left=0.07, right=0.98, bottom=0.07, top=0.95)

ax = pl.subplot(221)
plot_set(pred_solar, ax)
pl.title('[Fe/H] = 0.0')

ax = pl.subplot(222)
plot_set(pred_m1, ax)
pl.title('[Fe/H] = -1.0')
pl.plot([0], [0], 's', color='gray', label='Model')
pl.legend()

ax = pl.subplot(223)
plot_set(pred_uniform, ax)
pl.title('[Fe/H] = Uniform')

ax = pl.subplot(224)
plot_set(pred_besancon, ax)
pl.title('[Fe/H] = Galactic')

_ = pl.figtext(0.53, 0.01, 'Period (days)', ha='center', size=18)
_ = pl.figtext(0.01, 0.54, 'Mass ratio (M1 / M2)', va='center', size=18, rotation='vertical')
../_images/tutorials_period_mass_ratio_models_31_0.png
[ ]: