PyBADS Example 4: Noisy objective with user-provided noise estimates

PyBADS Example 4: Noisy objective with user-provided noise estimates#

In this example, we will show how to run PyBADS on a noisy target for which we can estimate the noise at each evaluation.

This notebook is Part 4 of a series of notebooks in which we present various example usages for BADS with the PyBADS package. The code used in this example is available as a script here.

import numpy as np
from pybads import BADS

1. Problem setup#

We assume you are already familiar with optimization of noisy targets with PyBADS, as described in the previous notebook.

Sometimes, you may be able to estimate the noise associated with each function evaluation, for example via bootstrap or more sophisticated estimation methods such as inverse binomial sampling. If you can do that, it is highly recommended you do so and tell PyBADS by activating the specify_target_noise option.

In the user-specified noise case, the target function fun is expected to return two outputs:

  • the (noisy) estimate of the function at x (as usual);

  • the estimate of the standard deviation of the noisy evaluation at x.

In this toy example below, we know the standard deviation sigma by construction. Note that the function is heteroskedastic, that is, the noise depends on the input location x.

def noisy_sphere_estimated_noise(x,scale=1.0):
    """Quadratic function with heteroskedastic noise; also return noise estimate."""
    x_2d = np.atleast_2d(x)
    f = np.sum(x_2d**2, axis=1)
    sigma = scale*(1.0 + np.sqrt(f))
    y = f + sigma*np.random.normal(size=x_2d.shape[0])
    return y, sigma

x0 = np.array([-3, -3]);      # Starting point
lower_bounds = np.array([-5, -5])
upper_bounds = np.array([5, 5])
plausible_lower_bounds = np.array([-2, -2])
plausible_upper_bounds = np.array([2, 2])

options = {
    "uncertainty_handling": True,
    "specify_target_noise": True,
    "noise_final_samples": 100
}

2. Run the optimization#

As usual, we run bads with the user-defined options.

bads = BADS(
    noisy_sphere_estimated_noise, x0, lower_bounds, upper_bounds, plausible_lower_bounds, plausible_upper_bounds, 
    options=options
)
optimize_result = bads.optimize()
Beginning optimization of a STOCHASTIC objective function (specified noise)

 Iteration    f-count      E[f(x)]        SD[f(x)]           MeshScale          Method              Actions
     0           1         27.5352         5.24264               1                                  
     0          33        -2.91026         5.24264               1          Initial mesh            Initial points
     0          37        -2.91026         2.90933             0.5          Refine grid             Train
     1          45        -2.91026         2.90933            0.25          Refine grid             Train
     2          46       -0.763687        0.868669            0.25      Successful search (ES-wcm)        
     2          57       -0.763687        0.868669           0.125          Refine grid             Train
     3          58        0.275083        0.508241           0.125      Successful search (ES-ell)        
     3          69        0.275083        0.508241          0.0625          Refine grid             Train
     4          70        0.836853        0.340929          0.0625      Successful search (ES-ell)        
     4          81        0.836853        0.340929         0.03125          Refine grid             Train
     5          82          1.4029        0.277549         0.03125      Successful search (ES-ell)        
     5          83         1.35267        0.401613         0.03125      Successful search (ES-wcm)        
     5          84         1.11699         0.55322         0.03125      Successful search (ES-wcm)        
     5          85        0.282006        0.579217         0.03125      Successful search (ES-wcm)        
     5          86       0.0522514        0.534398         0.03125      Successful search (ES-wcm)        
     5          88      -0.0275186        0.471638         0.03125      Successful search (ES-wcm)        
     5          97      -0.0275186        0.471638        0.015625          Refine grid             
     6          98          0.1502        0.247952        0.015625      Successful search (ES-wcm)        
     6          99     -0.00438043        0.233882        0.015625      Successful search (ES-wcm)        
     6         100      -0.0754532        0.227039        0.015625      Successful search (ES-wcm)        
     6         101       -0.134933        0.231195        0.015625      Successful search (ES-wcm)        
     6         109       -0.147044        0.193437         0.03125        Successful poll           
     7         117       -0.147044        0.193437        0.015625          Refine grid             
     8         121      -0.0511507        0.165127        0.015625      Incremental search (ES-ell)        
     8         122      -0.0929799         0.16533         0.03125        Successful poll           Train
     9         130       -0.136204        0.147907          0.0625        Successful poll           Train
    10         132       -0.136762        0.159908          0.0625      Incremental search (ES-wcm)        
    10         133       -0.143109        0.160052          0.0625      Incremental search (ES-ell)        
    10         134        -0.15408        0.159411          0.0625      Incremental search (ES-ell)        
    10         138       -0.165267        0.143339         0.03125          Refine grid             
    11         146       -0.142635        0.138263        0.015625          Refine grid             Train
    12         154       -0.103103        0.129585      0.00390625          Refine grid             Train
    13         156       -0.036772        0.122827      0.00390625      Incremental search (ES-wcm)        
    13         157      -0.0412922        0.121273      0.00390625      Successful search (ES-wcm)        
    13         158      -0.0521136        0.120051      0.00390625      Successful search (ES-wcm)        
    13         162      -0.0544248         0.11692      0.00390625      Successful search (ES-ell)        
    13         165       -0.066474        0.114724      0.00390625      Successful search (ES-ell)        
    13         166      -0.0730774        0.114275      0.00390625      Successful search (ES-ell)        
    13         168      -0.0737546        0.113082      0.00390625      Incremental search (ES-ell)        
    13         169      -0.0910154        0.113536      0.00390625      Successful search (ES-ell)        
    13         173      -0.0921021        0.111431      0.00390625      Successful search (ES-wcm)        
    13         174      -0.0921595        0.111604      0.00390625      Incremental search (ES-wcm)        
    13         182      -0.0921595        0.111604      0.00195312          Refine grid             
    14         183      -0.0564883        0.103933      0.00195312      Successful search (ES-wcm)        
    14         185      -0.0603416        0.103084      0.00195312      Successful search (ES-wcm)        
    14         186       -0.063154        0.102919      0.00195312      Successful search (ES-wcm)        
    14         194      -0.0686494       0.0990548      0.00390625        Successful poll           
    15         197      -0.0724876       0.0980083      0.00390625      Successful search (ES-wcm)        
    15         199      -0.0747856       0.0982266      0.00390625      Successful search (ES-ell)        
    15         204       -0.082064        0.096049      0.00390625      Successful search (ES-ell)        
    15         205      -0.0842201       0.0948564      0.00390625      Successful search (ES-ell)        
    15         206      -0.0855717       0.0946394      0.00390625      Successful search (ES-ell)        
    15         207      -0.0970346       0.0951909      0.00390625      Successful search (ES-ell)        
    15         210       -0.157229        0.234229      0.00390625      Successful search (ES-ell)        
    15         211       -0.174686         0.26718      0.00390625      Successful search (ES-ell)        
    15         222       -0.174686         0.26718      0.00195312          Refine grid             Train
    16         230      -0.0871417        0.189485     0.000488281          Refine grid             Train
    17         238      -0.0621804        0.092649     0.000244141          Refine grid             Train
    18         239      -0.0540518        0.109407     0.000244141      Successful search (ES-wcm)        
    18         240       -0.069655        0.109258     0.000244141      Successful search (ES-wcm)        
    18         241      -0.0697666        0.109354     0.000244141      Incremental search (ES-wcm)        
    18         242      -0.0788053        0.110422     0.000244141      Successful search (ES-wcm)        
    18         243       -0.100884        0.135117     0.000244141      Successful search (ES-wcm)        
    18         244       -0.132902        0.135261     0.000244141      Successful search (ES-wcm)        
    18         254       -0.132902        0.135261      0.00012207          Refine grid             Train
    19         262      -0.0550156        0.130951     0.000244141        Successful poll           Train
    20         266      -0.0294325        0.102518     0.000244141      Successful search (ES-wcm)        
    20         274      -0.0294325        0.102518      0.00012207          Refine grid             
    21         282     -0.00985551       0.0988365     6.10352e-05          Refine grid             
    22         283      0.00560284        0.095124     6.10352e-05      Successful search (ES-wcm)        
    22         286      0.00267631       0.0963349     6.10352e-05      Successful search (ES-wcm)        
    22         294      0.00267631       0.0963349     3.05176e-05          Refine grid             
    23         296       0.0131204       0.0933549     3.05176e-05      Successful search (ES-wcm)        
    23         297      0.00549339       0.0931934     3.05176e-05      Successful search (ES-wcm)        
    23         298      -0.0600505        0.226728     3.05176e-05      Successful search (ES-wcm)        
    23         299      -0.0945269        0.221667     3.05176e-05      Successful search (ES-wcm)        
    23         300      -0.0975214        0.217002     3.05176e-05      Successful search (ES-wcm)        
    23         303       -0.104548        0.212268     3.05176e-05      Successful search (ES-wcm)        
    23         304       -0.144592        0.208421     3.05176e-05      Successful search (ES-wcm)        
    23         305        -0.20207        0.204639     3.05176e-05      Successful search (ES-wcm)        
    23         314        -0.20207        0.204639     1.52588e-05          Refine grid             
    24         315       -0.103618        0.140386     1.52588e-05      Successful search (ES-wcm)        
    24         316       -0.134944        0.148889     1.52588e-05      Successful search (ES-wcm)        
    24         320       -0.510266        0.306413     1.52588e-05      Successful search (ES-wcm)        
    24         330       -0.510266        0.306413     7.62939e-06          Refine grid             
    25         331       -0.232391        0.214288     7.62939e-06      Successful search (ES-ell)        
    25         342       -0.232391        0.214288      3.8147e-06          Refine grid             Train
    26         350       -0.135337        0.202093     1.90735e-06          Refine grid             Train
    27         358      -0.0627757        0.195314     4.76837e-07          Refine grid             Train
Optimization terminated: change in the function value less than options['tol_mesh']
Estimated function value at minimum: -0.009079534944637901 ± 0.116836714211856 (mean ± SEM from 100 samples)

3. Results and conclusions#

As per noisy target optimization, optimize_result['fval'] is the estimated function value at optimize_result['x'], here obtained by taking the weighted mean of the final sampled evaluations (each evaluation is weighted by its precision, or inverse variance).

The final samples can be found in optimize_result['yval_vec'], and their estimated standard deviation in optimize_result['ysd_vec'].

x_min = optimize_result['x']
fval = optimize_result['fval']
fsd = optimize_result['fsd']

print(f"BADS minimum at: x_min = {x_min.flatten()}, fval (estimated) = {fval:.4g} +/- {fsd:.2g}")
print(f"total f-count: {optimize_result['func_count']}, time: {round(optimize_result['total_time'], 2)} s")
print(f"final evaluations (shape): {optimize_result['yval_vec'].shape}")
print(f"final evaluations SD (shape): {optimize_result['ysd_vec'].shape}")
BADS minimum at: x_min = [-0.11760458  0.02207676], fval (estimated) = -0.00908 +/- 0.12
total f-count: 458, time: 17.98 s
final evaluations (shape): (100,)
final evaluations SD (shape): (100,)

We can also check the ground-truth value of the target function at the returned point once we remove the noise:

print(f"The true, noiseless value of f(x_min) is {noisy_sphere_estimated_noise(x_min,scale=0)[0][0]:.3g}.")
The true, noiseless value of f(x_min) is 0.0143.

Compare this to the true global minimum of the sphere function at \(\textbf{x}^\star = [0,0]\), where \(f^\star = 0\).

Remarks#

  • Due to the elevated level of noise, we do not necessarily expect high precision in the solution.

  • For more information on optimizing noisy objective functions, see the BADS wiki: https://github.com/acerbilab/bads/wiki#noisy-objective-function (this link points to the MATLAB wiki, but many of the questions and answers apply to PyBADS as well).