Predicting period and mass ratio of hot subdwarfs¶
The aim of this tutorial is to show how NNaPS can be used to predict the orbital parameters of hot subdwarf binaries based on a relatively small sample of full MESA models.
We will use a sample of ~2000 binary MESA models focused on low mass binaries that interact on the red giant branch. We will train a fully connected model to predict if the interaction phase is stable, the orbital parameters (period and mass ratio) after the mass-loss phase and the final product (hot subdwarf, horizontal branch star or He white dwarf) after the mass-loss phase.
Then we will check how different metalicity distributions in the intial population will affect those final orbital parameters.
[1]:
import pandas as pd
import pylab as pl
import seaborn as sns
import yaml
from nnaps import predictors
Using TensorFlow backend.
Obtaining the data¶
Let’s download the extracted MESA models:
[2]:
data = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/period_mass_ratio_mesa_models_extracted.csv')
[3]:
data
[3]:
M1_init | M2_init | q_init | P_init | FeH_init | stability | product | P_final | q_final | |
---|---|---|---|---|---|---|---|---|---|
0 | 1.791000 | 1.138 | 1.573813 | 3.570104 | -0.109671 | stable | He-WD | 52.924494 | 0.265416 |
1 | 0.954000 | 0.355 | 2.687324 | 410.120037 | -0.780000 | unstable | UK | 0.618103 | 1.245376 |
2 | 1.297000 | 0.865 | 1.499421 | 371.590147 | -0.218502 | stable | sdB | 1130.865820 | 0.541258 |
3 | 1.554999 | 0.579 | 2.685664 | 61.920032 | -0.282191 | unstable | UK | 0.045363 | 0.560722 |
4 | 1.291999 | 0.378 | 3.417988 | 23.240011 | -0.493334 | unstable | UK | 0.005948 | 0.740944 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2340 | 0.922000 | 0.657 | 1.403348 | 139.840010 | -0.780000 | stable | He-WD | 326.953825 | 0.623666 |
2341 | 1.496000 | 1.323 | 1.130763 | 138.060030 | -0.840593 | stable | HB | 753.134322 | 0.402106 |
2342 | 1.154000 | 0.410 | 2.814634 | 311.360022 | 0.182481 | unstable | UK | 0.187493 | 0.932706 |
2343 | 0.771000 | 0.714 | 1.079832 | 231.110008 | -1.292690 | stable | He-WD | 438.852463 | 0.658157 |
2344 | 0.895000 | 0.653 | 1.370597 | 352.640024 | -0.780000 | stable | sdB | 660.178499 | 0.708158 |
2345 rows × 9 columns
The features that we are interested in are the inital parameters of the models: - M1_init: initial mass of the donor star - q_init: initial mass ratio - P_init: initial orbital period - FeH_init: metalicity
The properties that we want to predict are: - stability: if the mass loss phase is stable or not - product: what type of star we get after the mass loss phase (sdB, HB, He-WD or UK for systems that are unstable) - P_final: orbital period after interaction phase - q_final: mass ratio after the interaction phase
[5]:
features = ['M1_init', 'q_init', 'P_init', 'FeH_init']
classifiers = ['stability', 'product']
regressors = ['P_final', 'q_final']
Data analysis¶
Let’s start with having a look at the data we are working with.
[ ]:
stable = data[data['stability'] == 'stable']
[6]:
data['stability'].value_counts()
[6]:
stable 1586
unstable 759
Name: stability, dtype: int64
[7]:
data['product'].value_counts()
[7]:
He-WD 853
UK 759
HB 421
sdB 312
Name: product, dtype: int64
[8]:
sns.pairplot(stable, vars=features, hue='product')
[8]:
<seaborn.axisgrid.PairGrid at 0x7f2ae4661b10>
[9]:
sns.scatterplot('P_init', 'P_final', data=stable, hue='FeH_init')
[9]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f2ae3660850>
[10]:
sns.scatterplot('q_init', 'q_final', data=stable, hue='FeH_init')
[10]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f2ae1e0dc90>
Modeling¶
Now we can make a predictive model. We will use a fully connected shallow neural network with 3 layers. Important to notice is that the features are automatically scaled, but since we are working with two targets (P_final and q_final) that are very different in range, we need to apply a scaler to these variables to avoid one of them dominating the loss function.
[11]:
setup = """
features:
- M1_init
- q_init
- P_init
- FeH_init
regressors:
P_final:
processor: StandardScaler
q_final:
processor: StandardScaler
classifiers:
- stability
- product
model:
- {'layer':'Dense', 'args':[200], 'kwargs': {'activation':'relu', 'name':'FC_1'} }
- {'layer':'Dense', 'args':[100], 'kwargs': {'activation':'relu', 'name':'FC_2'} }
- {'layer':'Dense', 'args':[50], 'kwargs': {'activation':'relu', 'name':'FC_3'} }
optimizer: adam
"""
setup = yaml.safe_load(setup)
[12]:
predictor = predictors.FCPredictor(setup=setup, data=data)
[13]:
predictor.fit(epochs=100, batch_size=124, reduce_lr=True)
Train on 1876 samples, validate on 469 samples
Epoch 1/100
- 1s - loss: 3.3076 - P_final_loss: 0.5563 - q_final_loss: 0.7013 - stability_loss: 0.6195 - product_loss: 1.3635 - P_final_mae: 0.6184 - q_final_mae: 0.3683 - stability_accuracy: 0.7777 - product_accuracy: 0.3241 - val_loss: 2.2777 - val_P_final_loss: 0.3099 - val_q_final_loss: 0.1414 - val_stability_loss: 0.5188 - val_product_loss: 1.3055 - val_P_final_mae: 0.4181 - val_q_final_mae: 0.2271 - val_stability_accuracy: 0.8380 - val_product_accuracy: 0.3198
Epoch 2/100
- 0s - loss: 2.3581 - P_final_loss: 0.2519 - q_final_loss: 0.4045 - stability_loss: 0.4448 - product_loss: 1.2132 - P_final_mae: 0.3392 - q_final_mae: 0.2182 - stability_accuracy: 0.8417 - product_accuracy: 0.4488 - val_loss: 1.7413 - val_P_final_loss: 0.2005 - val_q_final_loss: 0.0971 - val_stability_loss: 0.3377 - val_product_loss: 1.1101 - val_P_final_mae: 0.3216 - val_q_final_mae: 0.1601 - val_stability_accuracy: 0.8721 - val_product_accuracy: 0.6141
Epoch 3/100
- 0s - loss: 1.8151 - P_final_loss: 0.1397 - q_final_loss: 0.3311 - stability_loss: 0.3114 - product_loss: 0.9900 - P_final_mae: 0.2485 - q_final_mae: 0.1622 - stability_accuracy: 0.8849 - product_accuracy: 0.6588 - val_loss: 1.2836 - val_P_final_loss: 0.1059 - val_q_final_loss: 0.0753 - val_stability_loss: 0.2378 - val_product_loss: 0.8689 - val_P_final_mae: 0.2266 - val_q_final_mae: 0.1394 - val_stability_accuracy: 0.9275 - val_product_accuracy: 0.7207
Epoch 4/100
- 0s - loss: 1.4653 - P_final_loss: 0.1084 - q_final_loss: 0.3225 - stability_loss: 0.2348 - product_loss: 0.7728 - P_final_mae: 0.2179 - q_final_mae: 0.1514 - stability_accuracy: 0.9184 - product_accuracy: 0.7303 - val_loss: 0.9856 - val_P_final_loss: 0.0873 - val_q_final_loss: 0.0602 - val_stability_loss: 0.1759 - val_product_loss: 0.6655 - val_P_final_mae: 0.2056 - val_q_final_mae: 0.1291 - val_stability_accuracy: 0.9403 - val_product_accuracy: 0.7612
Epoch 5/100
- 0s - loss: 1.2421 - P_final_loss: 0.0994 - q_final_loss: 0.3190 - stability_loss: 0.1946 - product_loss: 0.6452 - P_final_mae: 0.2057 - q_final_mae: 0.1311 - stability_accuracy: 0.9254 - product_accuracy: 0.7862 - val_loss: 0.8992 - val_P_final_loss: 0.0743 - val_q_final_loss: 0.1368 - val_stability_loss: 0.1508 - val_product_loss: 0.5419 - val_P_final_mae: 0.1865 - val_q_final_mae: 0.1790 - val_stability_accuracy: 0.9403 - val_product_accuracy: 0.8145
Epoch 6/100
- 0s - loss: 1.1037 - P_final_loss: 0.0810 - q_final_loss: 0.3026 - stability_loss: 0.1618 - product_loss: 0.5323 - P_final_mae: 0.1980 - q_final_mae: 0.1636 - stability_accuracy: 0.9302 - product_accuracy: 0.8124 - val_loss: 0.7422 - val_P_final_loss: 0.0668 - val_q_final_loss: 0.0983 - val_stability_loss: 0.1185 - val_product_loss: 0.4632 - val_P_final_mae: 0.1793 - val_q_final_mae: 0.1445 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8209
Epoch 7/100
- 0s - loss: 0.9975 - P_final_loss: 0.0707 - q_final_loss: 0.2901 - stability_loss: 0.1473 - product_loss: 0.4778 - P_final_mae: 0.1839 - q_final_mae: 0.1594 - stability_accuracy: 0.9302 - product_accuracy: 0.8358 - val_loss: 0.6717 - val_P_final_loss: 0.0628 - val_q_final_loss: 0.0918 - val_stability_loss: 0.1064 - val_product_loss: 0.4159 - val_P_final_mae: 0.1680 - val_q_final_mae: 0.1558 - val_stability_accuracy: 0.9574 - val_product_accuracy: 0.8294
Epoch 8/100
- 0s - loss: 0.9291 - P_final_loss: 0.0652 - q_final_loss: 0.2819 - stability_loss: 0.1368 - product_loss: 0.4310 - P_final_mae: 0.1727 - q_final_mae: 0.1525 - stability_accuracy: 0.9344 - product_accuracy: 0.8417 - val_loss: 0.6311 - val_P_final_loss: 0.0612 - val_q_final_loss: 0.0867 - val_stability_loss: 0.1064 - val_product_loss: 0.3820 - val_P_final_mae: 0.1736 - val_q_final_mae: 0.1252 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8465
Epoch 9/100
- 0s - loss: 0.8920 - P_final_loss: 0.0605 - q_final_loss: 0.2840 - stability_loss: 0.1290 - product_loss: 0.3871 - P_final_mae: 0.1666 - q_final_mae: 0.1426 - stability_accuracy: 0.9350 - product_accuracy: 0.8555 - val_loss: 0.6306 - val_P_final_loss: 0.0567 - val_q_final_loss: 0.1263 - val_stability_loss: 0.0974 - val_product_loss: 0.3570 - val_P_final_mae: 0.1670 - val_q_final_mae: 0.1638 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8507
Epoch 10/100
- 0s - loss: 0.8439 - P_final_loss: 0.0602 - q_final_loss: 0.2721 - stability_loss: 0.1371 - product_loss: 0.3827 - P_final_mae: 0.1690 - q_final_mae: 0.1385 - stability_accuracy: 0.9382 - product_accuracy: 0.8561 - val_loss: 0.5496 - val_P_final_loss: 0.0546 - val_q_final_loss: 0.0721 - val_stability_loss: 0.0950 - val_product_loss: 0.3326 - val_P_final_mae: 0.1544 - val_q_final_mae: 0.1383 - val_stability_accuracy: 0.9638 - val_product_accuracy: 0.8635
Epoch 11/100
- 0s - loss: 0.8921 - P_final_loss: 0.0612 - q_final_loss: 0.3935 - stability_loss: 0.1295 - product_loss: 0.3574 - P_final_mae: 0.1701 - q_final_mae: 0.1654 - stability_accuracy: 0.9387 - product_accuracy: 0.8577 - val_loss: 0.5804 - val_P_final_loss: 0.0490 - val_q_final_loss: 0.1264 - val_stability_loss: 0.0943 - val_product_loss: 0.3174 - val_P_final_mae: 0.1495 - val_q_final_mae: 0.1714 - val_stability_accuracy: 0.9552 - val_product_accuracy: 0.8785
Epoch 12/100
- 0s - loss: 0.8487 - P_final_loss: 0.0586 - q_final_loss: 0.3174 - stability_loss: 0.1221 - product_loss: 0.3325 - P_final_mae: 0.1562 - q_final_mae: 0.1660 - stability_accuracy: 0.9419 - product_accuracy: 0.8614 - val_loss: 0.5460 - val_P_final_loss: 0.0508 - val_q_final_loss: 0.0680 - val_stability_loss: 0.1125 - val_product_loss: 0.3183 - val_P_final_mae: 0.1565 - val_q_final_mae: 0.1182 - val_stability_accuracy: 0.9467 - val_product_accuracy: 0.8721
Epoch 13/100
- 0s - loss: 0.8180 - P_final_loss: 0.0555 - q_final_loss: 0.2856 - stability_loss: 0.1255 - product_loss: 0.3222 - P_final_mae: 0.1648 - q_final_mae: 0.1299 - stability_accuracy: 0.9392 - product_accuracy: 0.8721 - val_loss: 0.5329 - val_P_final_loss: 0.0491 - val_q_final_loss: 0.1100 - val_stability_loss: 0.0889 - val_product_loss: 0.2902 - val_P_final_mae: 0.1511 - val_q_final_mae: 0.1808 - val_stability_accuracy: 0.9638 - val_product_accuracy: 0.8955
Epoch 14/100
- 0s - loss: 0.7635 - P_final_loss: 0.0490 - q_final_loss: 0.2639 - stability_loss: 0.1188 - product_loss: 0.3074 - P_final_mae: 0.1515 - q_final_mae: 0.1644 - stability_accuracy: 0.9440 - product_accuracy: 0.8769 - val_loss: 0.5311 - val_P_final_loss: 0.0473 - val_q_final_loss: 0.1135 - val_stability_loss: 0.0914 - val_product_loss: 0.2855 - val_P_final_mae: 0.1433 - val_q_final_mae: 0.1476 - val_stability_accuracy: 0.9616 - val_product_accuracy: 0.8870
Epoch 15/100
- 0s - loss: 0.7786 - P_final_loss: 0.0478 - q_final_loss: 0.2968 - stability_loss: 0.1276 - product_loss: 0.3036 - P_final_mae: 0.1492 - q_final_mae: 0.1559 - stability_accuracy: 0.9440 - product_accuracy: 0.8817 - val_loss: 0.4969 - val_P_final_loss: 0.0488 - val_q_final_loss: 0.0882 - val_stability_loss: 0.0880 - val_product_loss: 0.2759 - val_P_final_mae: 0.1546 - val_q_final_mae: 0.1737 - val_stability_accuracy: 0.9595 - val_product_accuracy: 0.8913
Epoch 16/100
- 0s - loss: 0.7438 - P_final_loss: 0.0523 - q_final_loss: 0.2630 - stability_loss: 0.1230 - product_loss: 0.2981 - P_final_mae: 0.1481 - q_final_mae: 0.1420 - stability_accuracy: 0.9430 - product_accuracy: 0.8913 - val_loss: 0.4654 - val_P_final_loss: 0.0469 - val_q_final_loss: 0.0682 - val_stability_loss: 0.0893 - val_product_loss: 0.2656 - val_P_final_mae: 0.1490 - val_q_final_mae: 0.1190 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9041
Epoch 17/100
- 0s - loss: 0.7244 - P_final_loss: 0.0476 - q_final_loss: 0.9122 - stability_loss: 0.1176 - product_loss: 0.2901 - P_final_mae: 0.1460 - q_final_mae: 0.1270 - stability_accuracy: 0.9472 - product_accuracy: 0.8923 - val_loss: 0.5367 - val_P_final_loss: 0.0444 - val_q_final_loss: 0.1332 - val_stability_loss: 0.0998 - val_product_loss: 0.2656 - val_P_final_mae: 0.1381 - val_q_final_mae: 0.1568 - val_stability_accuracy: 0.9574 - val_product_accuracy: 0.9019
Epoch 18/100
- 0s - loss: 0.9268 - P_final_loss: 0.0532 - q_final_loss: 0.4544 - stability_loss: 0.1233 - product_loss: 0.2855 - P_final_mae: 0.1537 - q_final_mae: 0.2999 - stability_accuracy: 0.9451 - product_accuracy: 0.8875 - val_loss: 0.7742 - val_P_final_loss: 0.0510 - val_q_final_loss: 0.3940 - val_stability_loss: 0.0872 - val_product_loss: 0.2553 - val_P_final_mae: 0.1577 - val_q_final_mae: 0.2756 - val_stability_accuracy: 0.9616 - val_product_accuracy: 0.9041
Epoch 19/100
- 0s - loss: 0.7957 - P_final_loss: 0.0532 - q_final_loss: 0.3301 - stability_loss: 0.1156 - product_loss: 0.2775 - P_final_mae: 0.1597 - q_final_mae: 0.2370 - stability_accuracy: 0.9435 - product_accuracy: 0.8993 - val_loss: 0.4659 - val_P_final_loss: 0.0459 - val_q_final_loss: 0.0833 - val_stability_loss: 0.0910 - val_product_loss: 0.2492 - val_P_final_mae: 0.1502 - val_q_final_mae: 0.1570 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9126
Epoch 20/100
- 0s - loss: 0.6928 - P_final_loss: 0.0476 - q_final_loss: 0.2491 - stability_loss: 0.1146 - product_loss: 0.2650 - P_final_mae: 0.1527 - q_final_mae: 0.1499 - stability_accuracy: 0.9483 - product_accuracy: 0.9019 - val_loss: 0.4797 - val_P_final_loss: 0.0477 - val_q_final_loss: 0.1085 - val_stability_loss: 0.0889 - val_product_loss: 0.2406 - val_P_final_mae: 0.1407 - val_q_final_mae: 0.1524 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9147
Epoch 21/100
- 0s - loss: 0.6860 - P_final_loss: 0.0466 - q_final_loss: 0.2513 - stability_loss: 0.1196 - product_loss: 0.2611 - P_final_mae: 0.1434 - q_final_mae: 0.1381 - stability_accuracy: 0.9494 - product_accuracy: 0.9057 - val_loss: 0.4891 - val_P_final_loss: 0.0455 - val_q_final_loss: 0.1236 - val_stability_loss: 0.0882 - val_product_loss: 0.2381 - val_P_final_mae: 0.1402 - val_q_final_mae: 0.1399 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9126
Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 22/100
- 0s - loss: 0.6542 - P_final_loss: 0.0440 - q_final_loss: 0.2323 - stability_loss: 0.1193 - product_loss: 0.2566 - P_final_mae: 0.1370 - q_final_mae: 0.1261 - stability_accuracy: 0.9456 - product_accuracy: 0.9035 - val_loss: 0.4611 - val_P_final_loss: 0.0399 - val_q_final_loss: 0.1063 - val_stability_loss: 0.0874 - val_product_loss: 0.2333 - val_P_final_mae: 0.1273 - val_q_final_mae: 0.1336 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9168
Epoch 23/100
- 0s - loss: 0.6471 - P_final_loss: 0.0403 - q_final_loss: 0.2308 - stability_loss: 0.1062 - product_loss: 0.2406 - P_final_mae: 0.1360 - q_final_mae: 0.1222 - stability_accuracy: 0.9488 - product_accuracy: 0.9035 - val_loss: 0.4541 - val_P_final_loss: 0.0392 - val_q_final_loss: 0.1077 - val_stability_loss: 0.0847 - val_product_loss: 0.2286 - val_P_final_mae: 0.1268 - val_q_final_mae: 0.1287 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9190
Epoch 24/100
- 0s - loss: 0.6425 - P_final_loss: 0.0443 - q_final_loss: 0.2304 - stability_loss: 0.1147 - product_loss: 0.2519 - P_final_mae: 0.1314 - q_final_mae: 0.1175 - stability_accuracy: 0.9488 - product_accuracy: 0.9035 - val_loss: 0.4635 - val_P_final_loss: 0.0408 - val_q_final_loss: 0.1150 - val_stability_loss: 0.0858 - val_product_loss: 0.2283 - val_P_final_mae: 0.1259 - val_q_final_mae: 0.1271 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9211
Epoch 25/100
- 0s - loss: 0.6417 - P_final_loss: 0.0436 - q_final_loss: 0.2397 - stability_loss: 0.1081 - product_loss: 0.2452 - P_final_mae: 0.1352 - q_final_mae: 0.1165 - stability_accuracy: 0.9483 - product_accuracy: 0.9046 - val_loss: 0.4661 - val_P_final_loss: 0.0411 - val_q_final_loss: 0.1115 - val_stability_loss: 0.0889 - val_product_loss: 0.2308 - val_P_final_mae: 0.1355 - val_q_final_mae: 0.1253 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9147
Epoch 26/100
- 0s - loss: 0.6340 - P_final_loss: 0.0449 - q_final_loss: 0.2767 - stability_loss: 0.1278 - product_loss: 0.2579 - P_final_mae: 0.1336 - q_final_mae: 0.1162 - stability_accuracy: 0.9483 - product_accuracy: 0.9078 - val_loss: 0.4431 - val_P_final_loss: 0.0391 - val_q_final_loss: 0.0926 - val_stability_loss: 0.0884 - val_product_loss: 0.2283 - val_P_final_mae: 0.1264 - val_q_final_mae: 0.1197 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9168
Epoch 27/100
- 0s - loss: 0.6367 - P_final_loss: 0.0387 - q_final_loss: 0.2296 - stability_loss: 0.1224 - product_loss: 0.2585 - P_final_mae: 0.1307 - q_final_mae: 0.1132 - stability_accuracy: 0.9478 - product_accuracy: 0.9104 - val_loss: 0.4314 - val_P_final_loss: 0.0392 - val_q_final_loss: 0.0776 - val_stability_loss: 0.0896 - val_product_loss: 0.2293 - val_P_final_mae: 0.1248 - val_q_final_mae: 0.1162 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9168
Epoch 28/100
- 0s - loss: 0.6337 - P_final_loss: 0.0395 - q_final_loss: 0.2286 - stability_loss: 0.1086 - product_loss: 0.2450 - P_final_mae: 0.1286 - q_final_mae: 0.1118 - stability_accuracy: 0.9499 - product_accuracy: 0.9083 - val_loss: 0.4370 - val_P_final_loss: 0.0387 - val_q_final_loss: 0.0901 - val_stability_loss: 0.0879 - val_product_loss: 0.2255 - val_P_final_mae: 0.1236 - val_q_final_mae: 0.1165 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 29/100
- 0s - loss: 0.6258 - P_final_loss: 0.0366 - q_final_loss: 0.2223 - stability_loss: 0.1061 - product_loss: 0.2422 - P_final_mae: 0.1280 - q_final_mae: 0.1114 - stability_accuracy: 0.9499 - product_accuracy: 0.9078 - val_loss: 0.4366 - val_P_final_loss: 0.0382 - val_q_final_loss: 0.0908 - val_stability_loss: 0.0880 - val_product_loss: 0.2246 - val_P_final_mae: 0.1226 - val_q_final_mae: 0.1159 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 30/100
- 0s - loss: 0.6251 - P_final_loss: 0.0411 - q_final_loss: 0.2260 - stability_loss: 0.1064 - product_loss: 0.2360 - P_final_mae: 0.1269 - q_final_mae: 0.1115 - stability_accuracy: 0.9499 - product_accuracy: 0.9067 - val_loss: 0.4365 - val_P_final_loss: 0.0383 - val_q_final_loss: 0.0928 - val_stability_loss: 0.0876 - val_product_loss: 0.2230 - val_P_final_mae: 0.1222 - val_q_final_mae: 0.1169 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9190
Epoch 31/100
- 0s - loss: 0.6210 - P_final_loss: 0.0365 - q_final_loss: 0.2213 - stability_loss: 0.1085 - product_loss: 0.2384 - P_final_mae: 0.1267 - q_final_mae: 0.1109 - stability_accuracy: 0.9504 - product_accuracy: 0.9094 - val_loss: 0.4293 - val_P_final_loss: 0.0381 - val_q_final_loss: 0.0894 - val_stability_loss: 0.0857 - val_product_loss: 0.2213 - val_P_final_mae: 0.1219 - val_q_final_mae: 0.1159 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9211
Epoch 32/100
- 0s - loss: 0.6218 - P_final_loss: 0.0387 - q_final_loss: 0.2231 - stability_loss: 0.1069 - product_loss: 0.2394 - P_final_mae: 0.1263 - q_final_mae: 0.1118 - stability_accuracy: 0.9499 - product_accuracy: 0.9078 - val_loss: 0.4392 - val_P_final_loss: 0.0381 - val_q_final_loss: 0.1015 - val_stability_loss: 0.0859 - val_product_loss: 0.2194 - val_P_final_mae: 0.1216 - val_q_final_mae: 0.1195 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 33/100
- 0s - loss: 0.6201 - P_final_loss: 0.0410 - q_final_loss: 0.2231 - stability_loss: 0.1073 - product_loss: 0.2398 - P_final_mae: 0.1266 - q_final_mae: 0.1120 - stability_accuracy: 0.9515 - product_accuracy: 0.9110 - val_loss: 0.4280 - val_P_final_loss: 0.0381 - val_q_final_loss: 0.0877 - val_stability_loss: 0.0869 - val_product_loss: 0.2203 - val_P_final_mae: 0.1216 - val_q_final_mae: 0.1166 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 34/100
- 0s - loss: 0.6131 - P_final_loss: 0.0353 - q_final_loss: 0.2219 - stability_loss: 0.1050 - product_loss: 0.2335 - P_final_mae: 0.1253 - q_final_mae: 0.1104 - stability_accuracy: 0.9499 - product_accuracy: 0.9110 - val_loss: 0.4388 - val_P_final_loss: 0.0376 - val_q_final_loss: 0.1009 - val_stability_loss: 0.0871 - val_product_loss: 0.2191 - val_P_final_mae: 0.1204 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9190
Epoch 35/100
- 0s - loss: 0.6124 - P_final_loss: 0.0353 - q_final_loss: 0.2179 - stability_loss: 0.1066 - product_loss: 0.2300 - P_final_mae: 0.1233 - q_final_mae: 0.1107 - stability_accuracy: 0.9488 - product_accuracy: 0.9136 - val_loss: 0.4365 - val_P_final_loss: 0.0376 - val_q_final_loss: 0.0988 - val_stability_loss: 0.0877 - val_product_loss: 0.2180 - val_P_final_mae: 0.1231 - val_q_final_mae: 0.1171 - val_stability_accuracy: 0.9659 - val_product_accuracy: 0.9275
Epoch 36/100
- 0s - loss: 0.6131 - P_final_loss: 0.0366 - q_final_loss: 0.2187 - stability_loss: 0.1062 - product_loss: 0.2282 - P_final_mae: 0.1285 - q_final_mae: 0.1129 - stability_accuracy: 0.9494 - product_accuracy: 0.9094 - val_loss: 0.4372 - val_P_final_loss: 0.0377 - val_q_final_loss: 0.1044 - val_stability_loss: 0.0858 - val_product_loss: 0.2152 - val_P_final_mae: 0.1205 - val_q_final_mae: 0.1185 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9232
Epoch 37/100
- 0s - loss: 0.6091 - P_final_loss: 0.0375 - q_final_loss: 0.2183 - stability_loss: 0.1217 - product_loss: 0.2411 - P_final_mae: 0.1239 - q_final_mae: 0.1111 - stability_accuracy: 0.9494 - product_accuracy: 0.9120 - val_loss: 0.4389 - val_P_final_loss: 0.0382 - val_q_final_loss: 0.1046 - val_stability_loss: 0.0866 - val_product_loss: 0.2155 - val_P_final_mae: 0.1209 - val_q_final_mae: 0.1207 - val_stability_accuracy: 0.9701 - val_product_accuracy: 0.9211
Epoch 38/100
- 0s - loss: 0.6102 - P_final_loss: 0.0363 - q_final_loss: 0.2194 - stability_loss: 0.1047 - product_loss: 0.2241 - P_final_mae: 0.1234 - q_final_mae: 0.1131 - stability_accuracy: 0.9504 - product_accuracy: 0.9142 - val_loss: 0.4444 - val_P_final_loss: 0.0379 - val_q_final_loss: 0.1080 - val_stability_loss: 0.0884 - val_product_loss: 0.2161 - val_P_final_mae: 0.1233 - val_q_final_mae: 0.1191 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9211
Epoch 00038: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 39/100
- 0s - loss: 0.6020 - P_final_loss: 0.0370 - q_final_loss: 0.2138 - stability_loss: 0.1107 - product_loss: 0.2364 - P_final_mae: 0.1253 - q_final_mae: 0.1114 - stability_accuracy: 0.9499 - product_accuracy: 0.9120 - val_loss: 0.4409 - val_P_final_loss: 0.0376 - val_q_final_loss: 0.1055 - val_stability_loss: 0.0881 - val_product_loss: 0.2156 - val_P_final_mae: 0.1203 - val_q_final_mae: 0.1187 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 40/100
- 0s - loss: 0.6009 - P_final_loss: 0.0350 - q_final_loss: 0.2136 - stability_loss: 0.1064 - product_loss: 0.2326 - P_final_mae: 0.1224 - q_final_mae: 0.1113 - stability_accuracy: 0.9504 - product_accuracy: 0.9115 - val_loss: 0.4402 - val_P_final_loss: 0.0377 - val_q_final_loss: 0.1044 - val_stability_loss: 0.0884 - val_product_loss: 0.2155 - val_P_final_mae: 0.1200 - val_q_final_mae: 0.1179 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9275
Epoch 41/100
- 0s - loss: 0.6002 - P_final_loss: 0.0362 - q_final_loss: 0.2132 - stability_loss: 0.1052 - product_loss: 0.2273 - P_final_mae: 0.1219 - q_final_mae: 0.1105 - stability_accuracy: 0.9504 - product_accuracy: 0.9120 - val_loss: 0.4382 - val_P_final_loss: 0.0375 - val_q_final_loss: 0.1037 - val_stability_loss: 0.0880 - val_product_loss: 0.2148 - val_P_final_mae: 0.1194 - val_q_final_mae: 0.1180 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 42/100
- 0s - loss: 0.5996 - P_final_loss: 0.0342 - q_final_loss: 0.2136 - stability_loss: 0.1099 - product_loss: 0.2287 - P_final_mae: 0.1234 - q_final_mae: 0.1104 - stability_accuracy: 0.9494 - product_accuracy: 0.9131 - val_loss: 0.4367 - val_P_final_loss: 0.0372 - val_q_final_loss: 0.1029 - val_stability_loss: 0.0879 - val_product_loss: 0.2145 - val_P_final_mae: 0.1199 - val_q_final_mae: 0.1175 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9275
Epoch 43/100
- 0s - loss: 0.5994 - P_final_loss: 0.0347 - q_final_loss: 0.2137 - stability_loss: 0.1055 - product_loss: 0.2254 - P_final_mae: 0.1221 - q_final_mae: 0.1107 - stability_accuracy: 0.9499 - product_accuracy: 0.9136 - val_loss: 0.4375 - val_P_final_loss: 0.0374 - val_q_final_loss: 0.1055 - val_stability_loss: 0.0871 - val_product_loss: 0.2135 - val_P_final_mae: 0.1194 - val_q_final_mae: 0.1179 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 00043: ReduceLROnPlateau reducing learning rate to 8.000000525498762e-06.
Epoch 44/100
- 0s - loss: 0.5981 - P_final_loss: 0.0360 - q_final_loss: 0.2127 - stability_loss: 0.1025 - product_loss: 0.2239 - P_final_mae: 0.1216 - q_final_mae: 0.1107 - stability_accuracy: 0.9504 - product_accuracy: 0.9158 - val_loss: 0.4367 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1050 - val_stability_loss: 0.0870 - val_product_loss: 0.2133 - val_P_final_mae: 0.1192 - val_q_final_mae: 0.1178 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 45/100
- 0s - loss: 0.5980 - P_final_loss: 0.0359 - q_final_loss: 0.2130 - stability_loss: 0.1138 - product_loss: 0.2335 - P_final_mae: 0.1216 - q_final_mae: 0.1107 - stability_accuracy: 0.9504 - product_accuracy: 0.9158 - val_loss: 0.4364 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1047 - val_stability_loss: 0.0870 - val_product_loss: 0.2133 - val_P_final_mae: 0.1192 - val_q_final_mae: 0.1178 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9275
Epoch 46/100
- 0s - loss: 0.5980 - P_final_loss: 0.0360 - q_final_loss: 0.2380 - stability_loss: 0.1033 - product_loss: 0.2208 - P_final_mae: 0.1215 - q_final_mae: 0.1106 - stability_accuracy: 0.9499 - product_accuracy: 0.9152 - val_loss: 0.4366 - val_P_final_loss: 0.0374 - val_q_final_loss: 0.1050 - val_stability_loss: 0.0869 - val_product_loss: 0.2133 - val_P_final_mae: 0.1191 - val_q_final_mae: 0.1179 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 47/100
- 0s - loss: 0.5980 - P_final_loss: 0.0344 - q_final_loss: 0.2127 - stability_loss: 0.1104 - product_loss: 0.2419 - P_final_mae: 0.1214 - q_final_mae: 0.1106 - stability_accuracy: 0.9510 - product_accuracy: 0.9158 - val_loss: 0.4348 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0868 - val_product_loss: 0.2131 - val_P_final_mae: 0.1190 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9254
Epoch 48/100
- 0s - loss: 0.5977 - P_final_loss: 0.0340 - q_final_loss: 0.2131 - stability_loss: 0.1093 - product_loss: 0.2295 - P_final_mae: 0.1213 - q_final_mae: 0.1106 - stability_accuracy: 0.9510 - product_accuracy: 0.9147 - val_loss: 0.4346 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1034 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 00048: ReduceLROnPlateau reducing learning rate to 1.6000001778593287e-06.
Epoch 49/100
- 0s - loss: 0.5976 - P_final_loss: 0.0353 - q_final_loss: 0.2126 - stability_loss: 0.1081 - product_loss: 0.2291 - P_final_mae: 0.1213 - q_final_mae: 0.1105 - stability_accuracy: 0.9504 - product_accuracy: 0.9147 - val_loss: 0.4345 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1034 - val_stability_loss: 0.0867 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 50/100
- 0s - loss: 0.5976 - P_final_loss: 0.0340 - q_final_loss: 0.2126 - stability_loss: 0.1039 - product_loss: 0.2307 - P_final_mae: 0.1213 - q_final_mae: 0.1105 - stability_accuracy: 0.9504 - product_accuracy: 0.9147 - val_loss: 0.4347 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0867 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 51/100
- 0s - loss: 0.5975 - P_final_loss: 0.0346 - q_final_loss: 0.2125 - stability_loss: 0.1043 - product_loss: 0.2282 - P_final_mae: 0.1213 - q_final_mae: 0.1105 - stability_accuracy: 0.9510 - product_accuracy: 0.9147 - val_loss: 0.4347 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 52/100
- 0s - loss: 0.5975 - P_final_loss: 0.0342 - q_final_loss: 0.2128 - stability_loss: 0.1087 - product_loss: 0.2297 - P_final_mae: 0.1214 - q_final_mae: 0.1105 - stability_accuracy: 0.9510 - product_accuracy: 0.9147 - val_loss: 0.4346 - val_P_final_loss: 0.0373 - val_q_final_loss: 0.1035 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 53/100
- 0s - loss: 0.5975 - P_final_loss: 0.0369 - q_final_loss: 0.2133 - stability_loss: 0.1100 - product_loss: 0.2328 - P_final_mae: 0.1214 - q_final_mae: 0.1105 - stability_accuracy: 0.9510 - product_accuracy: 0.9152 - val_loss: 0.4346 - val_P_final_loss: 0.0372 - val_q_final_loss: 0.1034 - val_stability_loss: 0.0868 - val_product_loss: 0.2130 - val_P_final_mae: 0.1189 - val_q_final_mae: 0.1176 - val_stability_accuracy: 0.9680 - val_product_accuracy: 0.9232
Epoch 00053: ReduceLROnPlateau reducing learning rate to 3.200000264769187e-07.
Epoch 00053: early stopping
Training results
target mean training score test score
--------------------------------------------------
P_final : 444.678 60.521 59.283
q_final : 0.661 0.067 0.071
stability : - 95.1% 96.8%
product : - 91.5% 92.3%
We can now have a look at how the model trained, what the final accuracy and mean absolute errors are for the targets:
[14]:
pl.figure(figsize=(15,7))
predictor.print_score()
predictor.plot_training_history()
Training results
target mean training score test score
--------------------------------------------------
P_final : 444.678 60.521 59.283
q_final : 0.661 0.067 0.071
stability : - 95.1% 96.8%
product : - 91.5% 92.3%
For the classifiers we can also plot a confusion matrix:
[15]:
pl.figure(figsize=(16, 8))
predictor.plot_confusion_matrix()
Making predictions¶
Download four samples with different metalicity distributions
[16]:
sample_m1 = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_m1.csv')
sample_solar = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_solar.csv')
sample_uniform = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_uniform.csv')
sample_besancon = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/Sample_FeH_besancon.csv')
[17]:
observations = pd.read_csv('http://www.astro.physik.uni-potsdam.de/~jorisvos/nnaps/observations_p_q.csv')
[18]:
pred_m1 = predictor.predict(sample_m1)
pred_solar = predictor.predict(sample_solar)
pred_uniform = predictor.predict(sample_uniform)
pred_besancon = predictor.predict(sample_besancon)
[19]:
def plot_set(data, ax):
ax.set_facecolor('w')
s = data[(data['product'] == 'sdB')]
x, y = s['P_final'], s['q_final']
hb = ax.hexbin(x, y, gridsize=100, cmap='Greys', vmax=5)
pl.errorbar(observations['P'], observations['q'], xerr=observations['P_e'], yerr=observations['q_e'],
marker='+', color='b', ls='', ms=10, mew=3, label='Obs')
pl.xlim([200,1750])
pl.ylim([0.3, 0.9])
pl.figure(1, figsize=(12, 10))
pl.subplots_adjust(left=0.07, right=0.98, bottom=0.07, top=0.95)
ax = pl.subplot(221)
plot_set(pred_solar, ax)
pl.title('[Fe/H] = 0.0')
ax = pl.subplot(222)
plot_set(pred_m1, ax)
pl.title('[Fe/H] = -1.0')
pl.plot([0], [0], 's', color='gray', label='Model')
pl.legend()
ax = pl.subplot(223)
plot_set(pred_uniform, ax)
pl.title('[Fe/H] = Uniform')
ax = pl.subplot(224)
plot_set(pred_besancon, ax)
pl.title('[Fe/H] = Galactic')
_ = pl.figtext(0.53, 0.01, 'Period (days)', ha='center', size=18)
_ = pl.figtext(0.01, 0.54, 'Mass ratio (M1 / M2)', va='center', size=18, rotation='vertical')
[ ]: